首页 > 其他分享 >【YOLOv8改进】HAT(Hybrid Attention Transformer,)混合注意力机制 (论文笔记+引入代码)

【YOLOv8改进】HAT(Hybrid Attention Transformer,)混合注意力机制 (论文笔记+引入代码)

时间:2024-06-10 22:13:13浏览次数:22  
标签:dim Transformer self Attention Hybrid patch num embed size

YOLO目标检测创新改进与实战案例专栏

专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例

专栏链接: YOLO基础解析+创新改进+实战案例

摘要

基于Transformer的方法在低级视觉任务中表现出色,例如图像超分辨率。然而,通过归因分析,我们发现这些网络只能利用输入信息的有限空间范围。这表明Transformer在现有网络中的潜力尚未完全发挥。为了激活更多的输入像素以获得更好的重建效果,我们提出了一种新颖的混合注意力Transformer(Hybrid Attention Transformer, HAT)。它结合了通道注意力和基于窗口的自注意力机制,从而利用了它们能够利用全局统计信息和强大的局部拟合能力的互补优势。此外,为了更好地聚合跨窗口信息,我们引入了一个重叠交叉注意模块,以增强相邻窗口特征之间的交互。在训练阶段,我们还采用了同任务预训练策略,以进一步挖掘模型的潜力。大量实验表明了所提模块的有效性,我们进一步扩大了模型规模,证明了该任务的性能可以大幅提高。我们的方法整体上显著优于最先进的方法,超过了1dB。

创新点

  1. 更多像素的激活:通过结合不同的注意力机制,HAT能够激活更多的输入像素,这在图像超分辨率领域尤为重要,因为它直接关系到重建图像的细节和质量。

  2. 交叉窗口信息的有效聚合:通过重叠交叉注意力模块,HAT模型能够更有效地聚合跨窗口的信息,避免了传统Transformer模型中窗口间信息隔离的问题。

  3. 针对图像超分辨率优化的预训练策略:HAT采用的同任务预训练策略针对性强,能够更有效地利用大规模数据预训练的优势,提高模型在特定超分辨率任务上的表现。

HAT模型整体架构:

  1. 浅层特征提取(Shallow Feature Extraction)
    输入图像首先通过一个浅层特征提取模块,该模块使用卷积操作提取初始的低层次特征。

  2. 深层特征提取(Deep Feature Extraction)
    提取的特征输入到多个Residual Hybrid Attention Groups (RHAG)中进行深层特征提取。每个RHAG由多个Hybrid Attention Blocks (HAB)和一个Overlapping Cross-Attention Block (OCAB)组成。

  3. 图像重建(Image Reconstruction)
    最后,经过深层特征提取的特征通过一个图像重建模块,将高层次特征转化为输出的超分辨率图像。

yolov8 引入

