首页 > 其他分享 >YOLOV8的Detect head 逐行解读

YOLOV8的Detect head 逐行解读

时间:2024-07-06 13:59:06浏览次数:18  
标签:head Detect anchors self torch shape 8400 80 逐行

YOLOV8从不同的特征层,得到不同大小的特征图,然后预测每个特征图的每个格子anchor的类别概率,以及每个格子中物体的边框,即相对于中心点上下左右的偏移量box。

shape为[(1, 144, 80, 80),(1, 144, 40, 40),(1,144,20,20)]。

 输入x为从不同的上采样层得到的结果

x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)
box, cls = x_cat.split((self.reg_max * 4, self.nc), 1)
 #(1,64,8400),(1,80,8400)

整合这些结果,得到的shape为 (1,144,8400)。其中:
       8400 = 80 * 80+40 * 40+20 * 20,总的预测数
       144 为80个class和4*16个box
       4 为预测的四个边框距离中心点的距离,是Anchor-Free的预测目标,格式为[left,top,right,bottom]。
        self.reg_max = 16,是中心点的最大预测范围,即边框距离中心点的最远距离为16,但并不是16个像素,因为预测值都进行了不同stride的缩放。这个参数也决定了检测物体最大边框为 reg_max * stride*2。

self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5))
#(2,8400),(1,8400)
self.shape = shape #(1, 144, 80, 80)

def make_anchors(feats, strides, grid_cell_offset=0.5):
    """Generate anchors from features."""
    anchor_points, stride_tensor = [], []
    assert feats is not None
    dtype, device = feats[0].dtype, feats[0].device
    for i, stride in enumerate(strides):
        _, _, h, w = feats[i].shape
        sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset  # shift x
        sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset  # shift y
        sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
        anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
        stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
    return torch.cat(anchor_points), torch.cat(stride_tensor)


self.anchors[:,:10]
tensor([[0.5000, 1.5000, 2.5000, 3.5000, 4.5000, 5.5000, 6.5000, 7.5000, 8.5000, 9.5000],
        [0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000]], device='cuda:0')

self.strides[:,:10]
tensor([[8., 8., 8., 8., 8., 8., 8., 8., 8., 8.]], device='cuda:0')

 make_anchors,主要生成预测的网格点,

其中x 的shape [(1, 144, 80, 80),(1, 144, 40, 40),(1,144,20,20)]

self.stride 的值为:tensor([8., 16., 32.])

对应 80 * 80的特征图,生成   80 * 80的anchor和 80 * 80 的stride,anchor就是每个 1*1 网格的中心点,stride是缩放系数,大的特征图缩放系数小,用来预测小物体。

dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides  
#(1,4,8400),(1,2,8400) => (1,4,8400)

class DFL(nn.Module):
    def __init__(self, c1=16):
        """Initialize a convolutional layer with a given number of input channels."""
        super().__init__()
        self.conv = nn.Conv2d(c1, 1, 1, bias=False).requires_grad_(False)
        x = torch.arange(c1, dtype=torch.float)
        self.conv.weight.data[:] = nn.Parameter(x.view(1, c1, 1, 1))
        self.c1 = c1

    def forward(self, x):
    #(1,64,8400) => (1,4,16,8400) => (1,16,4,8400) => (1,1,4,8400) => (1,4,8400)
        """Applies a transformer layer on input tensor 'x' and returns a tensor."""
        b, _, a = x.shape  # batch, channels, anchors
        return self.conv(x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1)).view(b, 4, a)

def decode_bboxes(self, bboxes, anchors):
    """Decode bounding boxes."""
    if self.export:
        return dist2bbox(bboxes, anchors, xywh=False, dim=1)
    return dist2bbox(bboxes, anchors, xywh=True, dim=1)

def dist2bbox(distance, anchor_points, xywh=True, dim=-1):
    """Transform distance(ltrb) to box(xywh or xyxy)."""
    lt, rb = distance.chunk(2, dim)
    x1y1 = anchor_points - lt
    x2y2 = anchor_points + rb
    if xywh:
        c_xy = (x1y1 + x2y2) / 2
        wh = x2y2 - x1y1
        return torch.cat((c_xy, wh), dim)  # xywh bbox
    return torch.cat((x1y1, x2y2), dim)  # xyxy bbox

self.dfl(box):计算box偏移量

x.view(b, 4, self.c1, a).transpose(2, 1).softmax(1):先把box从(1,64,8400) => (1,4,16,8400) => (1,16,4,8400),然后对dim=1进行softmax计算,给16个距离对应的权重。

 self.conv的参数requires_grad_(False),等于x = torch.arange(c1, dtype=torch.float),固定为 tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.]),再进行 nn.Conv2d(c1, 1, 1, bias=False)计算,就相当于用softmax后的权重乘以对应的数值,得到最终的偏移量。

