首页 > 其他分享 >yolov5 mAP计算代码分析

yolov5 mAP计算代码分析

时间:2024-08-06 16:40:37浏览次数:5  
标签:mAP yolov5 False 300 代码 0.00000 matches True iou

前言

模型训练过程中每一轮都会计算P,R,mAP,[email protected]等数值,本篇分析这些数值的计算过程,分析最核心部分。我的感受是计算的过程比想象的复杂。
主要的流程在yolov5/val.py文件的process_batch处理函数中。

 if nl:
    tbox = xywh2xyxy(labels[:, 1:5])  # target boxes
    scale_boxes(im[si].shape[1:], tbox, shape, shapes[si][1])  # native-space labels
    labelsn = torch.cat((labels[:, 0:1], tbox), 1)  # native-space labels
    correct = process_batch(predn, labelsn, iouv, im[si])
    if plots:
        confusion_matrix.process_batch(predn, labelsn)

计算[email protected]的计算基本原理是:

  1. 给定一个阈值数组,从0.5-0.95每间隔0.05,生成一共10个数据
  2. 获取预测结果和标注,计算预测框和gt框的IOU,判断类别是否一致。
  3. 用阈值数组中的每一个阈值去筛选IOU,大于阈值的预测为True, 小于IOU的预测为False
  4. 统计结果,输出结果

入参分析

detections:bbox,预测框信息,格式为:x1, y1, x2, y2, conf, class
labels:gt框信息,格式为:class, x1, y1, x2, y2
iouv:阈值数组,内容是:[0.50000, 0.55000, 0.60000, 0.65000, 0.70000, 0.75000, 0.80000, 0.85000, 0.90000, 0.95000]

def process_batch(detections, labels, iouv, img=None):
    """
    Return correct prediction matrix
    Arguments:
        detections (array[N, 6]), x1, y1, x2, y2, conf, class
        labels (array[M, 5]), class, x1, y1, x2, y2
    Returns:
        correct (array[N, 10]), for 10 IoU levels
    """

预测框的可视化:

构造返回结果数组

对于每一个预测框,在不同的阈值下会判断成True 或 False。构造一个所有预测框在所有阈值下的结果二维数组。

    # 预测正确的记录,所有iouv [0.5, 0.55, 0.6 ... 0.95]阈值下iou大于阈值且类型匹配的结果。
    # shape:(300, 10)  300个预测,10个阈值
    correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)
  1. detections.shape[0]: 300
  2. iouv.shape[0]:10

最终的结果可以理解成300个预测框在10个不同阈值下一共会有3000个结果,结果为布尔数组。

计算所有预测框和标注框的IOU结果

将gt框和预测框传入box_iou函数,得到结果。

"""
计算gt框和预测的iou,这个过程会出现广播机制,5个gt框分别和300个预测框计算iou,得到的结果为5,300
(Pdb) pp detections[:, :4].shape
torch.Size([300, 4])
(Pdb) pp labels[:, 1:].shape
torch.Size([5, 4])
(Pdb) pp iou.shape
torch.Size([5, 300])
"""
iou = box_iou(labels[:, 1:], detections[:, :4])
  1. labels[:, 1:]:(5, 4)
  2. detections[:, :4]:(300, 4)

5个标注框和300个预测框匹配,也就是拿每一个标注框都和300个预测框匹配,得到5*300个结果。相当于每一行代表这一个标注框和预测框的iou匹配结果。

