首页 > 其他分享 >【YOLOv8改进】DAT(Deformable Attention):可变性注意力 (论文笔记+引入代码)

【YOLOv8改进】DAT(Deformable Attention):可变性注意力 (论文笔记+引入代码)

时间:2024-06-06 20:55:51浏览次数:32  
标签:torch Deformable self Attention groups offset attn 可变性 ref

YOLO目标检测创新改进与实战案例专栏

专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLO基础解析+创新改进+实战案例

摘要

Transformers最近在各种视觉任务中展现出了优越的性能。较大甚至是全局的感受野赋予了Transformer模型比其卷积神经网络(CNN)对手更强的表征能力。然而,简单地扩大感受野也带来了几个问题。一方面,使用密集注意力(例如在ViT中)会导致过高的内存和计算成本,并且特征可能会受到兴趣区域之外的无关部分的影响。另一方面,PVT或Swin Transformer采用的稀疏注意力对数据不敏感,可能限制了建模长距离关系的能力。为了解决这些问题,我们提出了一种新型的可变形自注意力模块,其中在自注意力中键和值对的位置是以数据为基础选择的。这种灵活的方案使自注意力模块能够聚焦于相关区域并捕捉更多信息特征。在此基础上,我们提出了Deformable Attention Transformer,这是一种用于图像分类和密集预测任务的通用主干模型,具有可变形注意力。广泛的实验表明,我们的模型在综合基准测试中实现了持续改进的结果。代码可在https://github.com/LeapLabTHU/DAT获取。

基本原理

关键

  1. 数据依赖的位置选择:Deformable Attention允许在自注意力机制中以数据依赖的方式选择键和值对的位置,使模型能够根据输入数据动态调整注意力的焦点。
  2. 灵活的偏移学习:通过学习偏移量,Deformable Attention可以将关键点和值移动到重要区域,从而提高模型对关键特征的捕获能力。
  3. 全局键共享:Deformable Attention学习一组全局键,这些键在不同的视觉标记之间共享,有助于模型捕获长距离的相关性。
  4. 空间自适应机制:Deformable Attention可以根据输入数据的特征动态调整注意力模式,从而适应不同的视觉任务和场景。

通过相对于Swin-Transformer和PVT的改进,加入了可变形机制,同时控制网络不增加太多的计算量。作者认为,缩小q对应的k的范围,能够减少无关信息的干扰,增强信息的捕捉,于是引入了DCN机制到注意力模块中,提出了一种新的注意力模块:可变形多头注意力模块。该模块通过对k和v进行DCN偏移后再计算注意力,从而提升了性能。

在可变形多头注意力模块中,输入特征图像 $x \in \mathbb{R}^{H \times W \times C}$ 生成一个参考网格,其中参考点 $p \in \mathbb{R}^{H_G \times W_G \times 2}$。该网格是从输入特征图 $x$ 降采样而来,降采样系数为 $r$, $H_G = H / r, W_G = W / r$。参考点的值代表的是坐标值 $(0, 0), \ldots, (H_G - 1, W_G - 1)$,再归一化到 $[-1, +1]$。

输入特征图像 $x$ 通过线性投影得到 $q = x W_q$,再输入到一个轻量级子网络offset network,生成偏移量 $\Delta p = \theta_{\text{offset}}(q)$。为了稳定训练过程,使用了一些预定义的因子来衡量 $\Delta p$ 的振幅,以防止太大的offset,即 $\Delta p \leftarrow \text{sinh}(\Delta p)$。

然后将获得的offset作用在参考点上,获得变形点的位置,进行特征采样(双线性插值)得到 $\hat{x}$,再通过投影矩阵生成Key和Value, $\hat{k} = \hat{x} W_k, \hat{v} = \hat{x} W_v$。

$qkv$进行多头注意力计算,同时加入相对位置偏移嵌入。最后将获得的多头特征拼接起来,通过投影矩阵获得最终的注意力模块输出 $Z$。

yolov8 引入

