首页 > 其他分享 >(即插即用模块-Attention部分) 十七、(CVPR 2022) HiLo Attention

(即插即用模块-Attention部分) 十七、(CVPR 2022) HiLo Attention

时间:2024-11-24 14:33:42浏览次数:9  
标签:dim heads self Attention CVPR ws 2022 attn Fi

在这里插入图片描述

文章目录

paper:Fast Vision Transformers with HiLo Attention

Code:https://github.com/ziplab/LITv2


1、HiLo Attention

论文中指出 多头自注意力(MSA) 在高分辨率图像上存在巨大的计算开销。为解决这一问题,本文引入一种 HiLo Attention 来提高速度和准确性。HiLo Attention 通过将注意力层分为高频和低频两部分,分别捕捉图像中的局部细节和全局结构。HiLo的动机在于自然图像包含丰富的频率,其中高、低频率在编码图像图案中分别代表 局部精细细节 和 全局结构。核心思想则是将特征图中的高频和低频信息进行解耦,再分别使用不同的注意力机制进行处理,从而提高视觉 Transformer 在高分辨率图像上的效率。

HiLo Attention 通过将 MSA 分成两条路径,其中一条路径通过 局部自注意力 利用相对高分辨率的特征图来编码高频交互,而另一条路径通过 全局注意力 利用下采样特征图来编码低频交互,这导致了效率的极大提高。对于一个输入特征而言:

  1. Head Splitting:首先根据设定的比例 α 将 MSA 层的头分为两组,一组用于Hi-Fi,另一组用于Lo-Fi。
  2. High Frequency Attention (Hi-Fi):在上方的路径中,通过将一组头部分配给高频注意力(Hi-Fi),再通过局部窗口自注意来捕获细粒度的高频(例如,2 × 2窗口)。Hi-Fi 专注于图像的局部细节,适用于处理高分辨率特征图。
  3. Low Frequency Attention (Lo-Fi):而在下方的路径则是用来实现低频注意力(Lo-Fi),首先对每个窗口应用平均汇集以获得低频信号。然后分配给 Lo-Fi 剩余头部,以建模输入特征映射中的每个查询位置与来自每个窗口的平均池低频键和值之间的关系。Lo-Fi关注于图像的全局结构,适用于处理下采样后的低分辨率特征图。
  4. 输出:最后将细化后的 Hi-Fi 与 Lo-Fi 的结果连接起来。这种设计不仅提高了效率,还通过减少键(keys)和值(values)的长度,实现了显著的复杂度降低。

HiLo Attention 结构图:
在这里插入图片描述

2、LIT v2

在 HiLo Attention 的基础上,论文提出了一种新的 ViTs 架构 LITv2。其整体结构与LIT v1基本类似,不同之处在于LIT v2通过使用 3x3 深度可分离卷积层代替了原来的相对位置编码,将位置信息隐式地学习到零填充中,从而提高速度并扩大早期 MLP 块的感受野,有效地提高了视觉 Transformer 的效率和性能。


LIT v2 结构图:
在这里插入图片描述

3、代码实现

import torch
import torch.nn as nn
from einops.einops import rearrange


