atss_assigner.py
yolov6\assigners\atss_assigner.py
目录
class ATSSAssigner(nn.Module):
1.所需的库和模块
import torch
import torch.nn as nn
import torch.nn.functional as F
from yolov6.assigners.iou2d_calculator import iou2d_calculator
from yolov6.assigners.assigner_utils import dist_calculator, select_candidates_in_gts, select_highest_overlaps, iou_calculator
2.class ATSSAssigner(nn.Module):
# 选择Anchor :对于每个Ground Truth(GT),根据中心距离最近的原则,从每个特征金字塔网络(FPN)层选择 k 个anchor。
# ATSSAssigner 的优点在于 :能够根据 IOU 的 均值 和 方差 动态调整阈值,从而更好地匹配目标对象,特别是在目标大小和形状变化较大的情况下表现出色。
class ATSSAssigner(nn.Module):
# 自适应训练样本选择分配器
'''Adaptive Training Sample Selection Assigner'''
# topk : 这个参数定义了在每个特征图层上,对于每个真实边界框,只考虑与其中心点距离最近的前 top_k 个锚点。这有助于减少计算量,并且专注于最有可能包含目标的锚点。
def __init__(self,
topk=9,
num_classes=80):
super(ATSSAssigner, self).__init__()
self.topk = topk
self.num_classes = num_classes
# self.bg_idx :是一个类属性,通常用于表示背景类别的索引。
# 在目标检测中,除了所有目标类别之外,还有一个特殊的类别代表背景(即不包含任何目标的区域)。
# 通过设置 self.bg_idx = num_classes ,标签分配器将背景类别的索引设置为目标类别总数的下一个索引。这意味着如果有 80 个目标类别,背景类别的索引将被设置为 81 。
self.bg_idx = num_classes
# torch.no_grad()
# 在 PyTorch 中, torch.no_grad() 是一个上下文管理器,用于在不需要计算梯度的场景中禁用梯度计算。这在模型评估和推理过程中尤为重要,因为它可以显著减少内存消耗并提高计算效率。
# 1. anc_bboxes : 这是一个张量,包含了所有锚点的边界框信息。在YOLOv6中,由于采用了anchor-free的设计,这个参数可能不是必需的,因为锚点是在特征图上动态生成的。
# 2. n_level_bboxes :通常指的是在不同特征层级(level)上的预测边界框(bounding boxes)的数量。
# 3. gt_labels : 这是一个张量,包含了真实边界框的类别标签信息。
# 4. gt_bboxes : 这是一个张量,包含了真实边界框的坐标信息。
# 5. mask_gt : 这是一个布尔张量,用于指示哪些真实边界框是有效的。
# 6. pd_bboxes : 这是一个张量,包含了预测的边界框信息。
@torch.no_grad()
def forward(self,
anc_bboxes,
n_level_bboxes,
gt_labels,
gt_bboxes,
mask_gt,
pd_bboxes):
# 本代码基于
# https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
r"""This code is based on
https://github.com/fcjian/TOOD/blob/master/mmdet/core/bbox/assigners/atss_assigner.py
Args:
anc_bboxes (Tensor): shape(num_total_anchors, 4)
n_level_bboxes (List):len(3)
gt_labels (Tensor): shape(bs, n_max_boxes, 1)
gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
mask_gt (Tensor): shape(bs, n_max_boxes, 1)
pd_bboxes (Tensor): shape(bs, n_max_boxes, 4)
Returns:
target_labels (Tensor): shape(bs, num_total_anchors)
target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
fg_mask (Tensor): shape(bs, num_total_anchors)
"""
# anc_bboxes 是一个张量,包含了所有锚点(anchors)的边界框信息。 size(0) 获取这个张量的第一个维度的大小,即锚点的总数。
# self.n_anchors 存储了这个数量,表示 每个特征图单元格上锚点的数量 。这个属性在后续的锚点分配和损失计算中非常重要。
self.n_anchors = anc_bboxes.size(0)
# gt_bboxes 是一个张量,包含了真实边界框(ground truth bounding boxes)的坐标信息。 size(0) 获取这个张量的批次大小,即这 批数据中图像的数量 。
# self.bs 存储了这个批次大小,用于后续的计算,确保处理的数据与批次中的图像数量一致。
self.bs = gt_bboxes.size(0)
# gt_bboxes 张量的第二个维度通常表示每个图像中的最大边界框数量。 size(1) 获取这个维度的大小。
# self.n_max_boxes 存储了这个最大边界框数量,它代表了 每个图像中可以包含的最大目标数量 。这个属性在处理不同图像中不同数量的目标时非常重要。
self.n_max_boxes = gt_bboxes.size(1)
if self.n_max_boxes == 0:
device = gt_bboxes.device
# torch.full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) → Tensor
# 返回创建 size 大小的维度,里面元素全部填充为 fill_value 。
# torch.zeros(size, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False)
# torch.zeros() 是 PyTorch 中用来创建全 0 张量的函数。
# size :参数表示张量的形状(shape),可以是一个整数或者一个包含多个整数的 tuple 。例如, torch.zeros(3, 4) 就会创建一个 3x4 的全 0 矩阵。
# out :参数表示输出张量。如果指定了这个参数,那么函数会将结果存储在这个张量里,而不是新建一个张量。
# dtype :参数表示张量的数据类型,可以是 torch.float32 、 torch.int64 等。
# layout 、 device 、 requires_grad :分别表示张量的内存布局、存储设备、是否需要求导,一般不用设置。
return torch.full( [self.bs, self.n_anchors], self.bg_idx).to(device), \
torch.zeros([self.bs, self.n_anchors, 4]).to(device), \
torch.zeros([self.bs, self.n_anchors, self.num_classes]).to(device), \
torch.zeros([self.bs, self.n_anchors]).to(device)
# def iou2d_calculator(bboxes1, bboxes2, mode='iou', is_aligned=False, scale=1., dtype=None): -> 2D 重叠(例如 IoU、GIoU)计算器。计算 2D bbox 之间的 IoU。
overlaps = iou2d_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)
# 用于重新调整 overlaps 张量的形状,以适应模型训练中对维度的特定需求。
# 1. self.bs :这是模型的一个属性,表示批次大小(batch size),即当前批次中的图像数量。
# 2. -1 : 在 reshape 函数中, -1 是一个特殊的值,表示该维度的大小将自动计算,以便保持数据的总元素数量不变。这通常用于让PyTorch自动推断正确的维度大小。
# 3. self.n_anchors : 这是模型的另一个属性,表示每个特征图单元格上的锚点数量。
# 通过这个操作, overlaps 张量被重新调整为一个三维张量,其形状为 [bs, num_gts, n_anchors] 。
# 其中: bs :是批次大小。 num_gts :是每个图像中真实边界框的数量,由于使用了 -1 ,这个维度的大小将自动计算。 n_anchors :是每个特征图单元格上的锚点数量。
# 这种重新调整形状的操作通常发生在锚点分配步骤中,用于确保重叠度张量与批次中的图像数量和每个图像的锚点数量相匹配。
# 这样,模型可以为每个锚点计算与每个真实边界框的重叠度,并执行后续的锚点分配策略,例如选择与真实边界框重叠度最高的锚点作为正样本(positive samples)。
# 例如,如果 overlaps 的原始形状是 [bs * num_gts * n_anchors] ,那么重新调整形状后, overlaps 的形状将变为 [bs, num_gts, n_anchors] ,这使得每个批次中的每张图像都有对应的锚点与真实边界框的重叠度信息。
overlaps = overlaps.reshape([self.bs, -1, self.n_anchors])
# def dist_calculator(gt_bboxes, anchor_bboxes): -> 计算所有 bbox 和 gt 之间的中心距离。 -> return distances, ac_points
# distances : 这是一个张量,包含了每个锚点与对应真实边界框之间的距离。
# ac_points : 这是一个张量,包含了所有锚点的中心点坐标。
distances, ac_points = dist_calculator(gt_bboxes.reshape([-1, 4]), anc_bboxes)
# 用于重新调整 distances 张量的形状,以适应模型训练中对维度的特定需求。
# 1. self.bs :这是模型的一个属性,表示批次大小(batch size),即当前批次中的图像数量。
# 2. -1 : 在 reshape 函数中, -1 是一个特殊的值,表示该维度的大小将自动计算,以便保持数据的总元素数量不变。这通常用于让PyTorch自动推断正确的维度大小。
# 3. self.n_anchors : 这是模型的另一个属性,表示每个特征图单元格上的锚点数量。
# 通过这个操作, distances 张量被重新调整为一个三维张量,其形状为 [bs, num_gts, n_anchors] 。
# 其中: bs :是批次大小。 num_gts :是每个图像中真实边界框的数量,由于使用了 -1 ,这个维度的大小将自动计算。 n_anchors :是每个特征图单元格上的锚点数量。
distances = distances.reshape([self.bs, -1, self.n_anchors])
# def select_topk_candidates(self, distances, n_level_bboxes, mask_gt): -> 筛选后的候选列表,筛选前的候选列表。 -> return is_in_candidate_list, candidate_idxs
# is_in_candidate_list :这是一个包含了多个元素的列表,每个元素都是一个张量(tensor),表示在每个目标检测层级中,哪些先验框(anchor boxes)被选为了候选框。
# candidate_idxs :这是一个包含了多个元素的列表,每个元素都是一个张量,包含了在每个目标检测层级中被选为候选框的先验框的索引。这些索引是基于全局的先验框索引空间的,意味着它们指向了在整个特征图中的先验框位置。
is_in_candidate, candidate_idxs = self.select_topk_candidates(
distances, n_level_bboxes, mask_gt)
# def thres_calculator(self,is_in_candidate, candidate_idxs, overlaps): -> 候选列表重叠区域阈值,后选列表重叠区域。 -> return overlaps_thr_per_gt, _candidate_overlaps
# 1. overlaps_thr_per_gt :这个变量代表每个Ground Truth(GT,即真实边界框)对应的样本阈值。只有当锚框与GT的IoU值超过这个阈值时,该锚框才被认为是正样本。
# 2. iou_candidates :这个变量代表每个候选锚框与对应GT的IoU值。
overlaps_thr_per_gt, iou_candidates = self.thres_calculator(
is_in_candidate, candidate_idxs, overlaps)
# 选择候选 iou >= 阈值作为正
# select candidates iou >= threshold as positive
# 对于每个候选锚框,检查它与对应GT的IoU值( iou_candidates )是否大于该GT的样本阈值( overlaps_thr_per_gt )。
# 如果IoU值大于样本阈值,那么这个锚框应该被标记为正样本,所以 is_pos 的对应位置被设置为 is_in_candidate 的值(即1)。
# 如果IoU值不大于样本阈值,那么这个锚框不应该被标记为正样本,所以 is_pos 的对应位置被设置为0。
# 最终, is_pos 张量将包含每个候选锚框是否应该被标记为正样本的信息。这个信息将用于后续的训练过程中,以确保模型只从那些真正有助于检测GT的目标中学习。
is_pos = torch.where(
iou_candidates > overlaps_thr_per_gt.repeat([1, 1, self.n_anchors]),
is_in_candidate, torch.zeros_like(is_in_candidate))
# def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9): -> 在 gt 中选择正锚框的中心。 -> return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
# is_in_gts :是一个布尔张量,表示每个锚点是否为有效的候选锚点。
is_in_gts = select_candidates_in_gts(ac_points, gt_bboxes)
# mask_pos :张量将包含所有被选为正样本的锚框的掩码,其中值为1的位置表示该锚框是一个正样本,值为0的位置表示不是。
# 用于生成最终的正样本掩码( mask_pos ),它结合了三个不同的条件:
# 1. is_pos :这是一个布尔张量,表示哪些锚框与对应的Ground Truth(GT)的IoU超过了设定的阈值,即这些锚框被认为是正样本。
# 2. is_in_gts :这是一个布尔张量,表示哪些锚框落在了GT的边界框内。通常,这个条件用于确保锚框不仅与GT有较高的IoU,而且实际上与GT有空间上的重叠。
# 3. mask_gt :这是一个掩码张量,用于标记哪些锚框是对应于特定GT的。在YOLOv6中,每个GT可能对应多个锚框, mask_gt 用于区分这些锚框属于哪个GT。
mask_pos = is_pos * is_in_gts * mask_gt
# def select_highest_overlaps(mask_pos, overlaps, n_max_boxes): -> 如果一个anchor box被分配给多个gts,那么将选择iou最高的那个。 -> return target_gt_idx, fg_mask , mask_pos
# 1. target_gt_idx :这是一个张量,形状为 [bs, num_total_anchors] ,包含了每个锚点被分配到的真实边界框的索引。如果一个锚点与多个真实边界框有交集, target_gt_idx 将指向重叠度最高的那个真实边界框。
# 2. fg_mask : 这是一个布尔张量,形状为 [bs, num_total_anchors] ,表示每个锚点是否被分配给了至少一个真实边界框。 fg_mask 用于区分正样本(foreground masks)和负样本(background masks)。
# 3. mask_pos (可能被更新) : 在处理重叠度后, mask_pos 可能被更新以反映每个真实边界框与锚点之间的最高重叠度分配。
target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
mask_pos, overlaps, self.n_max_boxes)
# 指定目标。
# assigned target
# def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask): -> 指定目标标签,指定目标边界框,指定目标分数。 -> return target_labels, target_bboxes, target_scores
# target_labels :张量将包含每个锚框的目标类别标签,其中前景锚框有对应的类别标签,背景锚框则被标记为背景。
# target_bboxes :选择的边界框坐标。形为三维张量,其中第一维是批次大小 self.bs ,第二维是每个样本的锚框数量 self.n_anchors ,第三维是边界框的四个坐标值 (x_min, y_min, x_max, y_max) 。
# target_scores :张量将包含每个锚框的目标类别的独热编码分数。
target_labels, target_bboxes, target_scores = self.get_targets(
gt_labels, gt_bboxes, target_gt_idx, fg_mask)
# soft label with iou 使用IoU soft标签。
# 检查是否存在预测的边界框。
if pd_bboxes is not None:
# def iou_calculator(box1, box2, eps=1e-9): -> 计算批量的iou。 -> return overlap / union
ious = iou_calculator(gt_bboxes, pd_bboxes) * mask_pos
# ious.max(axis=-2)[0]:使用 .max() 函数沿着 axis=-2 维度找到最大值。
# 在PyTorch中, axis=-2 通常指的是最后一个维度之前的维度,即在形状为 [batch_size, num_boxes1, num_boxes2] 的张量中, axis=-2 指的是 num_boxes1 维度。 [0] 表示从结果中选择第一个元素,这通常是一个包含最大IoU值的张量。
# unsqueeze(-1):使用 .unsqueeze() 函数在最后一个维度添加一个额外的维度。这将IoU值的张量从一维变为二维,形状从 [batch_size, num_boxes] 变为 [batch_size, num_boxes, 1] 。这个操作通常是为了保持张量维度的一致性,以便与其他张量进行操作,如在某些损失函数中。
# 最终,这行代码得到的 ious 张量包含了每个边界框与其最匹配的边界框之间的最大IoU值,并且增加了一个额外的维度。
ious = ious.max(axis=-2)[0].unsqueeze(-1)
target_scores *= ious
# 1. target_labels.long() : target_labels 是一个张量,包含了每个正样本锚框对应的类别标签。 .long() 方法将 target_labels 转换为长整型(LongTensor),这是因为标签通常是整数类型。 这个张量用于分类损失的计算,指导模型学习如何正确分类每个检测到的目标。
# 2. target_bboxes : target_bboxes 是一个张量,包含了每个正样本锚框对应的真实边界框坐标。 这些坐标通常用于回归损失的计算,指导模型学习如何精确地定位目标。
# 3. target_scores: target_scores 是一个张量,包含了每个正样本锚框的类别分数或概率,通常以独热编码形式表示。 这个张量用于计算分类损失,特别是在模型输出层使用交叉熵损失时。
# 4. fg_mask.bool(): fg_mask 是一个掩码张量,用于标记哪些锚框是前景(即包含目标的锚框)。 .bool() 方法将 fg_mask 转换为布尔类型,其中True表示对应的锚框是前景,False表示背景。 这个掩码用于筛选出正样本,确保只有前景锚框参与损失的计算。
return target_labels.long(), target_bboxes, target_scores, fg_mask.bool()
# 1. distances : 这是一个张量,包含了每个锚点与每个真实边界框之间的距离或重叠度(例如,IoU)。这个张量的形状通常是 [batch_size, num_anchors, num_gt_bboxes] ,其中 batch_size 是批次大小, num_anchors 是锚点的总数, num_gt_bboxes 是真实边界框的总数。
# 2. n_level_bboxes : 通常指的是在不同特征层级(level)上的预测边界框(bounding boxes)的数量。
# 3. mask_gt : 这是一个布尔张量,用于指示哪些真实边界框是有效的。这可以用于过滤掉那些不应该被考虑的边界框,例如由于数据清洗或增强操作而产生的无效边界框。形状通常是 [batch_size, num_gt_bboxes] 。
def select_topk_candidates(self,
distances,
n_level_bboxes,
mask_gt):
# mask_gt 是一个布尔张量,形状可能是 [bs, num_gts] ,其中 bs 是批次大小, num_gts 是每个图像中真实边界框的数量。
# repeat(1, 1, self.topk) 将 mask_gt 张量沿着第三个维度(索引从0开始计数,所以是最后一个维度)重复 self.topk 次。
# self.topk 是一个整数,表示你想要为每个真实边界框选择的锚点数量。
# 这个操作将 mask_gt 张量的形状从 [bs, num_gts] 改变为 [bs, num_gts, self.topk] 。
# 最终, mask_gt 将是一个形状为 [bs, num_gts, self.topk] 的布尔张量,其中每个 True 值表示对应的真实边界框是有效的,并且这个有效性被重复了 self.topk 次以匹配后续操作中锚点的数量。
mask_gt = mask_gt.repeat(1, 1, self.topk).bool()
# torch.split(tensor, split_size_or_sections, dim=0)
# torch.split 函数在PyTorch中用于将一个张量(Tensor)分割成多个较小的张量,这些张量在指定的维度上具有相等或不同的大小。这个函数非常灵活,可以根据需要分割张量。
# tensor :要分割的输入张量。
# split_size_or_sections :一个整数或张量大小的序列。 如果是一个整数,表示每个分割块的大小(除了可能的最后一块)。 如果是一个序列,表示每个分割块的大小。
# dim :要沿哪个维度进行分割。默认是0。
# 返回值:
# 返回一个张量元组,包含分割后的各个张量。
# torch.split 函数被用来将 distances 张量沿着最后一个维度( dim=-1 )分割成多个较小的张量,每个张量的大小由 n_level_bboxes 指定。
# 例如,如果 distances 的形状是 [bs, num_anchors, num_level_bboxes] ,那么分割后的每个张量的形状将是 [bs, num_anchors] ,每个张量对应一个特征图层上的所有锚点与真实边界框之间的距离。
level_distances = torch.split(distances, n_level_bboxes, dim=-1)
is_in_candidate_list = []
# candidate_idxs 是一个用于存储候选框索引的列表。这个列表通常用于目标检测过程中的样本匹配阶段,即确定哪些先验框(anchor points)是潜在的正样本或负样本。
# 在目标检测的上下文中,"候选框"(candidate boxes)或 "先验框"(anchor points)是指网络预测的框,它们用于与真实框(ground truth boxes)进行比较,以计算损失并更新模型权重。
# 在训练过程中,模型需要决定哪些候选框与真实框足够接近,以便被视为正样本,哪些则因为距离过远而被视为负样本。
candidate_idxs = []
start_idx = 0
for per_level_distances, per_level_boxes in zip(level_distances, n_level_bboxes):
end_idx = start_idx + per_level_boxes
selected_k = min(self.topk, per_level_boxes)
# torch.topk(input, k, dim=None, largest=True, sorted=True, out=None) -> (Tensor, LongTensor)
# 沿给定 dim 维度返回输入张量 input 中 k 个最大值。
# 如果不指定 dim ,则默认为 input 的最后一维。
# 如果为 largest 为 False ,则返回最小的 k 个值。
# 返回一个元组 ( values , indices ),其中 values 是原始输入张量, indices 中测元素下标。
# 如果设定布尔值 sorted 为 True ,将会确保返回的 k 个值被排序。
# input(Tensor) :输入的张量。
# k(int) :前 k 个大小中的 k。
# dim(int, optional) :需要进行排序的维度, dim = 0 表示按照列来排序, dim = 1 表示按照行来排序, 默认情况下, dim = 1。
# largest(bool, optional) :控制是否返回最大值或最小值。
# sorted(bool, optional) :控制是否对元素进行排序后再返回。
# out(tuple,可选) :(Tensor,LongTensor)的输出元组,可以可选地指定用作输出缓冲区。
# 返回值
# values (Tensor) :沿着指定维度返回的最大(或最小)的k个元素。
# indices (LongTensor) :返回的元素的索引。
# per_level_distances 是一个张量,包含了每个先验框与真实框之间的距离。
# topk 函数用于选择每个目标检测层级中距离最小的 selected_k 个先验框。
# dim=-1 表示在最后一个维度(即每个目标的先验框)上进行操作。
# largest=False 表示选择距离最小的先验框,而不是最大的。
# _ 表示忽略 topk 函数返回的第一个值(即每个先验框的距离)。
# per_level_topk_idxs 存储了每个目标检测层级中距离最小的 selected_k 个先验框的索引。
_, per_level_topk_idxs = per_level_distances.topk(selected_k, dim=-1, largest=False)
# start_idx 是当前目标检测层级的起始索引。
# per_level_topk_idxs + start_idx 将 per_level_topk_idxs 中的索引转换为全局索引。
# candidate_idxs.append(...) 将转换后的索引添加到 candidate_idxs 列表中,用于后续处理。
candidate_idxs.append(per_level_topk_idxs + start_idx)
# mask_gt 是一个掩码张量,用于标识哪些目标是正样本(即与真实框有较高 IoU 的目标)。
# torch.where(mask_gt, per_level_topk_idxs, torch.zeros_like(per_level_topk_idxs)) 用于将非正样本的索引设置为 0,只保留正样本的索引。
per_level_topk_idxs = torch.where(mask_gt,
per_level_topk_idxs, torch.zeros_like(per_level_topk_idxs))
# F.one_hot(per_level_topk_idxs, per_level_boxes) 将 per_level_topk_idxs 转换为 one-hot 编码形式。
# sum(dim=-2) 计算每个先验框在 one-hot 编码中的总和,用于标识哪些先验框被选为候选框。
is_in_candidate = F.one_hot(per_level_topk_idxs, per_level_boxes).sum(dim=-2)
# 这个操作用于确保每个先验框最多只能被选为一个候选框。如果一个先验框被选为多个候选框,则将其设置为 0。
is_in_candidate = torch.where(is_in_candidate > 1,
torch.zeros_like(is_in_candidate), is_in_candidate)
# 将 is_in_candidate 转换为与 distances 相同的数据类型,并添加到 is_in_candidate_list 列表中。
is_in_candidate_list.append(is_in_candidate.to(distances.dtype))
# 更新 start_idx 为 end_idx ,为下一个目标检测层级的处理做准备。
start_idx = end_idx
is_in_candidate_list = torch.cat(is_in_candidate_list, dim=-1)
candidate_idxs = torch.cat(candidate_idxs, dim=-1)
# is_in_candidate_list :这是一个包含了多个元素的列表,每个元素都是一个张量(tensor),表示在每个目标检测层级中,哪些先验框(anchor boxes)被选为了候选框。
# 具体来说, is_in_candidate 张量中的每个元素都是一个 one-hot 编码,指示了每个先验框是否被选为候选框。如果一个先验框被选为候选框,那么在对应的位置会有非零值(通常是1)。这个列表在后续的损失计算中非常重要,因为它决定了哪些先验框将参与损失的计算。
# candidate_idxs :这是一个包含了多个元素的列表,每个元素都是一个张量,包含了在每个目标检测层级中被选为候选框的先验框的索引。这些索引是基于全局的先验框索引空间的,意味着它们指向了在整个特征图中的先验框位置。
# 在训练过程中,这些索引用于从预测的边界框、对象置信度和类别概率中选择出对应的候选框,以便计算损失。
return is_in_candidate_list, candidate_idxs
def thres_calculator(self,
is_in_candidate,
candidate_idxs,
overlaps):
# 用于计算在一次批量(batch)处理中可以处理的最大目标数量。
# self.bs :这通常代表批量大小(batch size),即一次训练迭代中同时处理的样本数量。
# self.n_max_boxes :这代表每个样本中可以处理的最大目标数量。在目标检测任务中,一个图像可能包含多个目标, n_max_boxes 就是限制每个图像中可以被检测到的目标的最大数量。
# n_bs_max_boxes :这是将批量大小与每个样本的最大目标数量相乘得到的结果,表示在一次批量处理中可以处理的最大目标数量。这个值在确定网络输出的尺寸时非常重要,因为网络需要为每个目标预测边界框、置信度和类别概率。
# 例如,如果批量大小是 8 ( self.bs = 8 ),并且每个样本中可以处理的最大目标数量是 50 ( self.n_max_boxes = 50 ),那么 n_bs_max_boxes 将会是 400。这意味着在一次批量处理中,网络可以处理最多 400 个目标。
n_bs_max_boxes = self.bs * self.n_max_boxes
_candidate_overlaps = torch.where(is_in_candidate > 0,
overlaps, torch.zeros_like(overlaps))
candidate_idxs = candidate_idxs.reshape([n_bs_max_boxes, -1])
# 用于生成一个辅助索引张量,这个张量将用于后续的索引操作,以处理每个目标检测层级的先验框(anchor boxes)。
# torch.arange(n_bs_max_boxes) 生成一个从 0 到 n_bs_max_boxes - 1 的整数序列。 n_bs_max_boxes 是一次批量处理中可以处理的最大目标数量。
# self.n_anchors * torch.arange(n_bs_max_boxes, device=candidate_idxs.device) : 这个操作将生成的整数序列乘以 self.n_anchors ,生成一个新的序列。这个序列将用于索引操作,以处理每个目标检测层级的先验框。
# assist_idxs :这个辅助索引张量将用于后续的索引操作,以处理每个目标检测层级的先验框。
assist_idxs = self.n_anchors * torch.arange(n_bs_max_boxes, device=candidate_idxs.device)
# 增加张量的维度。
# None (或者在 Python 中也可以写作 newaxis )被插入到张量的维度中,创建了一个新的轴。
assist_idxs = assist_idxs[:,None]
# 这行代码的目的是将候选框索引( candidate_idxs )与辅助索引( assist_idxs )相加以生成一个扁平化的索引张量( flatten_idxs )。这个操作通常用于处理多尺度特征图和多个先验框(anchor boxes)。
# 1. candidate_idxs :这是一个张量,包含了在每个目标检测层级中被选为候选框的先验框的索引。这些索引是基于全局的先验框索引空间的。
# 2. assist_idxs :这是一个辅助索引张量,用于处理每个目标检测层级的先验框。它通过将 self.n_anchors (每个像素点上的先验框数量)乘以一个从 0 到 n_bs_max_boxes - 1 的整数序列来生成。
# 3. candidate_idxs + assist_idxs :这个操作将 candidate_idxs 和 assist_idxs 相加,生成一个新的索引张量 flatten_idxs 。这个操作的目的是将候选框索引与辅助索引相结合,以便在后续操作中快速索引到每个候选框的相关信息。
# 如果 candidate_idxs 是一个形状为 (n,) 的一维张量,而 assist_idxs 是一个形状为 (n, 1) 的二维张量,那么 PyTorch 会自动应用广播机制,使得 assist_idxs 的第二个维度被广播以匹配 candidate_idxs 的形状,从而使得每个元素都与 assist_idxs 的对应元素相加。
# 4. 结果:相加操作的结果是一个新的张量 flatten_idxs ,它包含了每个候选框的全局索引。这个张量可以用于索引操作,例如从一个大的张量中选择特定的行或列。
faltten_idxs = candidate_idxs + assist_idxs
candidate_overlaps = _candidate_overlaps.reshape(-1)[faltten_idxs]
candidate_overlaps = candidate_overlaps.reshape([self.bs, self.n_max_boxes, -1])
overlaps_mean_per_gt = candidate_overlaps.mean(axis=-1, keepdim=True)
overlaps_std_per_gt = candidate_overlaps.std(axis=-1, keepdim=True)
overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
# 1. overlaps_thr_per_gt :这个变量代表每个Ground Truth(GT,即真实边界框)对应的样本阈值。在YOLOv6中,为了决定哪些锚框(anchor boxes)将被标记为正样本(即负责预测某个GT的框),模型需要计算锚框和GT之间的交并比(IoU)。
# overlaps_thr_per_gt 就是基于这些IoU值计算出来的一个阈值,只有当锚框与GT的IoU值超过这个阈值时,该锚框才被认为是正样本。这个阈值通常是基于每个GT的所有候选锚框的IoU值统计得出的,例如取这些IoU值的某个分位数作为阈值。
# 2. _candidate_overlaps :这个变量代表每个候选锚框与对应GT的IoU值。在样本匹配过程中,对于每个GT,模型会计算所有锚框与该GT的IoU值,这些IoU值就是 _candidate_overlaps 。
# 然后,根据这些IoU值和上面提到的 overlaps_thr_per_gt 阈值,模型可以决定哪些锚框应该被选为正样本。
return overlaps_thr_per_gt, _candidate_overlaps
# 1. gt_labels :这是一个张量,包含了每个Ground Truth(GT)对象的类别标签。这些标签指示了每个GT对象属于哪个类别。
# 2. gt_bboxes :这是一个张量,包含了每个GT对象的边界框坐标。通常,这些坐标以 (x_min, y_min, x_max, y_max) 的形式表示,其中 (x_min, y_min) 是边界框左上角的坐标, (x_max, y_max) 是边界框右下角的坐标。
# 3. target_gt_idx :这是一个张量,包含了每个锚框对应的GT对象的索引。在样本匹配过程中,每个锚框会被分配给一个GT对象(如果它与某个GT对象的IoU超过某个阈值), target_gt_idx 就是用来记录这种分配关系的。
# 4. fg_mask :这是一个掩码张量,用于标记哪些锚框是前景(即包含目标的锚框)。在目标检测中,通常只有与GT对象有较高IoU的锚框才会被标记为前景,这些锚框将被用于计算损失函数。
def get_targets(self,
gt_labels,
gt_bboxes,
target_gt_idx,
fg_mask):
# assigned target labels 指定目标标签。
# batch_idx:这是一个张量,包含了从0到 self.bs-1 的整数序列,其中 self.bs 是批次大小。这个序列用于为每个样本在批次中生成唯一的索引。
batch_idx = torch.arange(self.bs, dtype=gt_labels.dtype, device=gt_labels.device)
# 通过在最后添加一个维度,将 batch_idx 从一维张量变为二维张量。
batch_idx = batch_idx[...,None]
# target_gt_idx + batch_idx * self.n_max_boxes : 这里, target_gt_idx 是一个张量,包含了每个锚框对应的GT对象的索引。
# 通过将 target_gt_idx 与 batch_idx 相加,并将结果乘以 self.n_max_boxes (每个样本的最大GT数量),为每个样本生成一个唯一的索引。这个操作确保了不同样本中的相同GT索引不会发生冲突。
target_gt_idx = (target_gt_idx + batch_idx * self.n_max_boxes).long()
# 首先,将 gt_labels 展平为一维张量。然后,使用上面计算的索引从展平的 gt_labels 中选择对应的类别标签。这将为每个锚框分配正确的类别标签。
target_labels = gt_labels.flatten()[target_gt_idx.flatten()]
# 将选择的类别标签重新塑形为二维张量,其中第一维是 批次大小 ,第二维是每个样本的 锚框数量 。
target_labels = target_labels.reshape([self.bs, self.n_anchors])
# 使用 torch.where 函数根据 fg_mask (前景掩码)来设置目标标签。如果 fg_mask 中的值为正(即锚框是前景),则使用对应的 target_labels 值。否则,使用 self.bg_idx (背景索引)填充目标标签。这一步确保了只有前景锚框保留其类别标签,而背景锚框则被标记为背景。
# 最终, target_labels 张量将包含每个锚框的目标类别标签,其中前景锚框有对应的类别标签,背景锚框则被标记为背景。这些目标标签将用于计算模型在训练过程中的损失,指导模型学习如何准确地检测和分类目标对象。
target_labels = torch.where(fg_mask > 0,
target_labels, torch.full_like(target_labels, self.bg_idx))
# assigned target boxes 指定目标边界框
# gt_bboxes.reshape([-1, 4]) :首先,将包含所有GT边界框坐标的张量 gt_bboxes 重塑为一个二维张量,其中每一行代表一个边界框的 (x_min, y_min, x_max, y_max) 坐标。 -1 在这里表示自动计算该维度的大小,以确保所有边界框都被包含。
# target_gt_idx.flatten() : target_gt_idx 是一个张量,包含了每个锚框对应的GT边界框的索引。通过调用 flatten() 方法,将这个张量展平成一维张量,以便从中选择边界框坐标。
# gt_bboxes.reshape([-1, 4])[target_gt_idx.flatten()] : 使用上面得到的一维索引张量从重塑后的 gt_bboxes 中选择对应的边界框坐标。这将为每个锚框分配正确的GT边界框坐标。
target_bboxes = gt_bboxes.reshape([-1, 4])[target_gt_idx.flatten()]
# 最后,将选择的边界框坐标重新塑形为三维张量,其中第一维是批次大小 self.bs ,第二维是每个样本的锚框数量 self.n_anchors ,第三维是边界框的四个坐标值 (x_min, y_min, x_max, y_max) 。
target_bboxes = target_bboxes.reshape([self.bs, self.n_anchors, 4])
# assigned target scores 指定目标分数
# target_labels 是一个张量,包含了每个锚框的目标类别标签。 self.num_classes + 1 是独热编码的类别数,其中 self.num_classes 是数据集中的类别总数,额外的 +1 通常是为了包含背景类别(即没有目标的类别)。
# .float() :将独热编码的张量转换为浮点数类型,这是因为模型训练中通常使用浮点数进行计算。
target_scores = F.one_hot(target_labels.long(), self.num_classes + 1).float()
# target_scores[:, :, :self.num_classes] :这一步是将独热编码张量中的背景类别(通常是最后一个类别)切掉,只保留数据集中实际的类别。
# self.num_classes 是实际的类别数,不包括背景类别。这样, target_scores 张量的形状将是 [batch_size, num_anchors, num_classes] ,其中 batch_size 是批次大小, num_anchors 是每个网格点上的锚框数量, num_classes 是数据集中的类别数。
# 通过这个过程, target_scores 张量将包含每个锚框的目标类别的独热编码分数,这些分数将用于计算模型在训练过程中的分类损失,指导模型学习如何准确地分类目标对象。背景类别的分数通常在后续的损失计算中被忽略或特别处理。
target_scores = target_scores[:, :, :self.num_classes]
# 1. target_labels :张量将包含每个锚框的目标类别标签,其中前景锚框有对应的类别标签,背景锚框则被标记为背景。
# 2. target_bboxes :选择的边界框坐标。形为三维张量,其中第一维是批次大小 self.bs ,第二维是每个样本的锚框数量 self.n_anchors ,第三维是边界框的四个坐标值 (x_min, y_min, x_max, y_max) 。
# 3. target_scores :张量将包含每个锚框的目标类别的独热编码分数。
return target_labels, target_bboxes, target_scores
标签:YOLOv6,gt,4.0,assigner,self,candidate,张量,bboxes,target
From: https://blog.csdn.net/m0_58169876/article/details/143350427