首页 > 其他分享 >读loss.py

读loss.py

时间:2022-12-31 22:55:17浏览次数:32  
标签:loss py pred torch 目标 num label anchor

阅读utils/loss.py,掌握YOLO算法的loss计算过程

这个损失函数跟YOLO-V5的损失函数相同,最关键的函数是build_target原理很抽象,代码更抽象。想要读懂耗时大概2~3天。由于学长时间有限,所以就只把最关键的部分讲一下。

原理解析

对于某一张输入图片,YOLO算法把图片划分成多个网格(e.g. 13x13),然后根据图片中目标所在的位置,使用对应位置的网格来进行预测,这一步骤实际上可以理解成:把目标分配给某一个网格。在使用anchor的情况下,每一个网格包含多个anchor,此时还要把目标分配给网格中的某一个anchor。所以整体上看就是把目标分配给anchor,称为“anchor assign”。这就是build_target函数所做的事情。它大致分成以下几个步骤:
  1. 取出某一个批次图片中所有的目标。
  2. 根据每一个目标的位置x,y,找到负责预测它的anchor。
  3. 通过高维数组的特殊索引操作:"fancy indexing",取出被分配到目标的anchor对应的预测值
  4. 计算每个anchor的损失值,然后就可以反向传播了。

代码解析

  1. 取出某一个批次图片中所有的目标

utils/datasets.py中,有一个函数:
def collate_fn(batch): 
    img, label = zip(*batch) 
    for i, l in enumerate(label): 
        if l.shape[0] > 0: 
            l[:, 0] = i 
    return torch.stack(img), torch.cat(label, 0)
各个变量的含义如下:
  • batch:DataLoader调用batch_size次TensorDataset类的__getitem__函数,获得的返回值放在一起就是batch。
  • img,label:zip(*xxx)利用 * 号操作符,可以将元组解压为列表。__getitem__函数会返回一张图片跟它的标签,图片形状是(通道数,高,宽),简写成(c, h, w),标签是个矩阵,形状是(目标个数,6)。每一个目标本来只有5个属性x, y, w, h, class_id,但是多预留一个属性,用来保存这个目标所在的图片在这一个批次中的索引。
  • 返回值:有俩。第一个是把图片拼起来,形状是(batch_size, c, h, w)。第二个是批次中所有的目标,是一个矩阵,形状是(所有目标数量,6)。
这个函数的返回值就包含了这个批次图片中所有的目标。 注:“(所有目标数量,6)”中的6,指的就是每个目标的属性:"x,y,w,h,class_id和该图片在这个batch中的索引"
  1. 把每一个目标分配给合适的anchor

utils/loss.pybuild_target函数完成这件事儿。细分成以下几个步骤:
  1. 函数输入值
preds:网络的输出值,长度为6的元组。 targets:函数collate_fn的第二个返回值,形状是(所有目标数量,6)。
  1. 加载配置文件中设定的anchor
