首页 > 其他分享 >YOLOv9改进策略【注意力机制篇】| GAM全局注意力机制: 保留信息以增强通道与空间的相互作用

YOLOv9改进策略【注意力机制篇】| GAM全局注意力机制: 保留信息以增强通道与空间的相互作用

时间:2024-09-06 14:51:42浏览次数:15  
标签:ADown models YOLOv9 common RepNCSPELAN4 机制 512 256 注意力

一、本文介绍

本文记录的是基于GAM注意力模块的YOLOv9目标检测改进方法研究GAM注意力模块通过3D排列和重新设计的子模块,能够在通道和空间方面保留信息,避免了先前方法中由于信息减少和维度分离而导致的全局空间-通道交互丢失的问题。本文利用GAM改进YOLOv9,以增强模型的跨维度交互能力。

文章目录


二、GAM注意力原理

全局注意力机制: 保留信息以增强通道与空间的相互作用

GAM(Global Attention Mechanism)是一种全局注意力机制,其设计目的是减少信息减少并放大全局维度交互特征,以增强深度神经网络的性能。

2.1、设计原理

  1. 整体结构:采用了来自CBAM的顺序通道 - 空间注意力机制,并重新设计了子模块。给定输入特征图 F 1 ∈ R C × H × W F_{1} \in \mathbb{R}^{C ×H ×W} F1​∈RC×H×W,中间状态 F 2 F_{2} F2​和输出 F 3 F_{3} F3​的定义为:
    • F 2 = M c ( F 1 ) ⊗ F 1 F_{2}=M_{c}\left(F_{1}\right) \otimes F_{1} F2​=Mc​(F1​)⊗F1​
    • F 3 = M s ( F 2 ) ⊗ F 2 F_{3}=M_{s}\left(F_{2}\right) \otimes F_{2} F3​=Ms​(F2​)⊗F2​
      其中 M c M_{c} Mc​和 M s M_{s} Ms​分别是通道和空间注意力图, ⊗ \otimes ⊗表示元素级乘法。
  2. 通道注意力子模块:使用3D排列来保留跨三个维度的信息,然后通过两层MLP(多层感知机)放大跨维度的通道 - 空间依赖性。(MLP是具有压缩比 r r r的编码器 - 解码器结构,与BAM相同。)
  3. 空间注意力子模块:为了关注空间信息,使用两个卷积层进行空间信息融合,并使用与通道注意力子模块相同的压缩比 r r r(与BAM相同)。同时,由于最大池化会减少信息并产生负面影响,所以移除了池化以进一步保留特征图。为了防止参数显著增加,在ResNet50中采用了具有通道打乱的组卷积。

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

2.2、优势

  1. 保留信息:通过3D排列和重新设计的子模块,GAM能够在通道和空间方面保留信息,避免了先前方法中由于信息减少和维度分离而导致的全局空间 - 通道交互的丢失。
  2. 放大交互:能够放大“全局”跨维度交互,捕获所有三个维度(通道、空间宽度和空间高度)上的重要特征,从而增强了跨维度的交互能力。
  3. 性能提升:在CIFAR - 100和ImageNet - 1K数据集上的评估表明,GAM稳定地优于其他几种近期的注意力机制,无论是在ResNet还是轻量级MobileNet上,都能提高性能。例如,在ImageNet - 1K数据集上,对于ResNet18,GAM以更少的参数和更高的效率优于ABN。

论文:https://arxiv.org/pdf/2112.05561v1
源码:https://github.com/dengbuqi/GAM_Pytorch/blob/main/CAM.py

三、GAM的实现代码

GAM模块的实现代码如下:

class GAMAttention(nn.Module):

    def __init__(self, c1, c2, group=True, rate=4):
        super(GAMAttention, self).__init__()

        self.channel_attention = nn.Sequential(
            nn.Linear(c1, int(c1 / rate)),
            nn.ReLU(inplace=True),
            nn.Linear(int(c1 / rate), c1),
        )
        self.spatial_attention = nn.Sequential(
            (
                nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate)
                if group
                else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3)
            ),
            nn.BatchNorm2d(int(c1 / rate)),
            nn.ReLU(inplace=True),
            (
                nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate)
                if group
                else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3)
            ),
            nn.BatchNorm2d(c2),
        )

    def forward(self, x):
        b, c, h, w = x.shape
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)
        x = x * x_channel_att

        x_spatial_att = self.spatial_attention(x).sigmoid()
        x_spatial_att = channel_shuffle(x_spatial_att, 4)  # last shuffle
        out = x * x_spatial_att
        return out


