首页 > 其他分享 >yolov5推理模块复现

yolov5推理模块复现

时间:2022-11-27 10:12:29浏览次数:33  
标签:box index yolov5 inverse image torch grid 复现 模块

import cv2
import numpy as np
import torch.nn
import torchvision
import torchvision.transforms.functional as T

checkpoint = torch.load("D:/yolov5m.pt", map_location="cpu")
model = checkpoint["model"].float()
model.eval()

model.fuse()
model.model[-1].export = True

# 在Python中推理,把BGR换位RGB

image = cv2.imread("inference/images/zidane.jpg")
show = image.copy()
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# 因为训练图片的时候是640*640来训练数据,预测也需要对长边缩放到640来训练
train_image_size = 640
image_height, image_width = image.shape[:2]
scale = train_image_size / max(image_height, image_width)
x_offset = train_image_size * 0.5 - image_width * scale * 0.5
y_offset = train_image_size * 0.5 - image_height * scale * 0.5

M = np.array([
    [scale, 0, x_offset],
    [0, scale, y_offset]
], dtype=np.float32)
inverse_M = cv2.invertAffineTransform(M)
image = cv2.warpAffine(image, M, (train_image_size, train_image_size), borderMode=cv2.BORDER_CONSTANT,
                       borderValue=(114, 114, 114))

# 需要转化维度并且归一化处理
# 这个方法,如果输入的是整数这会除以255,并且转化维度为C,H,W
# 这个方法,如果是浮点数转化维度为C,H,W
image = T.to_tensor(image).unsqueeze(dim=0)

# 因为这里我们转到tensorRT的时候把Dectcet,我们需要自己推理模块中的Foucs改变了所以我们我也需要对图片精选处理

image = torch.cat([
    image[..., ::2, ::2],
    image[..., 1::2, ::2],
    image[..., ::2, 1::2],
    image[..., 1::2, 1::2]
], dim=1)

# 前面的都是图像预处理,到推理阶段
with torch.no_grad():
    predicts = model(image)

# 解码恢复为框
image_objects = []
for level_index in range(3):
    predict = predicts[level_index]
    stride = model.model[-1].stride[level_index]
    anchor = model.model[-1].anchors[level_index]

    # predict.shape 1*255*80*80  255 = 3*(5+80) 这里是80个类别
    num_classes = int(predict.shape[1] / 3 - 5)
    predict.sigmoid_()

    threshold = 0.25
    num_anchor = 3

    # 遍历strid
    for anchor_index in range(num_anchor):
        channel_begin = anchor_index * (5 + num_classes)
        # 回归框的值
        regression = predict[0, channel_begin + 0:channel_begin + 4]
        # 目标真实值
        objectness = predict[0, channel_begin + 4]
        # 类别数
        classifier = predict[0, channel_begin + 5:channel_begin + 5 + num_classes]
        # torch.where 返回值为tuple元组,保存的内容为objectness中分别为列和行对应的值
        # a = np.array([
        #     [1, 1, 0],
        #     [1, 1, 0]
        # ])
        # a = torch.tensor(a)
        # y1, x1 = torch.where(a >= 0.5)
        # y1 = tensor([0, 0, 1, 1])  x1 = tensor([0, 1, 0, 1])
        grid_y, grid_x = torch.where(objectness >= threshold)
        if len(grid_y) == 0:
            continue
        select_classifier = classifier[:, grid_y, grid_x]
        max_class_id = select_classifier.argmax(dim=0)
        max_class_score = select_classifier[max_class_id, torch.arange(len(max_class_id))]
        select_objectness_score = objectness[grid_y, grid_x]
        # 这里是官方yolov5就是这样做的
        select_object_confidence = select_objectness_score * max_class_score

        # 遍历第二次大于阈值的类
        keep_object_index = torch.where(select_object_confidence >= threshold)
        if len(keep_object_index) == 0:
            continue
        object_confidence = select_object_confidence[keep_object_index]
        object_class = max_class_id[keep_object_index]
        grid_x = grid_x[keep_object_index]
        grid_y = grid_y[keep_object_index]

        # yolov5训练的框回归值其实是从0开始的,所以我们只是预测了我们的偏移量
        box_cx, box_cy = ((regression[:2, grid_y, grid_x].view(2, -1) * 2) - 0.5 + torch.stack([grid_x, grid_y],
                                                                                               dim=0)) * stride
        box_width, box_height = torch.pow(regression[2:4, grid_y, grid_x] * 2, 2) * anchor[anchor_index].view(2,
                                                                                                              1) * stride

        box_left = box_cx - (box_width - 1) * 0.5
        box_right = box_cx + (box_width - 1) * 0.5
        box_top = box_cy - (box_height - 1) * 0.5
        box_bottom = box_cy + (box_height - 1) * 0.5

        box_source_left = box_left * inverse_M[0, 0] + box_top * inverse_M[0, 1] + inverse_M[0, 2]
        box_source_top = box_left * inverse_M[1, 0] + box_top * inverse_M[1, 1] + inverse_M[1, 2]
        box_source_right = box_right * inverse_M[0, 0] + box_bottom * inverse_M[0, 1] + inverse_M[0, 2]
        box_source_bottom = box_right * inverse_M[1, 0] + box_bottom * inverse_M[1, 1] + inverse_M[1, 2]

        objs = torch.stack([
            box_source_left,
            box_source_top,
            box_source_right,
            box_source_bottom,
            object_confidence,
            object_class
        ], dim=1)
        image_objects.append(objs)

