阅读utils/loss.py
,掌握YOLO算法的loss计算过程
这个损失函数跟YOLO-V5的损失函数相同,最关键的函数是build_target
。原理很抽象,代码更抽象。想要读懂耗时大概2~3天。由于学长时间有限,所以就只把最关键的部分讲一下。
原理解析
对于某一张输入图片,YOLO算法把图片划分成多个网格(e.g. 13x13),然后根据图片中目标所在的位置,使用对应位置的网格来进行预测,这一步骤实际上可以理解成:把目标分配给某一个网格。在使用anchor的情况下,每一个网格包含多个anchor,此时还要把目标分配给网格中的某一个anchor。所以整体上看就是把目标分配给anchor,称为“anchor assign”。这就是build_target函数所做的事情。它大致分成以下几个步骤:- 取出某一个批次图片中所有的目标。
- 根据每一个目标的位置x,y,找到负责预测它的anchor。
- 通过高维数组的特殊索引操作:"fancy indexing",取出被分配到目标的anchor对应的预测值
- 计算每个anchor的损失值,然后就可以反向传播了。
代码解析
-
取出某一个批次图片中所有的目标
utils/datasets.py
中,有一个函数:
def collate_fn(batch):
img, label = zip(*batch)
for i, l in enumerate(label):
if l.shape[0] > 0:
l[:, 0] = i
return torch.stack(img), torch.cat(label, 0)
各个变量的含义如下:
-
batch
:DataLoader调用batch_size次TensorDataset类的__getitem__函数,获得的返回值放在一起就是batch。 -
img,label
:zip(*xxx)利用 * 号操作符,可以将元组解压为列表。__getitem__函数会返回一张图片跟它的标签,图片形状是(通道数,高,宽),简写成(c, h, w),标签是个矩阵,形状是(目标个数,6)。每一个目标本来只有5个属性x, y, w, h, class_id,但是多预留一个属性,用来保存这个目标所在的图片在这一个批次中的索引。 - 返回值:有俩。第一个是把图片拼起来,形状是(batch_size, c, h, w)。第二个是批次中所有的目标,是一个矩阵,形状是(所有目标数量,6)。
-
把每一个目标分配给合适的anchor
utils/loss.py
的build_target
函数完成这件事儿。细分成以下几个步骤:
- 函数输入值
preds
:网络的输出值,长度为6的元组。
targets
:函数collate_fn
的第二个返回值,形状是(所有目标数量,6)。
- 加载配置文件中设定的anchor
#加载anchor配置
anchors = np.array(cfg["anchors"])
anchors = torch.from_numpy(anchors.reshape(len(preds) // 3, anchor_num, 2)).to(device)
变量anchors
的形状是(尺度数量,anchor数量,2),2表示每一个anchor的宽高。如果你不知道尺度是什么意思,说明看的不认真,往回看看。
- 谜语部分
gain = torch.ones(7, device = device)
at = torch.arange(anchor_num, device = device).float().view(anchor_num, 1).repeat(1, label_num)
targets = torch.cat((targets.repeat(anchor_num, 1, 1), at[:, :, None]), 2)
g = 0.5 # bias
off = torch.tensor([[0, 0],
[1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m
# [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
], device = device).float() * g # offsets
这一部分代码充分体现了乱起变量名带来的后果。各个变量的含义:
-
gain
:后面就知道了,用来存在特征图的大小。形状是(7, ),并且初始化为1。 -
at
:后面就知道了。形状是(anchor_num,label_num),label_num就是整个批次图片包含的目标数量。可以查一下repeat函数的作用。 -
targets
:批次中所有的目标。各个表达返回值的形状:-
targets.repeat(anchor_num, 1, 1)
:(anchor_num, label_num, 6) -
at[:, :, None]
:(anchor_num,label_num, 1) -
torch.cat
:(anchor_num,label_num, 7),这也是最终targets的形状。
targets
先把目标复制了anchor_num遍,等价于把目标同时分配给了网格内的所有anchor。然后后面的代码中会以"向量化"的方式判断这几个anchor是否适合这个目标。 targets形状最后一个维度7个分量的含义是:(img_index, x, y, w, h, class_id, anchor_id)。7个分量描述了某一个目标,比如说它所在的图片索引,它被分配给了哪一个anchor。 -
-
g
,off
:后面就知道了。
- 关键部分
#将label坐标映射到特征图上 gain[2:6] = torch.tensor(pred.shape)[[3, 2, 3, 2]] gt = targets * gain
各个表达式的含义:
-
torch.tensor(pred.shape)[[3, 2, 3, 2]]
:用了一个非常简单的fancy indexing,pred.shape返回值是(n, c, h, w),w, h是特征图的宽高,由于这个特征图是网络预测值,所以如果网格划分是13x13的话,那么h, w都是13。索引2对应的是h,3对应的是w,所以这个表达式返回值是(w, h, w, h)。 -
gain[2:6] = torch.tensor(pred.shape)[[3, 2, 3, 2]]
:把(w, h, w, h)赋值给gain,此时gain的内容是:(1, 1, w, h, w, h, 1) -
gt = targets * gain
:targets的形状是(anchor_num,label_num, 7),7个分量的含义是:(img_index, x, y, w, h, class_id, anchor_id),乘上gain后会发生"数组广播操作",相当于这7个分量跟gain对应位置相乘,把坐标从原来的百分比坐标转成了特征图坐标。gt的形状是(anchor_num,label_num, 7),仍然表示批次内所有的目标。举个例子,算了不举了。
#anchor iou匹配 r = gt[:, :, 4:6] / anchors_cfg[:, None] j = torch.max(r, 1. / r).max(2)[0] < 2 t = gt[j]
各行代码的含义:
- 第一行:之前咱们把某一个目标分配给了所属的网格中所有的anchor,这一行就是计算目标跟分配的anchor的长宽比。
-
第二行:根据长宽比判断是否合适,返回的
j
是一个长度跟gt一样的一维bool数组,用来表示每一个目标跟分配给它的anchor是否合适。合适的标准就是目标大小跟anchor差不多。 - 第三行:过滤出那些跟分配的anchor很合适,已经找到了自己的归宿的目标。
#扩充维度并复制数据 # Offsets gxy = t[:, 2:4] # grid xy gxi = gain[[2, 3]] - gxy # inverse j, k = ((gxy % 1. < g) & (gxy > 1.)).T l, m = ((gxi % 1. < g) & (gxi > 1.)).T j = torch.stack((torch.ones_like(j), j, k, l, m)) t = t.repeat((5, 1, 1))[j] offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
首先需要知道YoloV5所提出的特殊的anchor分配策略,这个网上资料很多,所以大家自行查找就行。简单来说就是:
- 如果目标的中心点(x,y)落在所属网格的左上方,那么除了所属的网格外,额外让左边跟上面的两个网格也预测这个目标。
- 如果目标的中心点(x,y)落在所属网格的右上方,那么除了所属的网格外,额外让右边跟上面的两个网格也预测这个目标。
- 以此类推。
yolov5 build_target
就能找到。
后面的函数返回值部分应该很简单就看懂了。
-
通过高维数组的特殊索引操作:"fancy indexing",取出被分配到目标的anchor对应的预测值
#构建gt tcls, tbox, indices, anchors = build_target(preds, targets, cfg, device) for i, pred in enumerate(preds): #计算reg分支loss if i % 3 == 0: pred = pred.reshape(pred.shape[0], cfg["anchor_num"], -1, pred.shape[2], pred.shape[3]) pred = pred.permute(0, 1, 3, 4, 2) #判断当前batch数据是否有gt if len(indices): b, a, gj, gi = indices[layer_index[i]] nb = b.shape[0] if nb: ps = pred[b, a, gj, gi]
首先接收build_target函数的返回值,也就是anchor assign的结果。然后把返回值作为索引取出模型预测值的对应部分。这一行就解释了build_target函数的最终目的。
后面就是计算损失值了。
标签:loss,py,pred,torch,目标,num,label,anchor
From: https://www.cnblogs.com/yuxiyuxi/p/17017496.html