一、本文介绍
作为入门性篇章,这里介绍了A2-Nets网络注意力在YOLOv8中的使用。包含A2-Nets原理分析,A2-Nets的代码、A2-Nets的使用方法、以及添加以后的yaml文件及运行记录。
二、A2-Nets原理分析
A2-Nets官方论文地址:A2-Nets文章
A2-Nets注意力机制(双重注意力机制):它从输入图像/视频的整个时空空间中聚集和传播信息全局特征,使后续卷积层能够有效地从整个空间中访问特征。采用双注意机制(包括Spatial Attention和Channel Attention。Spatial Attention用于捕获图像中不同空间位置的重要性,而Channel Attention用于捕获图像中不同通道的重要性),分两步进行设计,第一步通过二阶注意池将整个空间的特征聚集成一个紧凑的集合,第二步通过另一个注意自适应地选择特征并将其分配到每个位置。
相关代码:
A2-Nets注意力的代码,如下。
import numpy as np
import torch
from torch import nn
from torch.nn import init
from torch.nn import functional as F
class DoubleAttention(nn.Module):
def __init__(self, in_channels, c_m=128, c_n=128, reconstruct=True):
super().__init__()
self.in_channels = in_channels
self.reconstruct = reconstruct
self.c_m = c_m
self.c_n = c_n
self.convA = nn.Conv2d(in_channels, c_m, 1)
self.convB = nn.Conv2d(in_channels, c_n, 1)
self.convV = nn.Conv2d(in_channels, c_n, 1)
if self.reconstruct:
self.conv_reconstruct = nn.Conv2d(c_m, in_channels, kernel_size=1)
self.init_weights()
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant_(m.weight, 1)
init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal_(m.weight, std=0.001)
if m.bias is not None:
init.constant_(m.bias, 0)
def forward(self, x):
b, c, h, w = x.shape
assert c == self.in_channels
A = self.convA(x) # b,c_m,h,w
B = self.convB(x) # b,c_n,h,w
V = self.convV(x) # b,c_n,h,w
tmpA = A.view(b, self.c_m, -1)
attention_maps = F.softmax(B.view(b, self.c_n, -1))
attention_vectors = F.softmax(V.view(b, self.c_n, -1), dim=-1)
# step 1: feature gating
global_descriptors = torch.bmm(tmpA, attention_maps.permute(0, 2, 1)) # b.c_m,c_n
# step 2: feature distribution
tmpZ = global_descriptors.matmul(attention_vectors) # b,c_m,h*w
tmpZ = tmpZ.view(b, self.c_m, h, w) # b,c_m,h,w
if self.reconstruct:
tmpZ = self.conv_reconstruct(tmpZ)
return tmpZ
四、YOLOv8中DoubleAttention使用方法
1.YOLOv8中添加DoubleAttention模块:
首先在ultralytics/nn/modules/conv.py最后添加DoubleAttention模块的代码。
2.在conv.py的开头__all__ = 内添加DoubleAttention模块的类别名(A2-Nets的类别名在本文中为DoubleAttention)
3.在同级文件夹下的__init__.py内添加A2-Nets的相关内容:(分别是from .conv import DoubleAttention ;以及在__all__内添加DoubleAttention)
4.在ultralytics/nn/tasks.py进行SK注意力机制的注册,以及在YOLOv8的yaml配置文件中添加DoubleAttention即可。
首先打开task.py文件,按住Ctrl+F,输入parse_model进行搜索。找到parse_model函数。在其最后一个else前面添加以下注册代码:(本文续接上篇文章,加在了CBAM、ECA、SKAttention的位置)
elif m in {CBAM,ECA,SKAttention,DoubleAttention}:#添加注意力模块,没有CBAM、ECA、SKAttention的,将CBAM、ECA、SKAttention删除即可
c1, c2 = ch[f], args[0]
if c2 != nc:
c2 = make_divisible(min(c2, max_channels) * width, 8)
args = [c1, *args[1:]]
然后,就是新建一个名为YOLOv8_DoubleAttention.yaml的配置文件:(路径:ultralytics/cfg/models/v8/YOLOv8_DoubleAttention.yaml)
# Ultralytics YOLO
标签:DoubleAttention,nn,Double,self,Networks,init,A2,Nets,注意力
From: https://blog.csdn.net/2301_79619145/article/details/142601192