首页 > 其他分享 >YOLOv6-4.0部分代码阅读笔记-assigner_utils.py

YOLOv6-4.0部分代码阅读笔记-assigner_utils.py

时间:2024-10-28 22:44:53浏览次数:7  
标签:YOLOv6 gt 4.0 assigner max 张量 锚点 bboxes 边界

assigner_utils.py

yolov6\assigners\assigner_utils.py

目录

assigner_utils.py

1.所需的库和模块

2.def dist_calculator(gt_bboxes, anchor_bboxes): 

3.def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): 

4.def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): 

5.def iou_calculator(box1, box2, eps=1e-9): 


1.所需的库和模块

import torch
import torch.nn.functional as F

2.def dist_calculator(gt_bboxes, anchor_bboxes): 

def dist_calculator(gt_bboxes, anchor_bboxes):
    # 计算所有 bbox 和 gt 之间的中心距离。
    """compute center distance between all bbox and gt

    Args:
        gt_bboxes (Tensor): shape(bs*n_max_boxes, 4)
        anchor_bboxes (Tensor): shape(num_total_anchors, 4)
    Return:
        distances (Tensor): shape(bs*n_max_boxes, num_total_anchors)
        ac_points (Tensor): shape(num_total_anchors, 2)
    """
    # gt_bboxes为真实边界框(ground truth bounding boxes,简称gt_bboxes)
    # 1. gt_bboxes : 这是一个二维张量,包含了真实边界框的坐标。通常,每个边界框由四个值组成: (x_min, y_min, x_max, y_max) ,其中 (x_min, y_min) 是边界框左上角的坐标, (x_max, y_max) 是边界框右下角的坐标。
    # 2. gt_bboxes[:, 0] : 这个表达式选择了 gt_bboxes 张量中所有边界框的 x_min 坐标。
    # 3. gt_bboxes[:, 2] : 这个表达式选择了 gt_bboxes 张量中所有边界框的 x_max 坐标。
    # 4. 将每个边界框的 x_min 和 x_max 坐标相加 后 / 2.0 : 将每个边界框的 x_min 和 x_max 坐标的和除以2,得到边界框中心点的x坐标。
    # 最终, gt_cx 包含所有真实边界框中心点的x坐标。
    gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0    # gt_cx 包含所有真实边界框中心点的x坐标。
    gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0    # 同理,gt_cy 包含所有真实边界框中心点的y坐标。
    # torch.stack(tensors, dim=0, out=None) → Tensor
    # torch.stack 函数用于将一系列张量沿着一个新的维度连接起来。与 torch.cat (concatenate)不同, torch.stack 会创建一个新的维度来存放输入张量,而不是在已有的维度上进行连接。
    # tensors :一个序列(如元组或列表)的张量,它们需要有相同的形状。
    # dim :沿着哪个维度堆叠张量。默认为0,表示在最前面添加一个新的维度。
    # out :一个可选的张量,用于存储输出结果。
    # 返回值:
    # 返回一个新的张量,它是输入张量沿着 dim 维度堆叠的结果。
    gt_points = torch.stack([gt_cx, gt_cy], dim=1)    #真实边界框的中心点xy坐标。
    ac_cx = (anchor_bboxes[:, 0] + anchor_bboxes[:, 2]) / 2.0    # ac_cx 包含所有锚点边界框中心点的x坐标。
    ac_cy = (anchor_bboxes[:, 1] + anchor_bboxes[:, 3]) / 2.0    # ac_cy 包含所有锚点边界框中心点的y坐标。
    ac_points = torch.stack([ac_cx, ac_cy], dim=1)    #锚点边界框的中心点xy坐标。

    distances = (gt_points[:, None, :] - ac_points[None, :, :]).pow(2).sum(-1).sqrt()

    # 1. distances : 这是一个张量,包含了每个锚点与对应真实边界框之间的距离。在YOLOv6中,由于采用了anchor-free的设计,这个距离可能是指从锚点到边界框四边的距离(top, bottom, left, right)。
        # 这些距离信息用于计算损失函数,特别是在anchor-free的检测框架中,模型需要预测从锚点到边界框边缘的距离,而不是预测相对于预定义锚点的偏移量。
    # 2. ac_points : 这是一个张量,包含了所有锚点的中心点坐标。这些坐标通常用于确定锚点的位置,以便与真实边界框进行比较。
        # 在anchor-free的检测框架中,锚点的中心点坐标对于计算损失函数至关重要,因为它们直接影响到预测边界框的准确性。
    return distances, ac_points

3.def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): 