class HAT(nn.Module):
   r"""混合注意力变换器 (Hybrid Attention Transformer)
       该PyTorch实现基于 `Activating More Pixels in Image Super-Resolution Transformer`。
       部分代码基于SwinIR。
   参数:
       img_size (int | tuple(int)): 输入图像大小。默认值64
       patch_size (int | tuple(int)): Patch大小。默认值1
       in_chans (int): 输入图像通道数。默认值3
       embed_dim (int): Patch嵌入维度。默认值96
       depths (tuple(int)): 每个Swin Transformer层的深度。
       num_heads (tuple(int)): 不同层的注意力头数量。
       window_size (int): 窗口大小。默认值7
       mlp_ratio (float): MLP隐藏层维度与嵌入维度的比例。默认值4
       qkv_bias (bool): 如果为True,为查询、键和值添加可学习的偏差。默认值True
       qk_scale (float): 如果设置,覆盖head_dim ** -0.5的默认qk缩放比例。默认值None
       drop_rate (float): Dropout率。默认值0
       attn_drop_rate (float): 注意力dropout率。默认值0
       drop_path_rate (float): 随机深度率。默认值0.1
       norm_layer (nn.Module): 规范化层。默认值nn.LayerNorm。
       ape (bool): 如果为True,为Patch嵌入添加绝对位置嵌入。默认值False
       patch_norm (bool): 如果为True,在Patch嵌入后添加规范化。默认值True
       use_checkpoint (bool): 是否使用checkpointing来节省内存。默认值False
       upscale: 上采样因子。用于图像SR的2/3/4/8,1用于降噪和压缩伪影减少
       img_range: 图像范围。1或255。
       upsampler: 重建模块。'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
       resi_connection: 残差连接前的卷积块。'1conv'/'3conv'
   """

   def __init__(self,
                in_chans=3,
                img_size=64,
                patch_size=1,
                embed_dim=96,
                depths=(6, 6, 6, 6),
                num_heads=(6, 6, 6, 6),
                window_size=7,
                compress_ratio=3,
                squeeze_factor=30,
                conv_scale=0.01,
                overlap_ratio=0.5,
                mlp_ratio=4.,
                qkv_bias=True,
                qk_scale=None,
                drop_rate=0.,
                attn_drop_rate=0.,
                drop_path_rate=0.1,
                norm_layer=nn.LayerNorm,
                ape=False,
                patch_norm=True,
                use_checkpoint=False,
                upscale=2,
                img_range=1.,
                upsampler='',
                resi_connection='1conv',
                **kwargs):
       super(HAT, self).__init__()

       self.window_size = window_size
       self.shift_size = window_size // 2
       self.overlap_ratio = overlap_ratio

       num_in_ch = in_chans
       num_out_ch = in_chans
       num_feat = 64
       self.img_range = img_range
       if in_chans == 3:
           rgb_mean = (0.4488, 0.4371, 0.4040)
           self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
       else:
           self.mean = torch.zeros(1, 1, 1, 1)
       self.upscale = upscale
       self.upsampler = upsampler

       # 相对位置索引
       relative_position_index_SA = self.calculate_rpi_sa()
       relative_position_index_OCA = self.calculate_rpi_oca()
       self.register_buffer('relative_position_index_SA', relative_position_index_SA)
       self.register_buffer('relative_position_index_OCA', relative_position_index_OCA)

       # ------------------------- 1,浅层特征提取 ------------------------- #
       self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)

       # ------------------------- 2,深层特征提取 ------------------------- #
       self.num_layers = len(depths)
       self.embed_dim = embed_dim
       self.ape = ape
       self.patch_norm = patch_norm
       self.num_features = embed_dim
       self.mlp_ratio = mlp_ratio

       # 将图像分割为非重叠patch
       self.patch_embed = PatchEmbed(
           img_size=img_size,
           patch_size=patch_size,
           in_chans=embed_dim,
           embed_dim=embed_dim,
           norm_layer=norm_layer if self.patch_norm else None)
       num_patches = self.patch_embed.num_patches
       patches_resolution = self.patch_embed.patches_resolution
       self.patches_resolution = patches_resolution

       # 将非重叠的patch合并为图像
       self.patch_unembed = PatchUnEmbed(
           img_size=img_size,
           patch_size=patch_size,
           in_chans=embed_dim,
           embed_dim=embed_dim,
           norm_layer=norm_layer if self.patch_norm else None)

       # 绝对位置嵌入
       if self.ape:
           self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
           trunc_normal_(self.absolute_pos_embed, std=.02)

       self.pos_drop = nn.Dropout(p=drop_rate)

       # 随机深度
       dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # 随机深度衰减规则

       # 构建残差混合注意力组 (RHAG)
       self.layers = nn.ModuleList()
       for i_layer in range(self.num_layers):
           layer = RHAG(
               dim=embed_dim,
               input_resolution=(patches_resolution[0], patches_resolution[1]),
               depth=depths[i_layer],
               num_heads=num_heads[i_layer],
               window_size=window_size,
               compress_ratio=compress_ratio,
               squeeze_factor=squeeze_factor,
               conv_scale=conv_scale,
               overlap_ratio=overlap_ratio,
               mlp_ratio=self.mlp_ratio,
               qkv_bias=qkv_bias,
               qk_scale=qk_scale,
               drop=drop_rate,
               attn_drop=attn_drop_rate,
               drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # 对SR结果无影响
               norm_layer=norm_layer,
               downsample=None,
               use_checkpoint=use_checkpoint,
               img_size=img_size,
               patch_size=patch_size,
               resi_connection=resi_connection)
           self.layers.append(layer)
       self.norm = norm_layer(self.num_features)

       # 构建深层特征提取中的最后一个卷积层
       if resi_connection == '1conv':
           self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
       elif resi_connection == 'identity':
           self.conv_after_body = nn.Identity()

       # ------------------------- 3,高质量图像重建 ------------------------- #
       if self.upsampler == 'pixelshuffle':
           # 用于经典SR
           self.conv_before_upsample = nn.Sequential(
               nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
           self.upsample = Upsample(upscale, num_feat)
           self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)

       self.apply(self._init_weights)

