一、本文介绍
本文记录的是基于AIFI模块的YOLOv9目标检测改进方法研究。AIFI
是RT-DETR
中高效混合编码器的一部分,利用其改进YOLOv9
模型,使网络在深层能够更好的捕捉到概念实体之间的联系,并有助于后续模块对对象进行定位和识别。
文章目录
二、AIFI设计原理
RT-DETR
模型结构:
AIFI(Attention-based Intra-scale Feature Interaction)
模块的相关信息如下:
2.1、设计原理
AIFI
是RT-DETR
中高效混合编码器的一部分。为了克服多尺度Transformer编码器
中存在的计算瓶颈,RT-DETR
对编码器结构进行了重新思考。
由于从低级特征中提取出的高级特征包含了关于对象的丰富语义信息,对级联的多尺度特征进行特征交互是冗余的。因此,AIFI
基于此设计,通过使用单尺度Transformer
编码器仅在S5特征层
上进行尺度内交互,进一步降低了计算成本。
对高级特征应用自注意力操作,能够捕捉到概念实体之间的联系,这有助于后续模块对对象进行定位和识别。而低级特征由于缺乏语义概念,且与高级特征交互存在重复和混淆的风险,因此其尺度内交互是不必要的。
2.2、优势
与基准模型相比,AIFI不仅显著降低了延迟(快35%),而且提高了准确性(AP高0.4%)。
论文:https://arxiv.org/abs/2304.08069
源码:https://github.com/lyuwenyu/RT-DETR
三、AIFI模块的实现代码
AIFI模块
的实现代码如下:
class TransformerEncoderLayer(nn.Module):
"""Defines a single layer of the transformer encoder."""
def __init__(self, c1, cm=2048, num_heads=8, dropout=0.0, act=nn.GELU(), normalize_before=False):
"""Initialize the TransformerEncoderLayer with specified parameters."""
super().__init__()
from ...utils.torch_utils import TORCH_1_9
if not TORCH_1_9:
raise ModuleNotFoundError(
"TransformerEncoderLayer() requires torch>=1.9 to use nn.MultiheadAttention(batch_first=True)."
)
self.ma = nn.MultiheadAttention(c1, num_heads, dropout=dropout, batch_first=True)
# Implementation of Feedforward model
self.fc1 = nn.Linear(c1, cm)
self.fc2 = nn.Linear(cm, c1)
self.norm1 = nn.LayerNorm(c1)
self.norm2 = nn.LayerNorm(c1)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.act = act
self.normalize_before = normalize_before
@staticmethod
def with_pos_embed(tensor, pos=None):
"""Add position embeddings to the tensor if provided."""
return tensor if pos is None else tensor + pos
def forward_post(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
"""Performs forward pass with post-normalization."""
q = k = self.with_pos_embed(src, pos)
src2 = self.ma(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.fc2(self.dropout(self.act(self.fc1(src))))
src = src + self.dropout2(src2)
return self.norm2(src)
def forward_pre(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
"""Performs forward pass with pre-normalization."""
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.ma(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.fc2(self.dropout(self.act(self.fc1(src2))))
return src + self.dropout2(src2)
def forward(self, src, src_mask=None, src_key_padding_mask=None, pos=None):
"""Forward propagates the input through the encoder module."""
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class AIFI(TransformerEncoderLayer):
"""Defines the AIFI transformer layer."""
def __init__(self, c1, cm=2048, num_heads=8, dropout=0, act=nn.GELU(), normalize_before=False):
"""Initialize the AIFI instance with specified parameters."""
super().__init__(c1, cm, num_heads, dropout, act, normalize_before)
def forward(self, x):
"""Forward pass for the AIFI transformer layer."""
c, h, w = x.shape[1:]
pos_embed = self.build_2d_sincos_position_embedding(w, h, c)
# Flatten [B, C, H, W] to [B, HxW, C]
x = super().forward(x.flatten(2).permute(0, 2, 1), pos=pos_embed.to(device=x.device, dtype=x.dtype))
return x.permute(0, 2, 1).view([-1, c, h, w]).contiguous()
@staticmethod
def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
"""Builds 2D sine-cosine position embedding."""
assert embed_dim % 4 == 0, "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
grid_w = torch.arange(w, dtype=torch.float32)
grid_h = torch.arange(h, dtype=torch.float32)
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="ij")
pos_dim = embed_dim // 4
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
omega = 1.0 / (temperature**omega)
out_w = grid_w.flatten()[..., None] @ omega[None]
out_h = grid_h.flatten()[..., None] @ omega[None]
return torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], 1)[None]
四、添加步骤
4.1 修改common.py
此处需要修改的文件是models/common.py
common.py中定义了网络结构的通用模块
,我们想要加入新的模块就只需要将模块代码放到这个文件内即可。
此时需要将上方实现的代码添加到common.py
中。
注意❗:在4.2小节
中的yolo.py
文件中需要声明的模块名称为:AIFI
。
4.2 修改yolo.py
此处需要修改的文件是models/yolo.py
yolo.py用于函数调用
,我们只需要将common.py
中定义的新的模块名添加到parse_model函数
下即可。
AIFI
模块添加后如下:
还需在此函数下添加如下代码:
elif m is AIFI:
args = [ch[f], *args]
五、yaml模型文件
5.1 模型改进⭐
在代码配置完成后,配置模型的YAML文件。
此处以models/detect/yolov9-c.yaml
为例,在同目录下创建一个用于自己数据集训练的模型文件yolov9-c-AIFI.yaml
。
将yolov9-c.yaml
中的内容复制到yolov9-c-AIFI.yaml
文件下,修改nc
数量等于自己数据中目标的数量。