一、本文介绍
本文记录的是基于GAM注意力模块的YOLOv9目标检测改进方法研究。GAM注意力模块
通过3D排列和重新设计的子模块,能够在通道和空间方面保留信息,避免了先前方法中由于信息减少和维度分离而导致的全局空间-通道交互丢失的问题。本文利用GAM
改进YOLOv9
,以增强模型的跨维度交互能力。
文章目录
二、GAM注意力原理
全局注意力机制: 保留信息以增强通道与空间的相互作用
GAM(Global Attention Mechanism)
是一种全局注意力机制,其设计目的是减少信息减少并放大全局维度交互特征,以增强深度神经网络的性能。
2.1、设计原理
- 整体结构:采用了来自CBAM的顺序通道 - 空间注意力机制,并重新设计了子模块。给定输入特征图
F
1
∈
R
C
×
H
×
W
F_{1} \in \mathbb{R}^{C ×H ×W}
F1∈RC×H×W,中间状态
F
2
F_{2}
F2和输出
F
3
F_{3}
F3的定义为:
- F 2 = M c ( F 1 ) ⊗ F 1 F_{2}=M_{c}\left(F_{1}\right) \otimes F_{1} F2=Mc(F1)⊗F1
-
F
3
=
M
s
(
F
2
)
⊗
F
2
F_{3}=M_{s}\left(F_{2}\right) \otimes F_{2}
F3=Ms(F2)⊗F2
其中 M c M_{c} Mc和 M s M_{s} Ms分别是通道和空间注意力图, ⊗ \otimes ⊗表示元素级乘法。
- 通道注意力子模块:使用3D排列来保留跨三个维度的信息,然后通过两层
MLP(多层感知机)
放大跨维度的通道 - 空间依赖性。(MLP是具有压缩比 r r r的编码器 - 解码器结构,与BAM相同。) - 空间注意力子模块:为了关注空间信息,使用两个卷积层进行空间信息融合,并使用与通道注意力子模块相同的压缩比 r r r(与BAM相同)。同时,由于最大池化会减少信息并产生负面影响,所以移除了池化以进一步保留特征图。为了防止参数显著增加,在ResNet50中采用了具有通道打乱的组卷积。
2.2、优势
- 保留信息:通过3D排列和重新设计的子模块,
GAM
能够在通道和空间方面保留信息,避免了先前方法中由于信息减少和维度分离而导致的全局空间 - 通道交互的丢失。 - 放大交互:能够放大“全局”跨维度交互,捕获所有三个维度(通道、空间宽度和空间高度)上的重要特征,从而增强了跨维度的交互能力。
- 性能提升:在CIFAR - 100和ImageNet - 1K数据集上的评估表明,
GAM
稳定地优于其他几种近期的注意力机制,无论是在ResNet还是轻量级MobileNet上,都能提高性能。例如,在ImageNet - 1K数据集上,对于ResNet18,GAM以更少的参数和更高的效率优于ABN。
论文:https://arxiv.org/pdf/2112.05561v1
源码:https://github.com/dengbuqi/GAM_Pytorch/blob/main/CAM.py
三、GAM的实现代码
GAM模块
的实现代码如下:
class GAMAttention(nn.Module):
def __init__(self, c1, c2, group=True, rate=4):
super(GAMAttention, self).__init__()
self.channel_attention = nn.Sequential(
nn.Linear(c1, int(c1 / rate)),
nn.ReLU(inplace=True),
nn.Linear(int(c1 / rate), c1),
)
self.spatial_attention = nn.Sequential(
(
nn.Conv2d(c1, c1 // rate, kernel_size=7, padding=3, groups=rate)
if group
else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3)
),
nn.BatchNorm2d(int(c1 / rate)),
nn.ReLU(inplace=True),
(
nn.Conv2d(c1 // rate, c2, kernel_size=7, padding=3, groups=rate)
if group
else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3)
),
nn.BatchNorm2d(c2),
)
def forward(self, x):
b, c, h, w = x.shape
x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)
x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)
x_channel_att = x_att_permute.permute(0, 3, 1, 2)
x = x * x_channel_att
x_spatial_att = self.spatial_attention(x).sigmoid()
x_spatial_att = channel_shuffle(x_spatial_att, 4) # last shuffle
out = x * x_spatial_att
return out
def channel_shuffle(x, groups=2): ##shuffle channel
# RESHAPE----->transpose------->Flatten
B, C, H, W = x.size()
out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()
out = out.view(B, C, H, W)
return out
四、添加步骤
4.1 修改common.py
此处需要修改的文件是models/common.py
common.py中定义了网络结构的通用模块
,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。
4.1.1 基础模块1
模块改进方法1️⃣:直接加入GAMAttention模块
。
GAMAttention模块
添加后如下:
注意❗:在4.2小节
中的yolo.py
文件中需要声明的模块名称为:GAMAttention
。
4.1.2 创新模块2⭐
模块改进方法2️⃣:基于GAMAttention模块
的RepNCSPELAN4
。
相较方法一中的直接插入注意力模块,利用注意力模块对卷积等其他模块进行改进,其新颖程度会更高一些,训练精度可能会表现的更高。
第二种改进方法是对YOLOv9
中的RepNCSPELAN4模块
进行改进。RepNCSPELAN4模块
的创新思想是将CSP
与ELAN
相结合。CSP
可以有效地分割梯度流,减少计算量的同时保持准确性。ELAN
则通过灵活的层聚合方式,增强网络的学习能力。此处的改进方法是将GAMAttention注意力模块
替换RepNCSPELAN4
中的卷积模块,生成GAMRepNCSPELAN4模块
。GAM 模块能够捕捉通道、空间宽度和空间高度等多个维度的重要特征,加强了跨维度的交互,在将其添加到RepNCSPELAN4模块
中有助于在分流过程中更好地分配注意力,减少无关信息的干扰,提高特征质量。
改进代码如下:
class GAMRepNCSPELAN4(nn.Module):
# csp-elan
def __init__(self, c1, c2, c3, c4, c5=1): # ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
self.c = c3//2
self.cv1 = Conv(c1, c3, 1, 1)
self.cv2 = nn.Sequential(RepNCSP(c3//2, c4, c5), GAMAttention(c4, c4))
self.cv3 = nn.Sequential(RepNCSP(c4, c4, c5), GAMAttention(c4, c4))
self.cv4 = Conv(c3+(2*c4), c2, 1, 1)
def forward(self, x):
y = list(self.cv1(x).chunk(2, 1))
y.extend((m(y[-1])) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
def forward_split(self, x):
y = list(self.cv1(x).split((self.c, self.c), 1))
y.extend(m(y[-1]) for m in [self.cv2, self.cv3])
return self.cv4(torch.cat(y, 1))
注意❗:在4.2小节
中的yolo.py
文件中需要声明的模块名称为:GAMRepNCSPELAN4
。
4.2 修改yolo.py
此处需要修改的文件是models/yolo.py
yolo.py用于函数调用
,我们只需要将common.py
中定义的新的模块名添加到parse_model函数
下即可。
GAMAttention模块
以及GAMRepNCSPELAN4模块
添加后如下:
五、yaml模型文件
5.1 模型改进版本一
在代码配置完成后,配置模型的YAML文件。
此处以models/detect/yolov9-c.yaml
为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-gam.yaml
。
将yolov9-c.yaml
中的内容复制到yolov9-c-gam.yaml
文件下,修改nc
数量等于自己数据中目标的数量。
在骨干网络的最后一层添加GAMAttention模块
,即下方代码中的第45行,只需要填入一个参数,通道数,和前一层通道数一致。