def channel_shuffle(x, groups=2):  ##shuffle channel
    # RESHAPE----->transpose------->Flatten
    B, C, H, W = x.size()
    out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
    out = out.view(B, C, H, W)
    return out


四、添加步骤

4.1 修改common.py

此处需要修改的文件是models/common.py

common.py中定义了网络结构的通用模块,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。

4.1.1 基础模块1

模块改进方法1️⃣:直接加入GAMAttention模块
GAMAttention模块添加后如下:

在这里插入图片描述

注意❗:在4.2小节中的yolo.py文件中需要声明的模块名称为:GAMAttention

4.1.2 创新模块2⭐

模块改进方法2️⃣:基于GAMAttention模块RepNCSPELAN4

相较方法一中的直接插入注意力模块,利用注意力模块对卷积等其他模块进行改进,其新颖程度会更高一些,训练精度可能会表现的更高。

第二种改进方法是对YOLOv9中的RepNCSPELAN4模块进行改进。RepNCSPELAN4模块的创新思想是将CSPELAN相结合。CSP可以有效地分割梯度流,减少计算量的同时保持准确性。ELAN则通过灵活的层聚合方式,增强网络的学习能力。此处的改进方法是将GAMAttention注意力模块替换RepNCSPELAN4中的卷积模块,生成GAMRepNCSPELAN4模块GAM 模块能够捕捉通道、空间宽度和空间高度等多个维度的重要特征,加强了跨维度的交互,在将其添加到RepNCSPELAN4模块中有助于在分流过程中更好地分配注意力,减少无关信息的干扰,提高特征质量。

改进代码如下:

class GAMRepNCSPELAN4(nn.Module):
    # csp-elan
    def __init__(self, c1, c2, c3, c4, c5=1):  # ch_in, ch_out, number, shortcut, groups, expansion
        super().__init__()
        self.c = c3//2
        self.cv1 = Conv(c1, c3, 1, 1)
        self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), GAMAttention(c4, c4))
        self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), GAMAttention(c4, c4))
        self.cv4 = Conv(c3+(2*c4), c2, 1, 1)

    def forward(self, x):
        y = list(self.cv1(x).chunk(2, 1))
        y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
        return self.cv4(torch.cat(y, 1))

    def forward_split(self, x):
        y = list(self.cv1(x).split((self.c, self.c), 1))
        y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
        return self.cv4(torch.cat(y, 1))

在这里插入图片描述

注意❗:在4.2小节中的yolo.py文件中需要声明的模块名称为:GAMRepNCSPELAN4

4.2 修改yolo.py

此处需要修改的文件是models/yolo.py

yolo.py用于函数调用,我们只需要将common.py中定义的新的模块名添加到parse_model函数下即可。

GAMAttention模块以及GAMRepNCSPELAN4模块添加后如下:

在这里插入图片描述


五、yaml模型文件

5.1 模型改进版本一

在代码配置完成后,配置模型的YAML文件。

此处以models/detect/yolov9-c.yaml为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-gam.yaml

yolov9-c.yaml中的内容复制到yolov9-c-gam.yaml文件下,修改nc数量等于自己数据中目标的数量。
在骨干网络的最后一层添加GAMAttention模块,即下方代码中的第45行,只需要填入一个参数,通道数,和前一层通道数一致

标签:ADown,models,YOLOv9,common,RepNCSPELAN4,机制,512,256,注意力
From: https://blog.csdn.net/qq_42591591/article/details/141864873