# image_objs = torch.cat(image_objects, dim=0)
#
# for left, top, right, bottom, confidence, class_id in image_objs:
#     cv2.rectangle(show,
#                   (int(left), int(top)),
#                   (int(right), int(bottom)),
#                   (0, 255, 0),
#                   2)
# cv2.imwrite("show.jpg", show)
# 会获得很多框所以需要nms
# 这里我们需要做类的nms,不做类间的
image_objs = torch.cat(image_objects, dim=0)
max_image_size = 4096
# left,right,top,bottom,confidence,class,扩大类与类之间的影响
nms_input_box = image_objs[:, :4] + image_objs[:, 5][:, None] * max_image_size
# boxes: Tensor, scores: Tensor, iou_threshold: float
keep_index = torchvision.ops.nms(nms_input_box, image_objs[:, 4], 0.5)
image_objs = image_objs[keep_index]

for left, top, right, bottom, confidence, class_id in image_objs:
    cv2.rectangle(show,
                  (int(left), int(top)),
                  (int(right), int(bottom)),
                  (0, 255, 0),
                  2)
cv2.imwrite("show.jpg", show)
View Code

认真仔细,计算出错认真排查

标签:box,index,yolov5,inverse,image,torch,grid,复现,模块
From: https://www.cnblogs.com/xiaoruirui/p/16929051.html

相关文章

  • 第1章-Spring的模块与应用场景
    目录一、Spring模块1.核心模块2.AOP模块3.消息模块4.数据访问模块5.Web模块6.测试模块二、集成功能1.目标原则2.支持组件三、应用场景1.典型完整的SpringWeb应......
  • YOLOv5识别图像内苹果和香蕉
    YOLOv5为目标检测带来了极大的方便。通过简单地训练YOLOv5,即可以实现一个速度快、性能高的目标检测系统。下面介绍如何从头开始构造一个简单的目标检测系统,用来识别图像内的......
  • js文件模块化引用问题(JavaScript modules)
    有个前端项目,需要用到配置文件。这个配置文件实在是太大了,就想拆成多个小的,然后一一引入,组合成一个完整的配置文件。如果是vue代码,这种情况根本是手到擒来,不费吹灰之力,而该......
  • Yolov5的类激活图
    在本教程中,我们将了解如何将EigenCAM(无梯度方法之一)用于YOLO5。这是https://github.com/jacobgil/pytorch-grad-cam/blob/master/tutorials/ClassActivationMapsforOb......
  • 基础-循环、模块、文件读写
    1、循环(loop)1.1for循环:语法“for[变量名]in[可迭代对象名]:[指令]"name="xiaojiayu"forcharinname:char=char.upper()print(char)name=["xiao","jia......
  • Python: 50个能够满足所有需要的模块
    Python具有强大的扩展能力,我列出了50个很棒的Python模块,包含几乎所有的需要:比如Databases,GUIs,Images,Sound,OSinteraction,Web,以及其他。推荐收藏。  Graphicalint......
  • 为什么需要模块化
    1.为什么需要模块化如果多人同时开发一个项目的时候,每个人都会创建很多的.js文件,可能就会存在命名冲突的问题。为了解决这个问题我们可以使用立即调用函数。但是......
  • 最全的PyInputPlus模块方法总结
    安装第三方模块    在Windows和macOS中,pip会随着Python自动安装。可以通过命令行窗口输入pip检查是否已经安装。但在Linux中,必须由你单独安装。在UbuntuLi......
  • WeNet和ESPnet中下采样模块(Conv2dSubsampling)
    关于WeNet和ESPnet两个工具下采样模块都是相同的操作,首先将输入序列扩充一个维度(因为要使用二维卷积),然后通过两个二维卷积,其中第一个卷积的输入通道为“1”,输出通道为odi......
  • 06.使用包、单元包以及模块
    包(package):一个用于构建、测试并分享单元包的Cargo功能;单元包(crate):一个用于生成库或可执行文件的树形模块结构;模块(module)及use关键字:它们被用于控制文件结构、作用域及路......