class HiLo(nn.Module):
    """
    这个模块 要求输入 的维度是 [B, N, C]  N=H*W
    所以,对于 [B, C, H, W]的张量,需要先转换 维度 ,再进行处理
    转换维度:
    from einops.einops import rearrange
    [B, C, H, W]->[B, H*W, C] : rearrange(x, 'b c h w -> b (h w) c')
    [B, H*W, C]->[B, C, H, W] : rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2,
                 alpha=0.5):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        head_dim = int(dim / num_heads)
        self.dim = dim

        # self-attention heads in Lo-Fi
        self.l_heads = int(num_heads * alpha)
        # token dimension in Lo-Fi
        self.l_dim = self.l_heads * head_dim

        # self-attention heads in Hi-Fi
        self.h_heads = num_heads - self.l_heads
        # token dimension in Hi-Fi
        self.h_dim = self.h_heads * head_dim

        # local window size. The `s` in our paper.
        self.ws = window_size

        if self.ws == 1:
            # ws == 1 is equal to a standard multi-head self-attention
            self.h_heads = 0
            self.h_dim = 0
            self.l_heads = num_heads
            self.l_dim = dim

        self.scale = qk_scale or head_dim ** -0.5

        # Low frequence attention (Lo-Fi)
        if self.l_heads > 0:
            if self.ws != 1:
                self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
            self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
            self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
            self.l_proj = nn.Linear(self.l_dim, self.l_dim)

        # High frequence attention (Hi-Fi)
        if self.h_heads > 0:
            self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
            self.h_proj = nn.Linear(self.h_dim, self.h_dim)

    def hifi(self, x):
        B, H, W, C = x.shape
        h_group, w_group = H // self.ws, W // self.ws

        total_groups = h_group * w_group

        x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)

        qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1,
                                                                                                              4, 2, 5)
        q, k, v = qkv[0], qkv[1], qkv[2]  # B, hw, n_head, ws*ws, head_dim

        attn = (q @ k.transpose(-2, -1)) * self.scale  # B, hw, n_head, ws*ws, ws*ws
        attn = attn.softmax(dim=-1)
        attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
        x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)

        x = self.h_proj(x)
        return x

    def lofi(self, x):
        B, H, W, C = x.shape

        q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)

        if self.ws > 1:
            x_ = x.permute(0, 3, 1, 2)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
        x = self.l_proj(x)
        return x

    def forward(self, x, H, W):
        B, N, C = x.shape

        x = x.reshape(B, H, W, C)

        if self.h_heads == 0:
            x = self.lofi(x)
            return x.reshape(B, N, C)

        if self.l_heads == 0:
            x = self.hifi(x)
            return x.reshape(B, N, C)

        hifi_out = self.hifi(x)
        lofi_out = self.lofi(x)

        x = torch.cat((hifi_out, lofi_out), dim=-1)
        x = x.reshape(B, N, C)

        return x

    def flops(self, H, W):
        # pad the feature map when the height and width cannot be divided by window size
        Hp = self.ws * math.ceil(H / self.ws)
        Wp = self.ws * math.ceil(W / self.ws)

        Np = Hp * Wp

        # For Hi-Fi
        # qkv
        hifi_flops = Np * self.dim * self.h_dim * 3
        nW = (Hp // self.ws) * (Wp // self.ws)
        window_len = self.ws * self.ws
        # q @ k and attn @ v
        window_flops = window_len * window_len * self.h_dim * 2
        hifi_flops += nW * window_flops
        # projection
        hifi_flops += Np * self.h_dim * self.h_dim

        # for Lo-Fi
        # q
        lofi_flops = Np * self.dim * self.l_dim
        kv_len = (Hp // self.ws) * (Wp // self.ws)
        # k, v
        lofi_flops += kv_len * self.dim * self.l_dim * 2
        # q @ k and attn @ v
        lofi_flops += Np * self.l_dim * kv_len * 2
        # projection
        lofi_flops += Np * self.l_dim * self.l_dim

        return hifi_flops + lofi_flops


if __name__ == '__main__':
    H, W = 16, 16
    x = torch.randn(4, 512, 16, 16).cuda()
    x = rearrange(x, 'b c h w -> b (h w) c')
    model = HiLo(512).cuda()
    out = model(x, 16, 16)
    out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)
    print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

标签:dim,heads,self,Attention,CVPR,ws,2022,attn,Fi
From: https://blog.csdn.net/wei582636312/article/details/143999807

相关文章

  • 大模型学习笔记:attention 机制
    UnderstandingQuery,Key,ValueinTransformersandLLMsThisself-attentionprocessisatthecoreofwhatmakestransformerssopowerful.Theyalloweveryword(ortoken)todynamicallyadjustitsimportancebasedonthesurroundingcontext,leadingt......
  • YOLOv10改进,YOLOv10添加DynamicConv(动态卷积),CVPR2024,二次创新C2f结构
    摘要大规模视觉预训练显著提高了大规模视觉模型的性能。现有的低FLOPs模型无法从大规模预训练中受益。在本文中,作者提出了一种新的设计原则,称为ParameterNet,旨在通过最小化FLOPs的增加来增加大规模视觉预训练模型中的参数数量。利用DynamicConv动态卷积将额外的参......
  • COCI2021-2022#4 Šarenlist
    luogu。问题描述:有\(k\)种颜色对一棵树的所有边进行染色,给定\(m\)条限制,每条限制要求\(u,v\)路径上的所有边至少有两种颜色,问染色的方案总数。注意到数据范围:\(m\le15\),明显的一个经典容斥。如何求钦定一些路径颜色全部相同的方案数?对于要求颜色相同的边用并查集并起来,......
  • YOLOv11改进策略【Head】| 结合CVPR-2024 中的DynamicConv 动态卷积 改进检测头, 优化
    一、本文介绍本文记录的是利用DynamicConv优化YOLOv11的目标检测网络模型。在大规模训练中,模型的参数量越多,FLOPs也越高,但在一些对计算资源有限制的场景下,需要低FLOPs的模型同时又希望模型能从大规模预训练中受益。传统的方法很难在增加参数的同时保持低FLOPs,因此Dynamic......
  • [JOISC2022] 洒水器
    [JOISC2022]洒水器题目描述JOI君有多年在自家菜园种植蔬菜的经验,现在他计划管理IOI农场。IOI农场由NNN块土地组成。土地间有......
  • 【NLP自然语言处理】Attention机制原理揭秘:赋予神经网络‘聚焦’与‘理解’的神奇力量
    目录......
  • CRC32爆破脚本 + [MoeCTF 2022]cccrrc 题解
    CRC32爆破原理介绍:CRC(循环冗余校验)是一种用于检测数据传输错误的技术。CRC算法生成一个校验值(校验和),这个值可以附加到数据后面,在数据接收方重新计算校验值并与附加的校验值进行比较,以此来确定数据是否在传输过程中发生了错误CRC32是一种常用的CRC算法,它的校验值长度固定为3......
  • P8814 [CSP-J 2022] 解密 题解
    解方程$题目中说,n=pq,ed=(p-1)(q-1)+1,m=n-ed+2.$$把ed的式子展开,得到:$$ed=p(q-1)-(q-1)+1$$ed=pq-p-q+2$$再把展开后的式子带入m中,得:$$m=n-(pq-p-q+2)+2.$$m=n-pq+p+q-2+2$$\becausen=pq$$\thereforem=pq-pq+p+q-2+2$$m=p+q.$$如果想要求出p和q的值,那么可以再......
  • 基于FFT + CNN - BiGRU-Attention 时域、频域特征注意力融合的电能质量扰动识别模型
    往期精彩内容:Python-电能质量扰动信号数据介绍与分类-CSDN博客Python电能质量扰动信号分类(一)基于LSTM模型的一维信号分类-CSDN博客Python电能质量扰动信号分类(二)基于CNN模型的一维信号分类-CSDN博客Python电能质量扰动信号分类(三)基于Transformer的一维信号分类模型-......
  • 轴承故障诊断 (12)基于交叉注意力特征融合的VMD+CNN-BiLSTM-CrossAttention故障识别模
    往期精彩内容:Python-凯斯西储大学(CWRU)轴承数据解读与分类处理Pytorch-LSTM轴承故障一维信号分类(一)-CSDN博客Pytorch-CNN轴承故障一维信号分类(二)-CSDN博客Pytorch-Transformer轴承故障一维信号分类(三)-CSDN博客三十多个开源数据集|故障诊断再也不用担心数据集了!P......