tensor([[0.00000, 0.00000, 0.78682, 0.02270, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.02236, 0.00000, 0.00000, 0.00000, 0.00000, 0.43876, 0.70469, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000,
         0.00000, 0.00000, 0.00000, 0.04540, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.46606, 0.00000, 0.00000, 0.00000, 0.00000, 0.13432, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.00000, 0.53255,

匹配标签

上一步计算了预测框和标注框的iou,这一步计算预测类别和标注类别的匹配。

    
    """
    计算标签匹配。labels[:, 0:1]代表所有标注的标签。用每一个gt的标签和所有的预测的标签匹配。
    过程存在广播机制,结果为:torch.Size([5, 300])
    结果中 0维也就是5是gt的label  1维也就是300是预测label
    (Pdb) labels[:, 0:1].shape
    torch.Size([5, 1])
    (Pdb) detections[:, 5].shape
    torch.Size([300])
    correct_class.shape
    torch.Size([5, 300])
    """
    correct_class = labels[:, 0:1] == detections[:, 5]

labels[:, 0:1]:[5, 1]
detections[:, 5]: [300]
和上一步类似,每一个标注类表和预测类别匹配,5个标注类别和300个预测类别匹配,得到一个5*300的结果。得到的是一个布尔数组

(Pdb) correct_class
tensor([[ True,  True,  True,  True,  True,  True,  True, False,  True,  True, False, False,  True, False,  True,  True, False, False,  True, False,  True, False,  True,  True,  True,  True,  True,  True, False, False,  True,  True,  True, False,  True,  True, False,  True,  True, False,  True, False, False,  True,
         False,  True,  True,  True,  True,  True,  True,  True,  True, False, False, False,  True,  True, False,  True,  True, False, False,  True,  True,  True, False,  True,  True,  True,  True, False,  True,  True,  True,  True, False, False,  True,  True,  True,  True,  True,  True,  True,  True,  True, False

遍历所有阈值计算在不同的阈值下预测正确的个数

阈值数组是[0.50000, 0.55000, 0.60000, 0.65000, 0.70000, 0.75000, 0.80000, 0.85000, 0.90000, 0.95000],计算在不同的阈值下预测框和标注框的iou匹配结果。以0.5为例,大于0.5为True,小于0.5为False

# 循环遍历所有的iou阈值,计算在不同的阈值下预测正确的个数
for i in range(len(iouv)):
    # (iou >= iouv[i]) & correct_class 得到一个布尔数组
    # torch.where((iou >= iouv[i]) & correct_class) 筛选出布尔数组中为True的元素,返回和布尔数组结构一致,返回为True元素在当前张量中的坐标。当前为二维数组。
    # 返回两个元素,元素1是数组,代表为True元素的x轴下标, 元素2也是数组,代表为True元素的y轴下标。
    # 类似:(tensor([0, 0, 1, 1]), tensor([0, 1, 1, 3]))
    x = torch.where((iou >= iouv[i]) & correct_class)  # IoU > threshold and classes match

(iou >= iouv[i]) & correct_class
iou 是 [5, 300]的iou结果
correct_class 是[5, 300]的类别结果
将两者的判断与操作,得到的就是iou大于阈值,且类别正确的结果。结果是5*300的布尔数组

torch.where((iou >= iouv[i]) & correct_class)
筛选出布尔数组中为True的元素,返回Ture在张量中的坐标。
x的结果如下,结果代表的含义是第一个张量代表标注标签的ID,第二个张量代表匹配的对应的预测标签的ID。

(tensor([0, 0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 4, 4, 5], device='cuda:0'),
 tensor([ 2, 67, 68,  3, 79,  5, 31, 90,  4, 86,  1, 85, 92,  0], device='cuda:0'))

判断是否有正确匹配的结果

# 判断是否有结果
if x[0].shape[0]:

    """
    将结果汇总成[label, detect, iou]的格式。
    label:gt的标签的下标
    detect: 预测结果标签的下标
    iou: 预测和gt的iou值
    
    torch.stack(x, 1) 将x轴和y轴表示的坐标转换成[x, y]的格式
    iou[x[0], x[1]][:, None]) 获取iou中匹配上的元素值
    """
    matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()  # [label, detect, iou]

torch.stack(x, 1)将标注框的ID和预测框的ID组合在一次
iou[x[0], x[1]][:, None]筛选出对应的IOU值
将两者组合在一起,得到的就是[label_id, detect_id, iou]

去除重复匹配

一个gt框可能和多个预测框的iou都大于阈值,这是完全可能的。这里只保留最好的匹配,所以把那些多余的匹配都去掉。一个gt框只保留一个预测框,同时一个预测框只对应一个gt框。

# 匹配存在多个元素,可能是因为广播机制导致的重复,去除重复。
if x[0].shape[0] > 1:
    # 对置信度排序,倒序排列
    matches = matches[matches[:, 2].argsort()[::-1]]

    # 对预测detect进行去重,一个预测框只保留一个iou匹配。返回去重之后的matches
    matches = matches[np.unique(matches[:, 1], return_index=True)[1]]

    # 对label进行去重,一个gt框只保留一个最高iou匹配,返回去重之后的matches
    # matches = matches[matches[:, 2].argsort()[::-1]]
    matches = matches[np.unique(matches[:, 0], return_index=True)[1]]

保存结果

correct是300*10的二维数组,将对应行和列的元素设置为True。

# 将correct中相应的元素置为True
correct[matches[:, 1].astype(int), i] = True
  1. matches[:, 1].astype(int)预测框的下标
  2. i某一个阈值

将某一个阈值下对应的预测框设置成True,也就是代表着在该阈值下,预测框预测的类别正确,同时IOU也超过阈值。

完整代码

def process_batch(detections, labels, iouv, img=None):
    """
    Return correct prediction matrix
    Arguments:
        detections (array[N, 6]), x1, y1, x2, y2, conf, class
        labels (array[M, 5]), class, x1, y1, x2, y2
    Returns:
        correct (array[N, 10]), for 10 IoU levels
    """

    # 预测正确的记录,所有iouv [0.5, 0.55, 0.6 ... 0.95]阈值下iou大于阈值且类型匹配的结果。
    # shape:(300, 10)  300个预测,10个阈值
    correct = np.zeros((detections.shape[0], iouv.shape[0])).astype(bool)

    """
    计算gt框和预测的iou,这个过程会出现广播机制,5个gt框分别和300个预测框计算iou,得到的结果为5,300
    (Pdb) pp detections[:, :4].shape
    torch.Size([300, 4])
    (Pdb) pp labels[:, 1:].shape
    torch.Size([5, 4])
    (Pdb) pp iou.shape
    torch.Size([5, 300])
    """
    iou = box_iou(labels[:, 1:], detections[:, :4])

    breakpoint()

    # 计算标签匹配。labels[:, 0:1]代表所有标注的标签。用每一个gt的标签和所有的预测的标签匹配。
    # 过程存在广播机制,结果为:torch.Size([5, 300])
    # 结果中 0维也就是5是gt的label  1维也就是300是预测label
    """
    (Pdb) labels[:, 0:1].shape
    torch.Size([5, 1])
    (Pdb) detections[:, 5].shape
    torch.Size([300])
    correct_class.shape
    torch.Size([5, 300])
    """
    correct_class = labels[:, 0:1] == detections[:, 5]

    # 循环遍历所有的iou阈值,计算在不同的阈值下预测正确的个数
    for i in range(len(iouv)):
        # (iou >= iouv[i]) & correct_class 得到一个布尔数组
        # torch.where((iou >= iouv[i]) & correct_class) 筛选出布尔数组中为True的元素,返回和布尔数组结构一致,返回为True元素在当前张量中的坐标。当前为二维数组。
        # 返回两个元素,元素1是数组,代表为True元素的x轴下标, 元素2也是数组,代表为True元素的y轴下标。
        # 类似:(tensor([0, 0, 1, 1]), tensor([0, 1, 1, 3]))
        x = torch.where((iou >= iouv[i]) & correct_class)  # IoU > threshold and classes match

        # 判断是否有结果
        if x[0].shape[0]:

            """
            将结果汇总成[label, detect, iou]的格式。
            label:gt的标签的下标
            detect: 预测结果标签的下标
            iou: 预测和gt的iou值
            
            torch.stack(x, 1) 将x轴和y轴表示的坐标转换成[x, y]的格式
            iou[x[0], x[1]][:, None]) 获取iou中匹配上的元素值
            """
            matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy()  # [label, detect, iou]

            # 匹配存在多个元素,可能是因为广播机制导致的重复,去除重复。
            if x[0].shape[0] > 1:
                # 对置信度排序,倒序排列
                matches = matches[matches[:, 2].argsort()[::-1]]

                # 对预测detect进行去重,一个预测框只保留一个iou匹配。返回去重之后的matches
                matches = matches[np.unique(matches[:, 1], return_index=True)[1]]

                # 对label进行去重,一个gt框只保留一个最高iou匹配,返回去重之后的matches
                matches = matches[np.unique(matches[:, 0], return_index=True)[1]]

            # 将correct中相应的元素置为True
            correct[matches[:, 1].astype(int), i] = True

            """
            (Pdb) pp correct
            array([[False, False, False, ..., False, False, False],
                   [False, False, False, ..., False, False, False],
                   [False, False, False, ..., False, False, False],
                   ...,
                   [False, False, False, ..., False, False, False],
                   [False, False, False, ..., False, False, False],
                   [False, False, False, ..., False, False, False]])
            (Pdb) pp correct.shape
            (300, 10)
            """

    return torch.tensor(correct, dtype=torch.bool, device=iouv.device)

标签:mAP,yolov5,False,300,代码,0.00000,matches,True,iou
From: https://www.cnblogs.com/goldsunshine/p/18345485

相关文章

  • 毕业设计:基于Springboot的宿舍管理系统微信小程序【代码+论文+PPT】
    全文内容包括:1、采用技术;2、系统功能;3、系统截图;4、配套内容。索取方式见文末微信号,欢迎关注收藏!一、采用技术语言:Java1.8框架:SpringBoot数据库:MySQL5.7、8.0开发工具:IntelliJIDEA旗舰版、微信开发工具其他:Maven3.8以上二、系统功能报修管理:学生可提交宿舍报修申请,管理......
  • 我正在 python 中使用 aspose.pdf 将 pdf 转换为 excel 。但问题是它只能将 pdf 的前
    `从tkinter导入*将aspose.pdf导入为ap从tkinter导入文件对话框importpandasaspdinput_pdf=filedialog.askopenfilename(filetypes=(("PDF文件",".pdf"),("所有文件",".")))output_file=filedialog.asksaveasfil......
  • 【C++/STL】map和set的封装(红黑树)
     ......
  • kimi写代码:tls singleton
    #include<iostream>#include<mutex>#include<string>#include<thread>classThreadLocalSingleton{private:ThreadLocalSingleton(){std::cout<<"createdforthread"<<std::this_thread::g......
  • 利用miniprogram-ci工具实现一键上传微信小程序代码
    本文由ChatMoney团队出品利用miniprogram-ci工具在后台实现一键上传微信小程序代码,避免了微信开发者工具的繁琐。一、部署node环境我用的是宝塔,可以直接在宝塔上安装Node.js版本管理器二、安装miniprogram-cinpminstallminiprogram-ci--save安装在指定文件夹里,这个......
  • 代码随想录Day7
    454.四数相加Ⅱ给你四个整数数组nums1、nums2、nums3和nums4,数组长度都是n,请你计算有多少个元组(i,j,k,l)能满足:0<=i,j,k,l<nnums1[i]+nums2[j]+nums3[k]+nums4[l]==0示例1:输入:nums1=[1,2],nums2=[-2,-1],nums3=[-1,2],nums4=[0,2]输......
  • 代码随想录Day6
    454.四数相加Ⅱ给你四个整数数组nums1、nums2、nums3和nums4,数组长度都是n,请你计算有多少个元组(i,j,k,l)能满足:0<=i,j,k,l<nnums1[i]+nums2[j]+nums3[k]+nums4[l]==0示例1:输入:nums1=[1,2],nums2=[-2,-1],nums3=[-1,2],nums4=[0,2]输......
  • 6 大推荐给开发者的无代码工具
    在不断发展的软件开发领域,无代码工具正迅速普及。最初,这些工具是为非技术背景的业务用户设计的,而如今,它们对开发者来说也同样不可或缺。无代码工具结合了效率、灵活性和创新性,让开发者能够在无需编写传统代码的情况下快速构建应用程序。那么,为什么开发者也应该考虑使用无代码工......
  • Java集合:Collection and Map;ArrayList;LinkList;HashSet;TreeSet;HashMap;TreeMap;Iterator:
        集合介绍:                        是一组变量类型(容器),跟数组很像。一,引用集合的原因(必要性):                  A:数组的空间长度固定,一旦确定不可以更改。多了浪费,少了报错。          B:使用数......
  • MapperScannerConfigurer中获取applicayion.yml配置,进行动态加载BasePackage
     由于在MapperScannerConfigurer的bean优先于@value,导致@value取出来的时候都是null,所以只能使用Environment来获取值importorg.mybatis.spring.mapper.MapperScannerConfigurer;importorg.springframework.beans.factory.annotation.Value;importorg.springframework......