相关文章

  • 从内存层面分析Java 参数传递机制
    在Java中,理解参数传递机制对于编写高效和可维护的代码至关重要。本文将探讨基本数据类型和引用数据类型的参数传递方式,并介绍System.identityHashCode方法及其作用。我们将结合栈帧的概念,通过示例代码来详细解释这些机制。System.identityHashCode的作用System.ident......
  • YOLOv8改进 | 注意力篇 | YOLOv8引入YOLO-Face提出的SEAM注意力机制优化物体遮挡检测
    1. SEAM介绍1.1 摘要:近年来,基于深度学习的人脸检测算法取得了长足的进步。这些算法通常可以分为两类,即像FasterR-CNN这样的两级检测器和像YOLO这样的一级检测器。由于精度和速度之间具有更好的平衡,一级探测器已广泛应用于许多应用中。在本文中,我们提出了一种基于......
  • 逐行讲解Transformer的代码实现和原理讲解:多头掩码注意力机制
    视频详细讲解(一行一行代码讲解实现过程):逐行讲解Transformer的代码实现和原理讲解:多头掩码注意力机制(1)_哔哩哔哩_bilibili1多头掩码注意力机制总体流程【总体流程图说明】【12个块】【多头掩码注意力机制公式】【计算公式对应的步骤】2向量相似度计算2.1点积向......
  • 【学习笔记】SSL证书安全机制之证书验证
    前言:每当Client从Server收到一张证书,有2件事Client需要去验证:证书是否有效?证书只是文件中的文本Client如何知道内容能够信任?Server是否是证书真正的拥有者?证书可以公开获取Client如何知道Server是真正的拥有者?1、证书是否有效?CertificateAuthority(须知道CA是证书......
  • Falcon Mamba: 首个高效的无注意力机制 7B 模型
    FalconMamba是由阿布扎比的TechnologyInnovationInstitute(TII)开发并基于TIIFalconMamba7BLicense1.0的开放获取模型。该模型是开放获取的,所以任何人都可以在HuggingFace生态系统中这里使用它进行研究或应用。在这篇博客中,我们将深入模型的设计决策、探究模......
  • AbMole|DNA双链断裂修复中的序列与染色质特征:MRX复合体的作用与机制
     在生物学领域中,DNA双链断裂(DSB)作为一种极具破坏性的基因组损伤,其准确且高效的修复对于维持细胞基因组的稳定性和功能至关重要。由来自哥伦比亚大学欧文医学中心微生物学与免疫学系的RobertGnügge和瑞士苏黎世工业大学(ETH)生物化学研究所生物系的 GiordanoReginato,Petr......
  • 【前端面试】事件监听机制&React 的事件系统实现
    目的React实现了自己的事件系统,主要是为了解决以下几个问题:跨浏览器兼容性:不同的浏览器在处理DOM事件时有不同的实现,React的事件系统抽象了这些差异,提供了一致的API给开发者使用。性能优化:React可以对事件进行池化(Pooling),这意味着事件对象可以在事件处理过程......
  • YOLOv8改进 | 注意力篇 | YOLOv8引入MSCAAttention(MSCA)注意力机制
    1. MSCA介绍1.1 摘要:我们提出了SegNeXt,一种用于语义分割的简单卷积网络架构。由于自注意力在编码空间信息方面的效率,最近基于变压器的模型在语义分割领域占据了主导地位。在本文中,我们证明卷积注意力是一种比Transformer中的自注意力机制更高效、更有效的编码上下文......
  • YOLOv8改进:CA注意力机制【注意力系列篇】(附详细的修改步骤,以及代码,CA目标检测效果由
    如果实验环境尚未搭建成功,可以参考这篇文章->【YOLOv8超详细环境搭建以及模型训练(GPU版本)】文章链接为:http://t.csdnimg.cn/8ZmAm---------------------------------------------------------------------------​------------------------------------------------------1......
  • 每天五分钟深度学习:广播机制(以python语言为例)
    本文重点因为向量化的计算方式导致效率的提升,所以现在很多时候,我们都是用向量化的计算,但是向量化计算有一个问题让人头痛就是维度的问题,本节课程我们将讲解python中的广播机制,你会发现这个机制的优秀之处。代码实例importnumpyasnpa=np.random.randn(3,4)b=np.random.r......