#加载anchor配置 
anchors = np.array(cfg["anchors"]) 
anchors = torch.from_numpy(anchors.reshape(len(preds) // 3, anchor_num, 2)).to(device)
变量anchors的形状是(尺度数量,anchor数量,2),2表示每一个anchor的宽高。如果你不知道尺度是什么意思,说明看的不认真,往回看看。
  1. 谜语部分
gain = torch.ones(7, device = device)

at = torch.arange(anchor_num, device = device).float().view(anchor_num, 1).repeat(1, label_num)
targets = torch.cat((targets.repeat(anchor_num, 1, 1), at[:, :, None]), 2)

g = 0.5  # bias
off = torch.tensor([[0, 0],
                    [1, 0], [0, 1], [-1, 0], [0, -1],  # j,k,l,m
                    # [1, 1], [1, -1], [-1, 1], [-1, -1],  # jk,jm,lk,lm
                   ], device = device).float() * g  # offsets
这一部分代码充分体现了乱起变量名带来的后果。各个变量的含义:
  • gain:后面就知道了,用来存在特征图的大小。形状是(7, ),并且初始化为1。
  • at:后面就知道了。形状是(anchor_num,label_num),label_num就是整个批次图片包含的目标数量。可以查一下repeat函数的作用。
  • targets批次中所有的目标。各个表达返回值的形状:
      • targets.repeat(anchor_num, 1, 1):(anchor_num, label_num, 6)
      • at[:, :, None]:(anchor_num,label_num, 1)
      • torch.cat:(anchor_num,label_num, 7),这也是最终targets的形状。
      由于每一个目标(label)都会分配给某一个网格,每一个网格中的anchor数量为anchor_num,所以targets先把目标复制了anchor_num遍,等价于把目标同时分配给了网格内的所有anchor。然后后面的代码中会以"向量化"的方式判断这几个anchor是否适合这个目标。   targets形状最后一个维度7个分量的含义是:(img_index, x, y, w, h, class_id, anchor_id)。7个分量描述了某一个目标,比如说它所在的图片索引,它被分配给了哪一个anchor。
  • goff:后面就知道了。
  1. 关键部分
进入for循环。首先是:
 
#将label坐标映射到特征图上 gain[2:6] = torch.tensor(pred.shape)[[3, 2, 3, 2]] gt = targets * gain
 
各个表达式的含义:
  • torch.tensor(pred.shape)[[3, 2, 3, 2]]:用了一个非常简单的fancy indexing,pred.shape返回值是(n, c, h, w),w, h是特征图的宽高,由于这个特征图是网络预测值,所以如果网格划分是13x13的话,那么h, w都是13。索引2对应的是h,3对应的是w,所以这个表达式返回值是(w, h, w, h)。
  • gain[2:6] = torch.tensor(pred.shape)[[3, 2, 3, 2]]:把(w, h, w, h)赋值给gain,此时gain的内容是:(1, 1, w, h, w, h, 1)
  • gt = targets * gain:targets的形状是(anchor_num,label_num, 7),7个分量的含义是:(img_index, x, y, w, h, class_id, anchor_id),乘上gain后会发生"数组广播操作",相当于这7个分量跟gain对应位置相乘,把坐标从原来的百分比坐标转成了特征图坐标。gt的形状是(anchor_num,label_num, 7),仍然表示批次内所有的目标。举个例子,算了不举了。
然后是:
 
#anchor iou匹配 r = gt[:, :, 4:6] / anchors_cfg[:, None] j = torch.max(r, 1. / r).max(2)[0] < 2 t = gt[j]
 
各行代码的含义:
  • 第一行:之前咱们把某一个目标分配给了所属的网格中所有的anchor,这一行就是计算目标跟分配的anchor的长宽比。
  • 第二行:根据长宽比判断是否合适,返回的j是一个长度跟gt一样的一维bool数组,用来表示每一个目标跟分配给它的anchor是否合适。合适的标准就是目标大小跟anchor差不多。
  • 第三行:过滤出那些跟分配的anchor很合适,已经找到了自己的归宿的目标。
最后是:
 
#扩充维度并复制数据 # Offsets gxy = t[:, 2:4] # grid xy gxi = gain[[2, 3]] - gxy # inverse j, k = ((gxy % 1. < g) & (gxy > 1.)).T l, m = ((gxi % 1. < g) & (gxi > 1.)).T j = torch.stack((torch.ones_like(j), j, k, l, m)) t = t.repeat((5, 1, 1))[j] offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
 
首先需要知道YoloV5所提出的特殊的anchor分配策略,这个网上资料很多,所以大家自行查找就行。简单来说就是:
  1. 如果目标的中心点(x,y)落在所属网格的左上方,那么除了所属的网格外,额外让左边跟上面的两个网格也预测这个目标。
  2. 如果目标的中心点(x,y)落在所属网格的右上方,那么除了所属的网格外,额外让右边跟上面的两个网格也预测这个目标。
  3. 以此类推。
这一小部分代码大家自己看,难点在于怎么用高维数组操作来表示上面的扩充过程,网上(知乎,bilibili)也有一些解析,搜索yolov5 build_target就能找到。 后面的函数返回值部分应该很简单就看懂了。
  1. 通过高维数组的特殊索引操作:"fancy indexing",取出被分配到目标的anchor对应的预测值

对应compute_loss函数中的这段代码:
 
#构建gt tcls, tbox, indices, anchors = build_target(preds, targets, cfg, device) for i, pred in enumerate(preds): #计算reg分支loss if i % 3 == 0: pred = pred.reshape(pred.shape[0], cfg["anchor_num"], -1, pred.shape[2], pred.shape[3]) pred = pred.permute(0, 1, 3, 4, 2) #判断当前batch数据是否有gt if len(indices): b, a, gj, gi = indices[layer_index[i]] nb = b.shape[0] if nb: ps = pred[b, a, gj, gi]
 
首先接收build_target函数的返回值,也就是anchor assign的结果。然后把返回值作为索引取出模型预测值的对应部分。这一行就解释了build_target函数的最终目的。 后面就是计算损失值了。

标签:loss,py,pred,torch,目标,num,label,anchor
From: https://www.cnblogs.com/yuxiyuxi/p/17017496.html

相关文章

  • 在vmware启动anolis自制iso镜像出现 ERROR: could not insert 'floppy' : No such dev
    解决linux启动出现ERROR:couldnotinsert'floppy':Nosuchdevice报错如下:最后屏幕狂刷Warning:dracut-initqueuetimeout-startingtimeoutscript如下:......
  • 如何播放无损音乐格式tak(Tom's lossless Audio Kompressor)?
    无损格式tak,全称Tom'slosslessAudioKompressorwindows下,下载TAKSDK:官方解码器解压后Applications文件夹里面有Tak.exe点击Decompress->AddFiles->Decompress......
  • AGC006D Median Pyramid Hard
    ​​\(AGC006D\)\(Median\)\(Pyramid\)\(Hard\)​​一、题目描述二、题目解析这道例题看到时毫无头绪,因为课程是二分,所以往二分的方向想,猜到是二分枚举最上面的那个数是......
  • Python之路【第六篇】:socket
    1.Socketsocket通常也称作"套接字",用于描述IP地址和端口,是一个通信链的句柄,应用程序通常通过"套接字"向网络发出请求或者应答网络请求。socket起源于Unix,而Unix/Lin......
  • 在pycharm里debug以学习huggingface/transformers
    把https://github.com/huggingface/transformers整个zip下载下来把src/transformers文件夹复制出来,放pycharm里,成这样:根据https://github.com/huggingface/transform......
  • NumPy - 入门
    目录NumPy,是NumericalPython的简称,它是目前Python数值计算中最为重要的基础包。大多数计算包都提供了基于NumPy的科学函数功能,将NumPy的数组对象作为数据交换的通......
  • Python爬虫学习经历
    requests模块1.处理一个UA反爬importrequestscontent=input("请输入你要搜索的内容:")url=f"https://www.sogou.com/web?query={content}"headers={#添加......
  • 掌握Python中列表生成式的五个原因
    1.引言在Python中我们往往使用列表生成式来代替for循环,本文通过引入实际例子,来阐述这背后的原因。闲话少说,我们直接开始吧!2.简洁性列表生成式允许我们在一行代码中创建一......
  • Python__19--函数调用的参数传递与变量的作用域
    1函数调用的参数传递形参(形式参数):在函数定义的时候用到的参数没有具体值,只是一个占位的符号,成为形参;实参(实际参数):在调用函数的时候输入的值。实际参数和形式参......
  • pycharm配置python环境
    打开pycharm,找到settings中的pythoninterpreter点击addinterpreter,addlocalinterpreter选择python安装路径就可以了......