一、本文介绍
作为入门性篇章,这里介绍了EMA注意力在YOLOv8中的使用。包含EMA原理分析,EMA的代码、EMA的使用方法、以及添加以后的yaml文件及运行记录。
二、EMA原理分析
EMA官方论文地址:EMA文章
EMA代码:EMA代码
EMA注意力机制(高效的多尺度注意力):通过重塑部分通道到批次维度,并将通道维度分组为多个子特征,以保留每个通道的信息并减少计算开销。EMA模块通过编码全局信息来重新校准每个并行分支中的通道权重,并通过跨维度交互来捕获像素级别的关系。
相关代码:
EMA注意力的代码,如下:
class EMA_attention(nn.Module):
def __init__(self, channels, c2=None, factor=32):
super(EMA_attention, self).__init__()
self.groups = factor
assert channels // self.groups > 0
self.softmax = nn.Softmax(-1)
self.agp = nn.AdaptiveAvgPool2d((1, 1))
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))
self.gn = nn.GroupNorm(channels // self.groups, channels // self.groups)
self.conv1x1 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=1, stride=1, padding=0)
self.conv3x3 = nn.Conv2d(channels // self.groups, channels // self.groups, kernel_size=3, stride=1, padding=1)
def forward(self, x):
b, c, h, w = x.size()
group_x = x.reshape(b * self.groups, -1, h, w) # b*g,c//g,h,w
x_h = self.pool_h(group_x)
x_w = self.pool_w(group_x).permute(0, 1, 3, 2)
hw = self.conv1x1(torch.cat([x_h, x_w], dim=2))
x_h, x_w = torch.split(hw, [h, w], dim=2)
x1 = self.gn(group_x * x_h.sigmoid() * x_w.permute(0, 1, 3, 2).sigmoid())
x2 = self.conv3x3(group_x)
x11 = self.softmax(self.agp(x1).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x12 = x2.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
x21 = self.softmax(self.agp(x2).reshape(b * self.groups, -1, 1).permute(0, 2, 1))
x22 = x1.reshape(b * self.groups, c // self.groups, -1) # b*g, c//g, hw
weights = (torch.matmul(x11, x12) + torch.matmul(x21, x22)).reshape(b * self.groups, 1, h, w)
return (group_x * weights.sigmoid()).reshape(b, c, h, w)
四、YOLOv8中EMA使用方法
1.YOLOv8中添加EMA模块:
首先在ultralytics/nn/modules/conv.py最后添加EMA模块的代码。
2.在conv.py的开头__all__ = 内添加EMA模块的类别名EMA_attention:
3.在同级文件夹下的__init__.py内添加EMA的相关内容:(分别是from .conv import EMA_attention ;以及在__all__内添加EMA_attention)
4.在ultralytics/nn/tasks.py进行EMA_attention注意力机制的注册,以及在YOLOv8的yaml配置文件中添加EMA_attention即可。
首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面(或者直接加载含有c2f模块的注册代码内,注册方式一样)添加以下注册代码:
if m in (EMA_attention):
c1, c2 = ch[f], args[0]
然后,就是新建一个名为YOLOv8_EMA.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_EMA.yaml)其中参数中nc,由自己的数据集决定。本文测试,采用的coco8数据集,有80个类别。
# Ultralytics YOLO
标签:EMA,nn,self,attention,YOLOv8,groups,注意力
From: https://blog.csdn.net/2301_79619145/article/details/142734096