task与yaml配置

详见:https://blog.csdn.net/shangyanaf/article/details/139142532

标签:dim,Transformer,self,Attention,Hybrid,patch,num,embed,size
From: https://www.cnblogs.com/banxia-frontend/p/18241118

相关文章

  • 【YOLOv8改进】EMA(Efficient Multi-Scale Attention):基于跨空间学习的高效多尺度注意力
    YOLO目标检测创新改进与实战案例专栏专栏目录:YOLO有效改进系列及项目实战目录包含卷积,主干注意力,检测头等创新机制以及各种目标检测分割项目实战案例专栏链接:YOLO基础解析+创新改进+实战案例摘要通道或空间注意力机制在许多计算机视觉任务中表现出显著的效果,可以......
  • 【YOLOv8改进】ACmix(Mixed Self-Attention and Convolution) (论文笔记+引入代码)
    YOLO目标检测创新改进与实战案例专栏专栏目录:YOLO有效改进系列及项目实战目录包含卷积,主干注意力,检测头等创新机制以及各种目标检测分割项目实战案例专栏链接:YOLO基础解析+创新改进+实战案例摘要卷积和自注意力是两个强大的表示学习技术,通常被认为是彼此独立的两......
  • attention机制、LSTM二者之间,是否存在attention一定优于LSTM的关系呢?
    这里没有严格的论证,只是自己的一些理解。attention机制下的Transformer确实是当前AI技术中最为火热的,基于其构建的大语言模型可以说是AI技术至今最强的技术之一了,但是attention是否真的的一定优于LSTM呢?其实,attention的效果或者说Transformer的效果是和数据量的多少有关系的,如......
  • 深度学习面试问题总结 | Transformer面试问题总结(二)
    本文给大家带来的百面算法工程师是深度学习Transformer的面试总结,文章内总结了常见的提问问题,旨在为广大学子模拟出更贴合实际的面试问答场景。在这篇文章中,我们还将介绍一些常见的深度学习算法工程师面试问题,并提供参考的回答及其理论基础,以帮助求职者更好地准备面试。通过......
  • 【Pytorch】一文向您详细介绍 nn.MultiheadAttention() 的作用和用法
    【Pytorch】一文向您详细介绍nn.MultiheadAttention()的作用和用法 下滑查看解决方法......
  • 深入对比:Transformer 与 RNN 的详细解析
    在自然语言处理(NLP)和机器学习领域,模型的选择对任务的成败至关重要。Transformer和RNN(递归神经网络)是两种流行但截然不同的模型架构。本文将深入探讨这两种架构的特点、优势、劣势,并通过实际案例进行比较。1.RNN(递归神经网络)1.1RNN简介RNN是一种处理序列数据的神经......
  • 【transformer】安装
    pipinstalltransformers==4.28.1下载pyyaml>=5.1时候遇到网络问题下载不下来pippyyaml==5.3-ihttps://pypi.tuna.tsinghua.edu.cn/simplepipinstalltransformers==4.28.1-ihttps://pypi.tuna.tsinghua.edu.cn/simple ---hugging-face下载bert库exportHF_ENDPOINT......
  • 【YOLOv8改进】CPCA(Channel prior convolutional attention)中的通道注意力,增强特征
    YOLO目标检测创新改进与实战案例专栏专栏目录:YOLO有效改进系列及项目实战目录包含卷积,主干注意力,检测头等创新机制以及各种目标检测分割项目实战案例专栏链接:YOLO基础解析+创新改进+实战案例摘要医学图像通常展示出低对比度和显著的器官形状变化等特征。现有注意......
  • Block Transformer:通过全局到局部的语言建模加速LLM推理
    在基于transformer的自回归语言模型(LMs)中,生成令牌的成本很高,这是因为自注意力机制需要关注所有之前的令牌,通常通过在自回归解码过程中缓存所有令牌的键值(KV)状态来解决这个问题。但是,加载所有先前令牌的KV状态以计算自注意力分数则占据了LMs的推理的大部分成本。在这篇论文中,作者......
  • Pyramid Vision Transformer, PVT(ICCV 2021)原理与代码解读
    paper:PyramidVisionTransformer:AVersatileBackboneforDensePredictionwithoutConvolutionsofficialimplementation:GitHub-whai362/PVT:OfficialimplementationofPVTseries存在的问题现有的VisionTransformer(ViT)主要设计用于图像分类任务,难以直接用......