首页 > 其他分享 >yolov8模型转为onnx后的推理测试(分为两个py文件)

yolov8模型转为onnx后的推理测试(分为两个py文件)

时间:2024-08-14 16:06:20浏览次数:6  
标签:img onnx image py yolov8 class input label self

点击查看代码
import torch

from ultralytics.utils import ASSETS, yaml_load
from ultralytics.utils.checks import check_requirements, check_yaml
import numpy as np
import cv2
import onnxruntime as ort

class YOLOv8:
    """YOLOv8 object detection model class for handling inference and visualization."""

    def __init__(self, onnx_model, input_image, confidence_thres, iou_thres):
        """
        Initializes an instance of the YOLOv8 class.

        Args:
            onnx_model: Path to the ONNX model.
            input_image: Path to the input image.
            confidence_thres: Confidence threshold for filtering detections.
            iou_thres: IoU (Intersection over Union) threshold for non-maximum suppression.
        """
        self.onnx_model = onnx_model
        self.input_image = input_image
        self.confidence_thres = confidence_thres
        self.iou_thres = iou_thres

        # Load the class names from the COCO dataset
        self.classes = yaml_load(check_yaml("mycoco.yaml"))["names"]

        # Generate a color palette for the classes
        self.color_palette = np.random.uniform(0, 255, size=(len(self.classes), 3))

    def draw_detections(self, img, box, score, class_id):
        """
        Draws bounding boxes and labels on the input image based on the detected objects.

        Args:
            img: The input image to draw detections on.
            box: Detected bounding box.
            score: Corresponding detection score.
            class_id: Class ID for the detected object.

        Returns:
            None
        """

        # Extract the coordinates of the bounding box
        x1, y1, w, h = box

        # Retrieve the color for the class ID
        color = self.color_palette[class_id]

        # Draw the bounding box on the image
        cv2.rectangle(img, (int(x1), int(y1)), (int(x1 + w), int(y1 + h)), color, 2)

        # Create the label text with class name and score
        label = f"{self.classes[class_id]}: {score:.2f}"

        # Calculate the dimensions of the label text
        (label_width, label_height), _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)

        # Calculate the position of the label text
        label_x = x1
        label_y = y1 - 10 if y1 - 10 > label_height else y1 + 10

        # Draw a filled rectangle as the background for the label text
        cv2.rectangle(
            img, (label_x, label_y - label_height), (label_x + label_width, label_y + label_height), color, cv2.FILLED
        )

        # Draw the label text on the image
        cv2.putText(img, label, (label_x, label_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1, cv2.LINE_AA)

    def preprocess(self):
        """
        Preprocesses the input image before performing inference.

        Returns:
            image_data: Preprocessed image data ready for inference.
        """
        # Read the input image using OpenCV
        self.img = cv2.imread(self.input_image)

        # Get the height and width of the input image
        self.img_height, self.img_width = self.img.shape[:2]

        # Convert the image color space from BGR to RGB
        img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)

        # Resize the image to match the input shape
        img = cv2.resize(img, (self.input_width, self.input_height))

        # Normalize the image data by dividing it by 255.0
        image_data = np.array(img) / 255.0

        # Transpose the image to have the channel dimension as the first dimension
        image_data = np.transpose(image_data, (2, 0, 1))  # Channel first

        # Expand the dimensions of the image data to match the expected input shape
        image_data = np.expand_dims(image_data, axis=0).astype(np.float32)

        # Return the preprocessed image data
        return image_data

    def postprocess(self, input_image, output):
        """
        Performs post-processing on the model's output to extract bounding boxes, scores, and class IDs.

        Args:
            input_image (numpy.ndarray): The input image.
            output (numpy.ndarray): The output of the model.

        Returns:
            numpy.ndarray: The input image with detections drawn on it.
        """

        # Transpose and squeeze the output to match the expected shape
        outputs = np.transpose(np.squeeze(output[0]))

        # Get the number of rows in the outputs array
        rows = outputs.shape[0]

        # Lists to store the bounding boxes, scores, and class IDs of the detections
        boxes = []
        scores = []
        class_ids = []

        # Calculate the scaling factors for the bounding box coordinates
        x_factor = self.img_width / self.input_width
        y_factor = self.img_height / self.input_height

        # Iterate over each row in the outputs array
        for i in range(rows):
            # Extract the class scores from the current row
            classes_scores = outputs[i][4:]

            # Find the maximum score among the class scores
            max_score = np.amax(classes_scores)

            # If the maximum score is above the confidence threshold
            if max_score >= self.confidence_thres:
                # Get the class ID with the highest score
                class_id = np.argmax(classes_scores)

                # Extract the bounding box coordinates from the current row
                x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3]

                # Calculate the scaled coordinates of the bounding box
                left = int((x - w / 2) * x_factor)
                top = int((y - h / 2) * y_factor)
                width = int(w * x_factor)
                height = int(h * y_factor)

                # Add the class ID, score, and box coordinates to the respective lists
                class_ids.append(class_id)
                scores.append(max_score)
                boxes.append([left, top, width, height])

        # Apply non-maximum suppression to filter out overlapping bounding boxes
        indices = cv2.dnn.NMSBoxes(boxes, scores, self.confidence_thres, self.iou_thres)

        # Iterate over the selected indices after non-maximum suppression
        for i in indices:
            # Get the box, score, and class ID corresponding to the index
            box = boxes[i]
            score = scores[i]
            class_id = class_ids[i]

            # Draw the detection on the input image
            self.draw_detections(input_image, box, score, class_id)

        # Return the modified input image
        return input_image

    def main(self):
        """
        Performs inference using an ONNX model and returns the output image with drawn detections.

        Returns:
            output_img: The output image with drawn detections.
        """
        # Create an inference session using the ONNX model and specify execution providers
        session = ort.InferenceSession(self.onnx_model, providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
        # Get the model inputs
        model_inputs = session.get_inputs()
        # Store the shape of the input for later use
        input_shape = model_inputs[0].shape
        self.input_width = input_shape[2]
        self.input_height = input_shape[3]
        # Preprocess the image data
        img_data = self.preprocess()
        # Run inference using the preprocessed image data
        outputs = session.run(None, {model_inputs[0].name: img_data})
        # Perform post-processing on the outputs to obtain output image.
        return self.postprocess(self.img, outputs)  # output image

class YOLOv8Inference:
    def __init__(self, model_path, conf_thres, iou_thres):
        # Check the requirements and select the appropriate backend (CPU or GPU)
        check_requirements("onnxruntime-gpu" if torch.cuda.is_available() else "onnxruntime")

        # Initialize the YOLOv8 instance
        self.detection = YOLOv8(model_path, None, conf_thres, iou_thres)

    def process_image(self, img_path):
        # Set the image path for the YOLOv8 instance
        print("207:img_path",img_path)
        self.detection.input_image = img_path

        # Perform object detection and obtain the output image
        output_image = self.detection.main()

        return output_image
点击查看代码
# Ultralytics YOLO 

标签:img,onnx,image,py,yolov8,class,input,label,self
From: https://www.cnblogs.com/littlecute555/p/18359188

相关文章

  • YOLOv8改进系列,YOLOv8添加EMA注意力机制,并与C2f融合
    原文摘要在各种计算机视觉任务中,通道或空间注意力机制在生成更具辨识度的特征表示方面表现出显著的效果。然而,通过通道维度减少来建模跨通道关系可能会在提取深度视觉表示时带来副作用。本文提出了一种新颖的高效多尺度注意力(EMA)模块。该模块重点在于保留每个通道的信息......
  • YOLOv8改进系列,YOLOv8替换主干网络为MobileNetV2(轻量化架构+助力涨点)
    原论文摘要MobileNetV2架构在多个任务和基准测试中提高了移动模型的最先进性能,并在不同的模型规模中表现出色。我们还介绍了在一种我们称之为SSDLite的新框架中应用这些移动模型进行目标检测的高效方法。MobileNetV2理论详解可以参考链接:论文地址本文在YOLOv8中的主干......
  • Python - 详情介绍Zmail发送邮件(支持普通&企业邮箱,163、QQ、gmail...)
    Python-详情介绍Zmail发送邮件为了满足在python项目中收发邮件给其他人,可利用自己的邮箱账号结合Zmail来完成。Zmail使得在python3中发送和接受邮件变得更简单。你不需要手动添加服务器地址、端口以及适合的协议。Zmail仅支持python3,不需要任何外部依赖.不支持python2......
  • springboot+vue《Python数据分析》的教学系统【程序+论文+开题】-计算机毕业设计
    系统程序文件列表开题报告内容研究背景随着大数据时代的到来,数据分析技能已成为各行各业不可或缺的核心竞争力之一。Python,作为一门高效、灵活且拥有丰富数据分析库的编程语言,正逐步成为数据分析领域的主流工具。然而,当前高等教育体系中,《Python数据分析》课程的教学仍面临......
  • Robyn与FastAPI全面对比:选择最适合你的Python Web框架
    引言1.1背景介绍在当今的软件开发领域,选择合适的Web框架对于项目的成功至关重要。Python作为一种广泛使用的编程语言,其生态系统中涌现出了众多优秀的Web框架,如FastAPI和Robyn。FastAPI自发布以来,因其高性能、易用性和自动生成API文档的特性,迅速成为开发者的首选。而Robyn......
  • 3163:【例27.3】 第几项(C、C++、python)
    3163:【例27.3】第几项信息学奥赛一本通-编程启蒙(C++版)在线评测系统[例27.3]第几项2020:【例4.5】第几项信息学奥赛一本通(C++版)在线评测系统27.3_哔哩哔哩_bilibiliC语言代码:#include<stdio.h>#include<stdlib.h>intmain(){intm,s=0,n=0;s......
  • Python轻量级 NoSQL 数据库之tinydb使用详解
    概要在现代应用开发中,使用数据库来存储和管理数据是非常常见的需求。对于简单的数据存储需求,关系型数据库可能显得过于复杂。TinyDB是一个纯Python实现的轻量级NoSQL数据库,专为嵌入式场景设计,适用于小型项目、原型开发和教学等场景。本文将详细介绍TinyDB库,包括其安......
  • Python之sys.argv功能使用详解
    概要在Python编程中,命令行参数是与程序交互的重要方式之一。通过命令行参数,用户可以在运行脚本时传递输入值,从而影响程序的行为。Python提供了一个非常方便的模块——sys,其中的sys.argv列表可以轻松地获取命令行参数。在本文中,将深入探讨sys.argv的使用方法,结合实际示例展示......
  • Python编程中不可忽视的docstring使用详解
    概要在Python编程中,代码的可读性和可维护性至关重要。除了清晰的命名和结构良好的代码外,良好的文档字符串(docstring)也是确保代码易于理解和使用的关键工具。docstring是Python中用于记录模块、类、方法和函数行为的字符串,帮助开发者和用户快速了解代码的功能和用法。本文将......
  • 一个Web服务器及python作web开发的框架:Tornado 托内科及python提示报错:ImportError:
    一、一个Web服务器及python作web开发的框架:Tornado托内科    tornado,是使用Python编写的一个强大的、可扩展的Web服务器及Python作web开发框架。网上说Tornado和现在的主流Web服务器框架(包括大多数Python的框架)有着明显的区别:它是非阻塞式服务器,而且速度相当快。得利......