class DAttentionBaseline(nn.Module):

   def __init__(
       self, q_size, kv_size, n_heads, n_head_channels, n_groups,
       attn_drop, proj_drop, stride, 
       offset_range_factor, use_pe, dwc_pe,
       no_off, fixed_pe, ksize, log_cpb
   ):
       # 初始化函数,定义了所需的参数
       super().__init__()
       self.dwc_pe = dwc_pe  # 是否使用深度卷积位置编码
       self.n_head_channels = n_head_channels  # 每个头的通道数
       self.scale = self.n_head_channels ** -0.5  # 缩放因子,等于每个头的通道数的负0.5次方
       self.n_heads = n_heads  # 多头注意力机制中的头数
       self.q_h, self.q_w = q_size  # query的高和宽
       self.kv_h, self.kv_w = self.q_h // stride, self.q_w // stride  # 计算键值对的高和宽
       self.nc = n_head_channels * n_heads  # 总的通道数
       self.n_groups = n_groups  # 分组数
       self.n_group_channels = self.nc // self.n_groups  # 每组的通道数
       self.n_group_heads = self.n_heads // self.n_groups  # 每组的头数
       self.use_pe = use_pe  # 是否使用位置编码
       self.fixed_pe = fixed_pe  # 是否使用固定的位置编码
       self.no_off = no_off  # 是否禁用偏移
       self.offset_range_factor = offset_range_factor  # 偏移范围因子
       self.ksize = ksize  # 卷积核尺寸
       self.log_cpb = log_cpb  # 是否使用对数相对位置偏置
       self.stride = stride  # 步幅
       kk = self.ksize
       pad_size = kk // 2 if kk != stride else 0  # 计算填充大小

       # 定义卷积偏移网络
       self.conv_offset = nn.Sequential(
           nn.Conv2d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
           LayerNormProxy(self.n_group_channels),  # 使用LayerNormProxy进行归一化
           nn.GELU(),  # 使用GELU激活函数
           nn.Conv2d(self.n_group_channels, 2, 1, 1, 0, bias=False)  # 输出偏移量
       )
       if self.no_off:
           for m in self.conv_offset.parameters():
               m.requires_grad_(False)  # 如果不使用偏移,禁用偏移网络的参数更新

       # 定义投影层
       self.proj_q = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # query投影
       )

       self.proj_k = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # key投影
       )

       self.proj_v = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # value投影
       )

       self.proj_out = nn.Conv2d(
           self.nc, self.nc,
           kernel_size=1, stride=1, padding=0  # 输出投影
       )

       self.proj_drop = nn.Dropout(proj_drop, inplace=True)  # 投影层的Dropout
       self.attn_drop = nn.Dropout(attn_drop, inplace=True)  # 注意力层的Dropout

       # 相对位置嵌入的定义
       if self.use_pe and not self.no_off:
           if self.dwc_pe:
               self.rpe_table = nn.Conv2d(
                   self.nc, self.nc, kernel_size=3, stride=1, padding=1, groups=self.nc)  # 深度卷积位置编码
           elif self.fixed_pe:
               self.rpe_table = nn.Parameter(
                   torch.zeros(self.n_heads, self.q_h * self.q_w, self.kv_h * self.kv_w)
               )
               trunc_normal_(self.rpe_table, std=0.01)  # 截断正态分布初始化
           elif self.log_cpb:
               # 借用自Swin-V2
               self.rpe_table = nn.Sequential(
                   nn.Linear(2, 32, bias=True),
                   nn.ReLU(inplace=True),
                   nn.Linear(32, self.n_group_heads, bias=False)
               )
           else:
               self.rpe_table = nn.Parameter(
                   torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
               )
               trunc_normal_(self.rpe_table, std=0.01)  # 截断正态分布初始化
       else:
           self.rpe_table = None

   @torch.no_grad()
   def _get_ref_points(self, H_key, W_key, B, dtype, device):
       # 获取参考点
       ref_y, ref_x = torch.meshgrid(
           torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
           torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
           indexing='ij'  # 保持矩阵索引一致
       )
       ref = torch.stack((ref_y, ref_x), -1)  # 堆叠y和x坐标
       ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # 扩展维度,适应批量和分组

       return ref
   
   @torch.no_grad()
   def _get_q_grid(self, H, W, B, dtype, device):
       # 获取query网格
       ref_y, ref_x = torch.meshgrid(
           torch.arange(0, H, dtype=dtype, device=device),
           torch.arange(0, W, dtype=dtype, device=device),
           indexing='ij'  # 保持矩阵索引一致
       )
       ref = torch.stack((ref_y, ref_x), -1)  # 堆叠y和x坐标
       ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)  # 归一化到[-1, 1]范围
       ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1)  # 扩展维度,适应批量和分组

       return ref

   def forward(self, x):
       # 前向传播函数
       B, C, H, W = x.size()  # 获取输入的尺寸
       dtype, device = x.dtype, x.device

       q = self.proj_q(x)  # 对输入x进行query投影
       q_off = einops.rearrange(q, 'b (g c) h w -> (b g) c h w', g=self.n_groups, c=self.n_group_channels)  # 重排列tensor的维度
       offset = self.conv_offset(q_off).contiguous()  # 计算偏移量
       Hk, Wk = offset.size(2), offset.size(3)  # 获取偏移量的高和宽
       n_sample = Hk * Wk  # 计算采样点数量

       if self.offset_range_factor >= 0 and not self.no_off:
           offset_range = torch.tensor([1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(1, 2, 1, 1)
           offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)

       offset = einops.rearrange(offset, 'b p h w -> b h w p')
       reference = self._get_ref_points(Hk, Wk, B, dtype, device)

       if self.no_off:
           offset = offset.fill_(0.0)

       if self.offset_range_factor >= 0:
           pos = offset + reference
       else:
           pos = (offset + reference).clamp(-1., +1.)

       if self.no_off:
           x_sampled = F.avg_pool2d(x, kernel_size=self.stride, stride=self.stride)
           assert x_sampled.size(2) == Hk and x_sampled.size(3) == Wk, f"Size is {x_sampled.size()}"
       else:
           x_sampled = F.grid_sample(
               input=x.reshape(B * self.n_groups, self.n_group_channels, H, W), 
               grid=pos[..., (1, 0)],  # y, x -> x, y
               mode='bilinear', align_corners=True)  # 进行双线性插值采样

       x_sampled = x_sampled.reshape(B, C, 1, n_sample)

       q = q.reshape(B * self.n_heads, self.n_head_channels, H * W)
       k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
       v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)

       attn = torch.einsum('b c m, b c n -> b m n', q, k)  # 计算注意力权重
       attn = attn.mul(self.scale)

       if self.use_pe and (not self.no_off):
           if self.dwc_pe:
               residual_lepe = self.rpe_table(q.reshape(B, C, H, W)).reshape(B * self.n_heads, self.n_head_channels, H * W)
           elif self.fixed_pe:
               rpe_table = self.rpe_table
               attn_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
               attn = attn + attn_bias.reshape(B * self.n_heads, H * W, n_sample)
           elif self.log_cpb:
               q_grid = self._get_q_grid(H, W, B, dtype, device)
               displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(4.0)  # 计算位移
               displacement = torch.sign(displacement) * torch.log2(torch.abs(displacement) + 1.0) / np.log2(8.0)
               attn_bias = self.rpe_table(displacement)  # 计算相对位置嵌入偏置
               attn = attn + einops.rearrange(attn_bias, 'b m n h -> (b h) m n', h=self.n_group_heads)
           else:
               rpe_table = self.rpe_table
               rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
               q_grid = self._get_q_grid(H, W, B, dtype, device)
               displacement = (q_grid.reshape(B * self.n_groups, H * W, 2).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 2).unsqueeze(1)).mul(0.5)
               attn_bias = F.grid_sample(
                   input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
                   grid=displacement[..., (1, 0)],
                   mode='bilinear', align_corners=True)  # 双线性插值计算相对位置偏置

               attn_bias = attn_bias.reshape(B * self.n_heads, H * W, n_sample)
               attn = attn + attn_bias

       attn = F.softmax(attn, dim=2)  # 对注意力权重进行softmax
       attn = self.attn_drop(attn)

       out = torch.einsum('b m n, b c n -> b c m', attn, v)  # 计算注意力输出

       if self.use_pe and self.dwc_pe:
           out = out + residual_lepe
       out = out.reshape(B, C, H, W)

       y = self.proj_drop(self.proj_out(out))  # 投影输出并进行Dropout

       return y, pos.reshape(B, self.n_groups, Hk, Wk, 2), reference.reshape(B, self.n_groups, Hk, Wk, 2)

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139193465