def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
    # 在 gt 中选择正锚框的中心。
    """select the positive anchors's center in gt

    Args:
        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
    Return:
        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
    """
    n_anchors = xy_centers.size(0)
    # gt_bboxes 是一个三维张量,通常代表了一批(batch)真实边界框(ground truth bounding boxes)的数据。这行代码将这个张量的形状分解为三个部分,并分别赋值给变量 bs 、 n_max_boxes 和 _  。
    # 1. bs : 这个变量代表批量大小(batch size),即这批数据中包含的图像数量。
    # 2. n_max_boxes :这个变量代表每个图像中的最大边界框数量。在目标检测任务中,由于一张图像可能包含多个目标,因此每个图像可能有多个边界框。
    # 3. _ : 这个变量是一个占位符,用于接收张量的第三个维度的大小。在目标检测任务中,这个维度通常代表每个边界框的属性数量,例如 (x_min, y_min, x_max, y_max) 这四个坐标值。
    bs, n_max_boxes, _ = gt_bboxes.size()
    # 1. gt_bboxes : 这是一个包含地面真实边界框(ground truth bounding boxes)的张量。
        # 它的形状通常是 [batch_size, num_objects, 4] ,其中 batch_size 是批次中图像的数量, num_objects 是每张图像中对象的数量,而 4 代表每个边界框的四个坐标值(通常是 x_min, y_min, x_max, y_max )。
    # 2. reshape([-1, 4]) : reshape 方法用于将 gt_bboxes 张量重塑为一个新的形状。
        # -1 在这里是一个特殊的参数,表示该维度的大小会自动计算,以便保持数据的总元素数量不变。这样做通常是为了将多个维度合并为一个维度。 4 表示新的张量应该有一个大小为4的维度,这对应于每个边界框的四个坐标值。
    # 通过这个操作, _gt_bboxes 将是一个二维张量,其中第一个维度是所有边界框的展平版本,第二个维度包含每个边界框的四个坐标值。这种形状的张量通常用于计算损失函数,因为它允许你轻松地访问每个边界框的坐标值。
    # 例如,如果 gt_bboxes 的原始形状是 [32, 80, 4] (假设有32张图像,每张图像最多有80个边界框), reshape 操作后 _gt_bboxes 的形状将是 [2560, 4] ,这意味着有2560个边界框( 32 * 80 ),每个边界框由4个坐标值表示。
    # 这样的展平操作使得处理每个边界框变得更加方便,尤其是在批量操作中。
    _gt_bboxes = gt_bboxes.reshape([-1, 4])
    # 1. unsqueeze(0) : unsqueeze 方法在指定维度上增加一个维度。这里, unsqueeze(0) 在张量 xy_centers 的最前面添加一个新的维度。这通常用于创建一个额外的批次维度,使得张量的形状从 [n] 变为 [1, n] 。
    # 例如,如果 xy_centers 的原始形状是 [n, 2] (其中 n 是边界框的数量,2 代表每个边界框的中心点坐标 (x, y) ), unsqueeze(0) 后的形状将变为 [1, n, 2] 。
    # 2. repeat(bs * n_max_boxes, 1, 1) : repeat 方法将张量在指定的维度上重复指定的次数。这里, repeat(bs * n_max_boxes, 1, 1) 表示: 在第一个维度(批次维度)上重复 bs * n_max_boxes 次。 在第二个和第三个维度上各重复1次,这意味着这些维度上的内容不会改变。
    # 例如,如果 bs 是批次大小, n_max_boxes 是每个图像中的最大边界框数量,那么 bs * n_max_boxes 就是批次中所有图像的边界框总数。这确保了 xy_centers 张量在第一个维度上有足够的重复来匹配批次中所有边界框的数量。
    # 通过这个操作, xy_centers 将被扩展为一个新的张量,其形状变为 [bs * n_max_boxes, n, 2] 。这种形状的张量通常用于目标检测模型中,其中需要将每个边界框的中心点坐标与相应的锚点或预测边界框进行比较。
    # 总结来说,这行代码的目的是将 xy_centers 张量扩展到与批次中所有边界框的数量相匹配,同时保持每个边界框的中心点坐标不变。这是在准备数据以进行损失计算时的一个常见步骤,特别是在处理需要大量重复操作的批次数据时。
    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
    # 1. _gt_bboxes[:, 0:2] 和 _gt_bboxes[:, 2:4] :
    # _gt_bboxes 是一个二维张量,形状为 [total_bboxes, 4] ,其中 total_bboxes 是批次中所有图像的边界框总数,而 4 代表每个边界框的四个坐标值(通常是 x_min, y_min, x_max, y_max )。
    #  _gt_bboxes[:, 0:2] 选择所有边界框的前两个坐标( x_min, y_min ),代表左上角坐标。 _gt_bboxes[:, 2:4] 选择所有边界框的后两个坐标( x_max, y_max ),代表右下角坐标。
    # 2. .unsqueeze(1) : unsqueeze 方法在指定维度上增加一个维度。这里, unsqueeze(1) 在张量的第二个维度上添加一个新的维度,将形状从 [total_bboxes, 2] 变为 [total_bboxes, 1, 2] 。
    # 3. .repeat(1, n_anchors, 1) : repeat 方法将张量在指定的维度上重复指定的次数。
    # 这里, repeat(1, n_anchors, 1) 表示: 在第一个维度(边界框维度)上重复1次,这意味着这个维度上的内容不会改变。
    # 在第二个维度(锚点维度)上重复 n_anchors 次,其中 n_anchors 是每个图像中每个特征图单元格的锚点数量。在第三个维度(坐标维度)上重复1次,这意味着这个维度上的内容不会改变。
    # 这确保了每个边界框的左上角和右下角坐标被重复以匹配每个锚点的数量。通过这个操作, gt_bboxes_lt 和 gt_bboxes_rb 将被扩展为新的形状 [total_bboxes, n_anchors, 2] 。这种形状的张量通常用于目标检测模型中,其中需要将每个真实边界框的左上角和右下角坐标与相应的锚点进行比较。
    # 总结来说,这行代码的目的是将每个真实边界框的左上角和右下角坐标扩展到与批次中所有锚点的数量相匹配,同时保持每个边界框的坐标不变。这是在准备数据以进行损失计算时的一个常见步骤,特别是在处理需要大量重复操作的批次数据时。
    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
    # 通过相减, b_lt 计算了每个预测中心点到最近真实边界框左上角的距离。这个距离向量 b_lt 可以用来计算损失函数,特别是在需要考虑边界框定位精度的场景中。
    b_lt = xy_centers - gt_bboxes_lt
    # 通过相减, b_rb 计算了每个预测中心点到最近真实边界框右下角的距离。这个距离向量 b_rb 同样用于计算损失函数,帮助模型学习如何更准确地预测边界框的位置。
    b_rb = gt_bboxes_rb - xy_centers
    # torch.cat 函数将这两个向量沿着最后一个维度( dim=-1 )连接起来,形成一个包含四个距离值的新向量。这个新向量通常表示为 [dx, dy, dw, dh] ,其中 dx 和 dy 是中心点到左上角的距离, dw 和 dh 是中心点到右下角的距离。
    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
    # bbox_deltas 是通过连接 b_lt 和 b_rb 得到的张量,形状为 [total_anchors, 4] 。
    # reshape 方法将这个张量重新调整为一个新的形状,以匹配批次大小( bs )、每张图像中的最大边界框数量( n_max_boxes )、每个图像中每个特征图单元格的锚点数量( n_anchors )以及每个锚点的四个距离值。
    # -1 在这里是一个特殊的参数,表示该维度的大小会自动计算,以便保持数据的总元素数量不变。
    # 通过这个操作, bbox_deltas 将被重新调整为一个四维张量,其形状为 [bs, n_max_boxes, n_anchors, 4] 。这种形状的张量通常用于目标检测模型中,其中需要将每个预测边界框与相应的真实边界框进行比较。
    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
    # 返回值的是一个布尔张量,表示每个锚点是否为有效的候选锚点。
    # 1. bbox_deltas.min(axis=-1)[0] : bbox_deltas 是一个张量,包含了每个锚点与真实边界框之间的距离差异,形状为 [bs, n_max_boxes, n_anchors, 4] 。
    # .min(axis=-1) 计算每个锚点的四个距离差异中的最小值,这四个值分别代表锚点中心到真实边界框左上角和右下角的距离。
    # axis=-1 指定了沿着最后一个维度(即每个锚点的四个距离值)进行操作。 [0] 取出最小值操作的结果,得到一个形状为 [bs, n_max_boxes, n_anchors] 的张量。
    # 2. > eps : eps 是一个非常小的正数,用于避免除以零的情况。 > eps 比较操作,检查每个锚点的最小距离是否大于 eps ,这确保了锚点的中心点确实位于真实边界框内部。
    # 3. .to(gt_bboxes.dtype) : 将返回的布尔张量转换为与 gt_bboxes 相同的数据类型,这通常是为了确保后续计算的一致性。
    # 最终, select_candidates_in_gts 方法返回一个布尔张量,形状为 [bs, n_max_boxes, n_anchors] ,其中 True 表示对应的锚点中心点位于至少一个真实边界框的内部,这些锚点被认为是有效的候选锚点。这些信息将被用于后续的损失函数计算,特别是在确定正样本和负样本时。
    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)

