一、本文介绍
作为入门性篇章,这里介绍了ShuffleAttention注意力在YOLOv8中的使用。包含ShuffleAttention原理分析,ShuffleAttention的代码、ShuffleAttention的使用方法、以及添加以后的yaml文件及运行记录。
二、ShuffleAttention原理分析
ShuffleAttention官方论文地址:文章
ShuffleAttention官方代码地址:官方代码
ShuffleAttention注意力机制:采用Shuffle单元有效地结合了两种类型的注意力机制。首先将通道维分组为多个子特征,然后再并行处理它们。然后,对于每个子特征,利用Shuffle Unit在空间和通道维度上描绘特征依赖性。之后,将所有子特征汇总在一起,并采用“channel shuffle”运算符来启用不同子特征之间的信息通信。
三、相关代码:
ShuffleAttention注意力的代码,如下。
class ShuffleAttention(nn.Module):
def __init__(self, channel=512, reduction=16, G=8):
super().__init__()
self.G = G
self.channel = channel
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
self.sigmoid = nn.Sigmoid()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
@staticmethod
def channel_shuffle(x, groups):
b, c, h, w = x.shape
x = x.reshape(b, groups, -1, h, w)
x = x.permute(0, 2, 1, 3, 4)
# flatten
x = x.reshape(b, -1, h, w)
return x
def forward(self, x):
b, c, h, w = x.size()
# group into subfeatures
x = x.view(b * self.G, -1, h, w) # bs*G,c//G,h,w
# channel_split
x_0, x_1 = x.chunk(2, dim=1) # bs*G,c//(2*G),h,w
# channel attention
x_channel = self.avg_pool(x_0) # bs*G,c//(2*G),1,1
x_channel = self.cweight * x_channel + self.cbias # bs*G,c//(2*G),1,1
x_channel = x_0 * self.sigmoid(x_channel)
# spatial attention
x_spatial = self.gn(x_1) # bs*G,c//(2*G),h,w
x_spatial = self.sweight * x_spatial + self.sbias # bs*G,c//(2*G),h,w
x_spatial = x_1 * self.sigmoid(x_spatial) # bs*G,c//(2*G),h,w
# concatenate along channel axis
out = torch.cat([x_channel, x_spatial], dim=1) # bs*G,c//G,h,w
out = out.contiguous().view(b, -1, h, w)
# channel shuffle
out = self.channel_shuffle(out, 2)
return out
四、YOLOv8中ShuffleAttention使用方法
1.YOLOv8中添加ShuffleAttention模块:
首先在ultralytics/nn/modules/conv.py最后添加ShuffleAttention模块的代码。
2.在conv.py的开头__all__ = 内添加ShuffleAttention模块的类别名:
3.在同级文件夹下的__init__.py内添加SimAM的相关内容:(分别是from .conv import ShuffleAttention ;以及在__all__内添加ShuffleAttention)
4.在ultralytics/nn/tasks.py进行ShuffleAttention注意力机制的注册,以及在YOLOv8的yaml配置文件中添加ShuffleAttention即可。
首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:
elif m in {CBAM,ECA,ShuffleAttention}:#添加注意力模块,没有CBAM、eca的,M删除即可
c1, c2 = ch[f], args[0]
if c2 != nc:
c2 = make_divisible(min(c2, max_channels) * width, 8)
args = [c1, *args[1:]]
然后,就是新建一个名为YOLOv8_ShuffleAttention.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_ShuffleAttention.yaml)
# Ultralytics YOLO
标签:nn,self,YOLOv8,init,ShuffleAttention,注意力,channel
From: https://blog.csdn.net/2301_79619145/article/details/142930590