标签:torch,Deformable,self,Attention,groups,offset,attn,可变性,ref
From: https://www.cnblogs.com/banxia-frontend/p/18235996

相关文章

  • 【YOLOv8改进】D-LKA Attention:可变形大核注意力 (论文笔记+引入代码)
    YOLO目标检测创新改进与实战案例专栏专栏目录:YOLO有效改进系列及项目实战目录包含卷积,主干注意力,检测头等创新机制以及各种目标检测分割项目实战案例专栏链接:YOLO基础解析+创新改进+实战案例摘要医学图像分割在Transformer模型的应用下取得了显著进步,这些模型在捕......
  • 基于双向长短时记忆神经网络结合多头注意力机制BiLSTM-Mutilhead-Attention实现柴油机
    %加载数据集和标签load(‘diesel_dataset.mat’);%假设数据集存储在diesel_dataset.mat文件中data=diesel_dataset.data;labels=diesel_dataset.labels;%数据预处理%这里假设你已经完成了数据的预处理,包括特征提取、归一化等步骤%划分训练集和测试集[tra......
  • 【YOLOv8改进】D-LKA Attention:可变形大核注意力 (论文笔记+引入代码)
    摘要医学图像分割在Transformer模型的应用下取得了显著进步,这些模型在捕捉远距离上下文和全局语境信息方面表现出色。然而,这些模型的计算需求随着token数量的平方增加,限制了其深度和分辨率能力。大多数现有方法以逐片处理三维体积图像数据(称为伪3D),这忽略了重要的片间信息,从而降低......
  • 【CNN分类】基于马尔可夫转移场卷积网络多头注意力机制 MTF-CNN-Mutilhead-Attention
    马尔可夫转移场卷积神经网络是在传统卷积神经网络的基础上,引入了马尔可夫随机场模型来捕捉特征之间的相关性。这种方法能够更好地提取特征并增强模型的学习能力。而多头注意力机制则可以进一步增强模型对关键特征的关注,提高故障诊断的准确性。下面是一个基于MATLAB的MTF-......
  • 基于 MATLAB 的麻雀算法 (SSA) 优化注意力机制卷积神经网络结合门控循环单元 (SSA-Att
    鱼弦:公众号【红尘灯塔】,CSDN博客专家、内容合伙人、新星导师、全栈领域优质创作者、51CTO(Top红人+专家博主)、github开源爱好者(go-zero源码二次开发、游戏后端架构https://github.com/Peakchen)基于MATLAB的麻雀算法(SSA)优化注意力机制卷积神经网络结合门控循环单元......
  • 基于GWO灰狼优化的CNN-GRU-Attention的时间序列回归预测matlab仿真
    1.算法运行效果图预览优化前     优化后     2.算法运行软件版本matlab2022a 3.算法理论概述      时间序列回归预测是数据分析的重要领域,旨在根据历史数据预测未来时刻的数值。近年来,深度学习模型如卷积神经网络(ConvolutionalNeuralNet......
  • 基于附带Attention机制的seq2seq模型架构实现英译法的案例
    模型架构先上图我们这里选用GRU来实现该任务,因此上图的十个方框框都是GRU块,如第二张图,放第一张图主要是强调编码器的输出是作用在解码器每一次输入的观点,具体的详细流程图将在代码实现部分给出。编码阶段1.准备工作要用到的数据集点此下载,备用地址,点击下载导入相关的......
  • 关于attention中对padding的处理:mask
    先问了下chatgpt:我正在学习torch.nn.multiheadattention,请告诉我att_mask和key_padding_mask这两个参数有什么不同,分别用于处理什么问题,以及输出有什么不同,并给出代码示例chatgpt的回答:torch.nn.MultiheadAttention中的attn_mask和key_padding_mask是两个非常重要的参数,......
  • 基于GWO灰狼优化的CNN-LSTM-Attention的时间序列回归预测matlab仿真
    1.算法运行效果图预览优化前    优化后     2.算法运行软件版本matlab2022a  3.算法理论概述       时间序列回归预测是数据分析的重要领域,旨在根据历史数据预测未来时刻的数值。近年来,深度学习模型如卷积神经网络(ConvolutionalNeuralN......
  • 探索大语言模型:理解Self Attention
    一、背景知识在ChatGPT引发全球关注之后,学习和运用大型语言模型迅速成为了热门趋势。作为程序员,我们不仅要理解其表象,更要探究其背后的原理。究竟是什么使得ChatGPT能够实现如此卓越的问答性能?自注意力机制的巧妙融入无疑是关键因素之一。那么,自注意力机制究竟是什么,它是如何创造......