4.def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): 

# 1. mask_pos : 这是一个布尔张量,形状为 [bs, n_max_boxes, num_total_anchors] ,其中 bs 是批次大小, n_max_boxes 是每张图像中的最大边界框数量, num_total_anchors 是所有特征图层上锚点的总数。 mask_pos 表示每个锚点是否与至少一个真实边界框有交集,即是否是候选的正样本。
# 2. overlaps : 这是一个张量,形状为 [bs, n_max_boxes, num_total_anchors] ,包含了每个锚点与每个真实边界框之间的重叠度(通常是IoU,即交并比)。 overlaps 用于确定每个锚点与每个真实边界框的相似度,这是选择正样本的关键信息。
# 3. n_max_boxes : 这是一个整数,表示每张图像中的最大边界框数量。 这个参数用于确保在处理重叠度时,索引不会超出边界框数量的范围。
def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
    # 如果一个anchor box被分配给多个gts,那么将选择iou最高的那个。
    """if an anchor box is assigned to multiple gts,
        the one with the highest iou will be selected.

    Args:
        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
    Return:
        target_gt_idx (Tensor): shape(bs, num_total_anchors)
        fg_mask (Tensor): shape(bs, num_total_anchors)
        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
    """
    fg_mask = mask_pos.sum(axis=-2)
    # 这个条件检查是否有任何锚点与超过一个真实边界框重叠。如果 fg_mask 中的最大值大于1,这意味着至少有一个锚点与多个真实边界框重叠。
    if fg_mask.max() > 1:
        # fg_mask.unsqueeze(1) 在第二个维度上增加一个维度,将 fg_mask 从 [bs, num_total_anchors] 变为 [bs, 1, num_total_anchors] 。
        # > 1 检查是否有任何锚点与超过一个真实边界框重叠。
        # repeat([1, n_max_boxes, 1]) 将这个布尔张量在第二个维度上重复 n_max_boxes 次,以便与真实边界框的数量对齐。
        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
        # overlaps 是一个张量,包含了每个锚点与每个真实边界框之间的重叠度。 argmax(axis=1) 找到每个锚点与真实边界框重叠度最高的那个边界框的索引。
        max_overlaps_idx = overlaps.argmax(axis=1)
        # F.one_hot 将 max_overlaps_idx 中的索引转换为一个one-hot编码张量,其中每个索引位置有一个值为1,其余为0。 这个one-hot编码张量的形状为 [bs, num_total_anchors, n_max_boxes] 。
        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
        # permute(0, 2, 1) 重新排列张量的维度,从 [bs, num_total_anchors, n_max_boxes] 变为 [bs, n_max_boxes, num_total_anchors] 。 to(overlaps.dtype) 将张量转换为与 overlaps 相同的数据类型。
        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
        # torch.where 根据 mask_multi_gts 中的布尔值选择 is_max_overlaps 或 mask_pos 中的值。 如果 mask_multi_gts 为 True ,则选择 is_max_overlaps 中的值;否则,选择 mask_pos 中的值。
        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
        # sum(axis=-2) 计算每个锚点与所有真实边界框的交集数量。 这确保了与多个真实边界框重叠的锚点只被计算一次。
        fg_mask = mask_pos.sum(axis=-2)
    target_gt_idx = mask_pos.argmax(axis=-2)
    # 1. target_gt_idx :这是一个张量,形状为 [bs, num_total_anchors] ,包含了每个锚点被分配到的真实边界框的索引。如果一个锚点与多个真实边界框有交集, target_gt_idx 将指向重叠度最高的那个真实边界框。
    # 2. fg_mask : 这是一个布尔张量,形状为 [bs, num_total_anchors] ,表示每个锚点是否被分配给了至少一个真实边界框。 fg_mask 用于区分正样本(foreground masks)和负样本(background masks)。
    # 3. mask_pos (可能被更新) : 在处理重叠度后, mask_pos 可能被更新以反映每个真实边界框与锚点之间的最高重叠度分配。
    return target_gt_idx, fg_mask , mask_pos