decode_bboxes 使用 dist2bbox函数,box的格式为[left,top,right,bottom],将box分为两部分,用中心点减去left,top,得到左上角x1y1,用中心点加上 right,bottom,得到右下角的点x2y2,这样就得到了xyxy格式(也可以转换为xywh格式)的坐标点,再乘以对应的stride,得到最终的坐标点。(1,4,8400)

y = torch.cat((dbox, cls.sigmoid()), 1)  #(1,84,8400)

将预测的坐标点和类别合并,得到最终返回结果。


 完整代码:

class Detect(nn.Module):
    """YOLOv8 Detect head for detection models."""

    dynamic = False  # force grid reconstruction
    export = False  # export mode
    shape = None
    anchors = torch.empty(0)  # init
    strides = torch.empty(0)  # init

    def __init__(self, nc=80, ch=()):
        """Initializes the YOLOv8 detection layer with specified number of classes and channels."""
        super().__init__()
        self.nc = nc  # number of classes
        self.nl = len(ch)  # number of detection layers
        self.reg_max = 16  # DFL channels (ch[0] // 16 to scale 4/8/12/16/20 for n/s/m/l/x)
        self.no = nc + self.reg_max * 4  # number of outputs per anchor
        self.stride = torch.zeros(self.nl)  # strides computed during build
        c2, c3 = max((16, ch[0] // 4, self.reg_max * 4)), max(ch[0], min(self.nc, 100))  # channels
        self.cv2 = nn.ModuleList(
            nn.Sequential(Conv(x, c2, 3), Conv(c2, c2, 3), nn.Conv2d(c2, 4 * self.reg_max, 1)) for x in ch
        )
        self.cv3 = nn.ModuleList(nn.Sequential(Conv(x, c3, 3), Conv(c3, c3, 3), nn.Conv2d(c3, self.nc, 1)) for x in ch)
        self.dfl = DFL(self.reg_max) if self.reg_max > 1 else nn.Identity()

    def inference(self, x):#[(1, 144, 80, 80),(1, 144, 40, 40),(1,144,20,20)]
        # Inference path
        shape = x[0].shape  # BCHW  (1, 144, 80, 80)
        x_cat = torch.cat([xi.view(shape[0], self.no, -1) for xi in x], 2)  #(1,144,8400)
        if self.dynamic or self.shape != shape:
            self.anchors, self.strides = (x.transpose(0, 1) for x in make_anchors(x, self.stride, 0.5)) #(2,8400),(1,8400)
            self.shape = shape #(1, 144, 80, 80)

        if self.export and self.format in ("saved_model", "pb", "tflite", "edgetpu", "tfjs"):  # avoid TF FlexSplitV ops
            box = x_cat[:, : self.reg_max * 4]
            cls = x_cat[:, self.reg_max * 4 :]
        else:
            box, cls = x_cat.split((self.reg_max * 4, self.nc), 1) #(1,64,8400),(1,80,8400)

        if self.export and self.format in ("tflite", "edgetpu"):
            # Precompute normalization factor to increase numerical stability
            # See https://github.com/ultralytics/ultralytics/issues/7371
            grid_h = shape[2]
            grid_w = shape[3]
            grid_size = torch.tensor([grid_w, grid_h, grid_w, grid_h], device=box.device).reshape(1, 4, 1)
            norm = self.strides / (self.stride[0] * grid_size)
            dbox = self.decode_bboxes(self.dfl(box) * norm, self.anchors.unsqueeze(0) * norm[:, :2])
        else:
            dbox = self.decode_bboxes(self.dfl(box), self.anchors.unsqueeze(0)) * self.strides  #(1,4,8400),(1,2,8400) => (1,4,8400)

        y = torch.cat((dbox, cls.sigmoid()), 1)  #(1,84,8400)
        return y if self.export else (y, x)

    def forward_feat(self, x, cv2, cv3):
        y = []
        for i in range(self.nl):
            y.append(torch.cat((cv2[i](x[i]), cv3[i](x[i])), 1))
        return y

    def forward(self, x):
        """Concatenates and returns predicted bounding boxes and class probabilities."""
        y = self.forward_feat(x, self.cv2, self.cv3)
        
        if self.training:
            return y

        return self.inference(y)

    def bias_init(self):
        """Initialize Detect() biases, WARNING: requires stride availability."""
        m = self  # self.model[-1]  # Detect() module
        # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1
        # ncf = math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum())  # nominal class frequency
        for a, b, s in zip(m.cv2, m.cv3, m.stride):  # from
            a[-1].bias.data[:] = 1.0  # box
            b[-1].bias.data[: m.nc] = math.log(5 / m.nc / (640 / s) ** 2)  # cls (.01 objects, 80 classes, 640 img)

    def decode_bboxes(self, bboxes, anchors):
        """Decode bounding boxes."""
        if self.export:
            return dist2bbox(bboxes, anchors, xywh=False, dim=1)
        return dist2bbox(bboxes, anchors, xywh=True, dim=1)

参考:

YOLOv8详解:损失函数、Anchor-Free、样本分配策略;以及与v5的对比_yolov8的损失函数为什么大于1-CSDN博客

标签:head,Detect,anchors,self,torch,shape,8400,80,逐行
From: https://blog.csdn.net/zhilaizhiwang/article/details/140181955

相关文章

  • 解析Torch中多头注意力`MultiheadAttention`
    前沿:这部分内容是《AttentionIsAllYouNeed》出来之后就深入研究了这篇文章的模型结构,也是之后工作那一年进行实际落地的一小部分内容。最近再次使用它,顺带读了torch官方的实现,大家风范的实现,注意很多细节,值得我们学习,也顺带放在这,之后就不再了解这块内容了,过去式了。下......
  • es库-连接工具-chrome插件:Elasticsearch-Head
    Elasticsearch-Head如何连接es数据库呢:1.下载Elasticsearch-Head插件压缩包  2.解压文件夹,是这样的: 3.打开chrome浏览器的扩展程序管理 然后,点击“加载已解压的扩展程序”: 找到并且选中你压缩es-head文件夹的根目录-》点击“选择文件夹”: 到目前就加载上es-h......
  • 【YOLOv10改进 - 注意力机制】 MHSA:多头自注意力(Multi-Head Self-Attention)
    YOLOv10目标检测创新改进与实战案例专栏专栏链接:YOLOv10创新改进有效涨点介绍摘要我们介绍了BoTNet,这是一个概念简单但功能强大的骨干架构,将自注意力引入多个计算机视觉任务,包括图像分类、物体检测和实例分割。通过仅在ResNet的最后三个瓶颈块中用全局自注意力替换......
  • 当你不小心使用了 git reset --hard HEAD^
    会话总结1.使用gitreset--hardHEAD^后如何恢复撤销的gitcommit提交知识点:gitreset--hardHEAD^会回退到前一个提交,并删除工作目录中的更改。通过gitreflog查看操作历史记录,找到被回退的提交哈希值。使用gitreset--hard<commit_hash>恢复到特定的提交。......
  • softlockup detector
    1简介从内核稳定性问题的角度来看内核安全,是基础,也是必备技能。很多时候,一个内核稳定性问题,就是造成系统安全的罪魁祸首。当出现异常死锁、Hangup、死机等问题时,watchdog的作用就很好的体现出来。Watchdog主要用于监测系统运行情况,一旦出现以上异常情况,就会重启系统,并收集cras......
  • detected dubious ownership in repository 问题彻底解决大全
    fatal:detecteddubiousownershipinrepositoryat'C:\lindexi\Code\Foo'isownedby:'S-1-5-21-469934170-xxx-xxx-1001'butthecurrentuseris:'S-1-5-21-469994170-aaa-bbb-1001' 这个问题给我很大困扰,我的电脑因为强制改用户名造成了大量bug,无法删除用户安全......
  • 更加优雅的下载文件 --- http header Content-Disposition 学习
    更加优雅的下载文件---httpheaderContent-Disposition学习在响应头中在请求头中a标签的download属性小结Content-Disposition在响应头中,告诉浏览器如何处理返回的内容,在表单提交中,说明表单字段信息。在响应头中用在响应头中,告诉浏览器如何处理返回的内容......
  • Microsoft.AspNetCore.Builder.ForwardedHeadersOptions
    答案为ai生成ForwardedHeadersOptions是ASP.NETCore中用于配置转发头部的一个类。当应用程序位于负载均衡器(https://blog.csdn.net/cyl101816/article/details/135195729)、反向代理服务器等后面时,由于HTTP请求会通过多个代理或转发,原始的请求头(如X-Forwarded-For和X-For......
  • 伪装目标检测论文阅读 VSCode:General Visual Salient and Camouflaged Object Detect
    论文link:link代码:code1.摘要  显著物体检测和伪装物体检测是相关但又不同的二元映射任务,这些任务涉及多种模态,具有共同点和独特线索,现有研究通常采用复杂的特定于任务的专家模型,可能会导致冗余和次优结果。我们引入了VSCode,这是一种具有新颖的2D提示学习的通用模型,用于......
  • 'MMDetection3D'+'waymo-open-dataset-tf-2-6-0'+'pytorc2.3.1+cu121'安装
    安装pytorc2.3.1+cu121步骤1.创建并激活一个conda环境condacreate-nmmdpython=3.8-ycondaactivatemmd步骤2.基于PyTorch官方说明安装PyTorch,例如:pip3installtorchtorchvisiontorchaudio--index-urlhttps://download.pytorch.org/whl/cu121步骤3.验......