5.def iou_calculator(box1, box2, eps=1e-9): 

def iou_calculator(box1, box2, eps=1e-9):
    # 计算批量的iou。
    """Calculate iou for batch

    Args:
        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
    Return:
        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
    """
    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
    x1y1 = torch.maximum(px1y1, gx1y1)
    x2y2 = torch.minimum(px2y2, gx2y2)
    overlap = (x2y2 - x1y1).clip(0).prod(-1)
    area1 = (px2y2 - px1y1).clip(0).prod(-1)
    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
    union = area1 + area2 - overlap + eps

    return overlap / union

标签:YOLOv6,gt,4.0,assigner,max,张量,锚点,bboxes,边界
From: https://blog.csdn.net/m0_58169876/article/details/143315787

相关文章

  • 给虚拟机挂载一块硬盘(以ubuntu24.04为例)
    一、新增、分区、格式化新盘1、首先在虚拟机中增加一块新硬盘(500G)例如:Vmware、Exsi软件,增加完成后,查看一下:root@ubuntu:~#lsblk-fNAMEFSTYPEFSVERLABELUUIDFSAVAILFSUSE%MOUNTPOINTSsda......
  • ubuntu24.04安装完以后发现硬盘空间少一半
    1、查看现在硬盘情况root@ubuntu:~#df-hFilesystemSizeUsedAvailUse%Mountedontmpfs1.6G1.1M1.6G1%/runefivarfs256K64K188K26%/sys/firmware/efi/efivars/dev......
  • VS2022 添加旧版本.NET Framework 3.5/4.0支持
    鉴于vs2022最旧只支持到.netframework4.6.2有些项目.netframework版本比较低,又想要用新版本vs以3.5为例要使vs2022支持低版本.netframework项目,可参考以下步骤实现下载.netframeworknuget包下载链接如下,可根据需要下载对应版本v3.5v4.0v4.5修改后缀为zip或直接......
  • chatGpt4.0Plus,Claude3最新保姆级教程开通升级
     如何使用WildCard服务注册Claude3随着Claude3的震撼发布,最强AI模型的桂冠已不再由GPT-4独揽。Claude3推出了三个备受瞩目的模型:Claude3Haiku、Claude3Sonnet以及Claude3Opus,每个模型都展现了卓越的性能与特色。其中,Claude3Opus更是实现了对GPT-4的全......
  • Redis4.0.12集群搭建
    服务器:节点1:10.10.175.55 端口:6379/7379节点2:10.10.175.56 端口:6379/7379节点3:10.10.175.57 端口:6379/7379以下操作均需在每台服务器上执行安装依赖关系yuminstallmakezlibopenssl*ImageMagick-develgcc*rubygems-y2、创建节点目录mkdir-p/usr/local/redis-cl......
  • IObit Uninstaller Pro v14.0.0.17 解锁版 (强悍的驱动级软件卸载)
    IObitUninstallerProv14.0.0.17解锁版(强悍的驱动级软件卸载)IObitUninstaller,软件卸载程序。IObitUninstaller是款强悍的驱动级软件卸载工具,有强制卸载、批量卸载、安装监视器、文件粉碎、软件健康检查、卸载Windows更新补丁、移除浏览器工具栏和插件等功能。一、下载地址链......
  • Ubuntu 24.04自带RDP远程桌面
    Ubuntu24.04自带远程桌面啦,在Setting-->System-->DesktopSharing中可以开启远程桌面共享和远程控制,默认使用3390端口,开启远程服务后,在Windows机器中就可以用自带的远程桌面软件连接Ubuntu桌面了.另外一种可选方案是RemoteLogging,但启用此功能会导致当前用户被......
  • Ubuntu 24.04使用virtualBox启动虚拟机提示Kernel driver not installed的解决办法
    1.Ubuntu安装virtualBoxvirtualBox官方下载对应ubuntu24.04系统的deb安装包进入到下载文件所在目录使用如下apt命令安装下载好的deb安装包sudoaptinstall-f./virtualBox*2.启动虚拟机提示“Kerneldrivernotinstalled”由于我装的是双系统,ubuntu挂载了windows下使......
  • pear-admin-layui-main 4.0 admin.js bug 修复
    pearAdmin.instances.tabPage=tabPage.render({ elem:'content', session:param.tab.session, index:0, tabMax:param.tab.max, preload:param.tab.preload, closeEvent:function(id){ pearAdmin.instan......
  • Advanced Renamer v4.05.0 文件批量重命名工具绿色版
    AdvancedRenamer是一款界面简洁友好功能强大的轻量型批量重命名工具,用户无需专业知识就能掌握运用的高级批量修改文件名软件,AdvancedRenamer拥有比较强大的修改文件名功能,能够快速方便地对文件或文件夹进行修改名称,你可以在它的命名方案列表中添加方案,有新名称、新写法、移......