首页 > 其他分享 >【CVPR2022】Shunted Self-Attention via Multi-Scale Token Aggregation

【CVPR2022】Shunted Self-Attention via Multi-Scale Token Aggregation

时间:2023-09-14 22:58:01浏览次数:49  
标签:dim Multi via heads nn Self drop num self

来自CVPR2022 基于多尺度令牌聚合的分流自注意力

论文地址:[2111.15193] Shunted Self-Attention via Multi-Scale Token Aggregation (arxiv.org)

项目地址:https://github.com/OliverRensu/Shunted-Transformer

一、Introduction

还是经典的ViT的历史遗留问题:ViT中的自注意力计算是针对一个固定的patch大小的token的,每一层内每个token特征的感受野是相似的。这样的约束不可避免地限制了每个自注意力层在捕获多尺度特征方面的能力,从而导致在处理不同尺度的多个对象的图像时的性能下降。为了解决这个问题,作者提出了一个新颖且通用的策略—shunted self-attention(SSA)。SSA 的关键思想是将异构感受野大小注入标记:在计算自注意力矩阵之前,它选择性地合并标记以表示更大的对象特征,同时保持某些标记以保留细粒度特征。这种新颖的合并方案使 self-attention 能够学习不同大小的对象之间的关系,同时降低令牌数量和计算成本。

二、Motivation

1.计算成本高。自注意力机制带来了昂贵的内存消耗成本。

2.ViT生成的特征图为单一尺度的,粗粒度的,这不可避免地限制了模型的性能。

3.之前的Transformer模型在很大程度上忽略了注意层中场景对象的多尺度特性,使它们在涉及不同大小对象的野外场景中变得脆弱。从技术上讲,这种无能归因于它们潜在的注意机制:现有的方法只依赖于一个注意层内令牌的静态接受域和统一的信息粒度,因此无法同时捕获不同尺度的特征。

三、Contribution

1.提出了SSA,将多尺度信息提取的功能集成在一个自注意力层中,SSA自适应地合并大对象上的令牌以提高计算效率,并保留小对象上的令牌来保留更多细节。SSA的多尺度注意机制是通过将多个注意头分成几个组来实现的。每组都考虑了专用的注意力粒度。对于细粒度组,SSA 学习聚合少量标记并保留更多局部细节。对于剩余的粗粒度头部组,SSA 学习聚合大量标记,从而减少计算成本,同时保留捕获大型对象的能力。多粒度组联合学习多粒度信息,使模型能够有效地对多尺度对象进行建模。

2.在此基础上,构建了一种能高效捕获多尺度目标,尤其是微小和远程的孤立目标的分流变压器。

四、Method

SSA块和ViT中传统的自注意块有两个主要区别:1)SSA为每个自注意层引入了一个分流注意机制,以捕获多粒度信息,更好地建模不同大小的对象,特别是小对象;2)通过增加交叉令牌交互,增强了点向前馈层提取局部信息的能力。此外,我们的分流变压器部署了一种新的补丁嵌入方法,用于为第一个注意块获得更好的输入特征映射。

 

 4.1 Shunted Transformer Block

为了降低计算成本,PVT引入了空间减少注意(spatialreduction attention, SRA)来取代原来的多头自注意(multiple -head self-attention, MSA)。然而,SRA倾向于在一个自注意力层中合并太多的令牌,并且只在单个尺度上提供标记特征。这些限制阻碍了模型捕获多尺度目标,特别是小尺度目标的能力。因此,我们通过在一个自注意层中并行学习多粒度引入分流自注意。整体结构后遵循了PVT的层级结构。

4.1.1 Shunted Self-Attention

如图5所示,我们的SSA与PVT的SRA不同之处在于,在同一自注意层的注意头上,K、V的长度不相同。相反,长度在不同的头中变化,以捕获不同粒度的信息。这提供了多尺度令牌聚合(MTA)。具体地,对于由i索引的不同头,K和V被下采样到不同大小,在时间上,下采样操作是由不同大小的卷积完成的,卷积核大小和步长为ri,在一层中有不同的ri,因此,K和V可以关注到不同的尺度,LE(·)是对V值进行深度卷积得到的MTA局部增强分量。与PVT中的SR相比,更多细粒度和低级的细节。

计算成本降低可能取决于 r 的值,因此,我们可以很好地定义模型和 r 以权衡计算成本和模型性能。当 r 变大时,K、V 中的更多令牌合并,K、V 的长度较短,因此计算成本较低,但仍保留捕获大对象的能力。相比之下,当 r 变小时,保留了更多细节,但带来了更多的计算成本。在一个自注意力层中集成各种 r 使其能够捕获多粒度特征。

 

代码部分:原理较为简单,核心是采用了分组的思想,用不同的卷积核和步长进行卷积操作完成下采样这一步,得到两个不同大小的结果,用于获取不同长度的K和V,对V值进行深度卷积得到的MTA局部增强分量与原始的V值相加,得到增强后的的V值,两组k和v分别表示前一半和后一半的head产生的,将q分为两组,然后分别进行两组自注意力的计算得到x1和x2,将x1和x2在维度上进行拼接,得到最终的x。

class Attention(nn.Module):     def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):         super().__init__()         assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim         self.num_heads = num_heads         head_dim = dim // num_heads         self.scale = qk_scale or head_dim ** -0.5
        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)         self.proj = nn.Linear(dim, dim)         self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio         if sr_ratio > 1:             self.act = nn.GELU()             if sr_ratio==8:                 self.sr1 = nn.Conv2d(dim, dim, kernel_size=8, stride=8)                 self.norm1 = nn.LayerNorm(dim)                 self.sr2 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)                 self.norm2 = nn.LayerNorm(dim)             if sr_ratio==4:                 self.sr1 = nn.Conv2d(dim, dim, kernel_size=4, stride=4)                 self.norm1 = nn.LayerNorm(dim)                 self.sr2 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)                 self.norm2 = nn.LayerNorm(dim)             if sr_ratio==2:                 self.sr1 = nn.Conv2d(dim, dim, kernel_size=2, stride=2)                 self.norm1 = nn.LayerNorm(dim)                 self.sr2 = nn.Conv2d(dim, dim, kernel_size=1, stride=1)                 self.norm2 = nn.LayerNorm(dim)             self.kv1 = nn.Linear(dim, dim, bias=qkv_bias)             self.kv2 = nn.Linear(dim, dim, bias=qkv_bias)             self.local_conv1 = nn.Conv2d(dim//2, dim//2, kernel_size=3, padding=1, stride=1, groups=dim//2)             self.local_conv2 = nn.Conv2d(dim//2, dim//2, kernel_size=3, padding=1, stride=1, groups=dim//2)         else:             self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)             self.local_conv = nn.Conv2d(dim, dim, kernel_size=3, padding=1, stride=1, groups=dim)     def forward(self, x, H, W):         B, N, C = x.shape         q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)         if self.sr_ratio > 1:                 x_ = x.permute(0, 2, 1).reshape(B, C, H, W)                 x_1 = self.act(self.norm1(self.sr1(x_).reshape(B, C, -1).permute(0, 2, 1)))                 x_2 = self.act(self.norm2(self.sr2(x_).reshape(B, C, -1).permute(0, 2, 1)))                 kv1 = self.kv1(x_1).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)                 kv2 = self.kv2(x_2).reshape(B, -1, 2, self.num_heads//2, C // self.num_heads).permute(2, 0, 3, 1, 4)                 k1, v1 = kv1[0], kv1[1] #B head N C                 k2, v2 = kv2[0], kv2[1]                 attn1 = (q[:, :self.num_heads//2] @ k1.transpose(-2, -1)) * self.scale                 attn1 = attn1.softmax(dim=-1)                 attn1 = self.attn_drop(attn1)                 v1 = v1 + self.local_conv1(v1.transpose(1, 2).reshape(B, -1, C//2).                                         transpose(1, 2).view(B,C//2, H//self.sr_ratio, W//self.sr_ratio)).\                     view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)                 x1 = (attn1 @ v1).transpose(1, 2).reshape(B, N, C//2)                 attn2 = (q[:, self.num_heads // 2:] @ k2.transpose(-2, -1)) * self.scale                 attn2 = attn2.softmax(dim=-1)                 attn2 = self.attn_drop(attn2)                 v2 = v2 + self.local_conv2(v2.transpose(1, 2).reshape(B, -1, C//2).                                         transpose(1, 2).view(B, C//2, H*2//self.sr_ratio, W*2//self.sr_ratio)).\                     view(B, C//2, -1).view(B, self.num_heads//2, C // self.num_heads, -1).transpose(-1, -2)                 x2 = (attn2 @ v2).transpose(1, 2).reshape(B, N, C//2)
                x = torch.cat([x1,x2], dim=-1)         else:             kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)             k, v = kv[0], kv[1]
            attn = (q @ k.transpose(-2, -1)) * self.scale             attn = attn.softmax(dim=-1)             attn = self.attn_drop(attn)
            x = (attn @ v).transpose(1, 2).reshape(B, N, C) + self.local_conv(v.transpose(1, 2).reshape(B, N, C).                                         transpose(1, 2).view(B,C, H, W)).view(B, C, N).transpose(1, 2)         x = self.proj(x)         x = self.proj_drop(x)
        return x

4.1.2 Detail-specific Feedforward Layers

在传统的前馈层中,全连接层是逐点的,不能学习交叉标记信息。在这里,我们的目标是通过指定前馈层的细节来补充本地信息。如图 6 所示,我们通过在前馈层中的两个全连接层之间添加数据特定层来补充前馈层中的局部细节。实践中由深度卷积实现。

代码:

class Mlp(nn.Module):     def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):         super().__init__()         out_features = out_features or in_features         hidden_features = hidden_features or in_features         self.fc1 = nn.Linear(in_features, hidden_features)         self.dwconv = DWConv(hidden_features)         self.act = act_layer()         self.fc2 = nn.Linear(hidden_features, out_features)         self.drop = nn.Dropout(drop)
    def forward(self, x, H, W):         x = self.fc1(x)         x = self.act(x + self.dwconv(x, H, W))         x = self.drop(x)         x = self.fc2(x)         x = self.drop(x)         return x

4.2. Patch Embedding

ViT直接将输入图像分割成16 × 16个不重叠的补丁。最近的研究发现,在补丁嵌入中使用卷积可以提供更高质量的令牌序列,并帮助Transformer“看到更好”比传统的大步幅非重叠补丁嵌入。因此,一些作品 使用 7 × 7 卷积进行重叠的补丁嵌入。在我们的模型中,我们根据模型大小采用具有不同重叠的卷积层。我们采用步长为 2 和零填充的 7 × 7 卷积层作为补丁嵌入中的第一层,并根据模型大小添加步长为 1 的额外 3 × 3 卷积层。最后,步幅为 2 的非重叠投影层以生成大小为 H/4 × W/4 的输入序列。(CVT等等一系列工作都是用卷积生成token。

五、Conclusion

本文提出了一种新颖的分流自注意力 (SSA) 方案来明确解释多尺度特征。与之前只关注一个注意力层中的静态特征图的工作不同,在一个自注意力层中保持关注多尺度对象的各种尺度特征图。

 

标签:dim,Multi,via,heads,nn,Self,drop,num,self
From: https://www.cnblogs.com/yeonni/p/17482948.html

相关文章

  • 多主架构:VLDB技术论文《Taurus MM: bringing multi-master to the cloud》解读
    本文分享自华为云社区《多主创新,让云数据库性能更卓越》,作者:GaussDB数据库。华为《TaurusMM:bringingmulti-mastertothecloud》论文被国际数据库顶会VLDB2023录用,这篇论文里讲述了符合云原生数据库特点的超燃技术。介绍了如何通过各种黑科技减少云原生数据库的网络消耗,......
  • 多主架构:VLDB技术论文《Taurus MM: bringing multi-master to the cloud》解读
    本文分享自华为云社区《多主创新,让云数据库性能更卓越》,作者:GaussDB数据库。华为《TaurusMM:bringingmulti-mastertothecloud》论文被国际数据库顶会VLDB2023录用,这篇论文里讲述了符合云原生数据库特点的超燃技术。介绍了如何通过各种黑科技减少云原生数据库的网络消耗,进......
  • Transformer-empowered Multi-scale Contextual Matching and Aggregation for
    Transformer-empoweredMulti-scaleContextualMatchingandAggregationforMulti-contrastMRISuper-resolution(阅读文献)10.12基于变压器的磁共振多对比度超分辨率多尺度背景匹配与聚合摘要:MRI可以显示相同解剖结构的多对比图像,使多对比超分辨率(SR)技术成为可能。和使用单一......
  • 【学习笔记】Self-attention
    最近想学点NLP的东西,开始看BERT,看了发现transformer知识丢光了,又来看self-attention;看完self-attention发现还得再去学学wordembedding...推荐学习顺序是:wordembedding、self-attention/transformer、BERT(后面可能还会补充新的)我是看的李宏毅老师的课程+pdf,真的很爱他的课........
  • Graph transduction via alternating minimization
    目录概符号说明GTAM交替优化求解WangJ.,JebaraT.andChangS.Graphtransductionviaalternatingminimization.ICML,2008.概一种对类别不均更鲁棒的半监督算法.符号说明\(\mathcal{X}_l=\{\mathbf{x}_1,\cdots,\mathbf{x}_l\}\),labeledinputs;\(\mathcal......
  • Python PIL 远程命令执行漏洞(via Ghostscript)
    目录1.1、漏洞描述1.2、漏洞等级1.3、影响版本1.4、漏洞复现1、基础环境2、漏洞分析3、漏洞验证1.5、深度利用1、反弹Shell说明内容漏洞编号PIL-CVE-2018-16509漏洞名称PythonPIL远程命令执行漏洞漏洞评级影响范围漏洞描述修复方案1.1、漏洞......
  • C++ STL之map、multimap
    map和multimap是C++STL(StandardTemplateLibrary)中的关联容器,它们提供键值对的存储和访问。map是一个有序关联容器,它存储一组键值对,其中每个键都是唯一的。map中的键值对按照键的升序排序。用户可以通过键来访问、修改和删除对应的值。map的实现通常使用平衡二叉搜索树(如红黑树......
  • Cousleur (ICPC 青岛) (值域主席树 + 逆序对 + multiset +mp)
    题目大意:给一个序列 n会有n次操作,每次都会删除一个数这个数是连续子序列里面最大的逆序对的个数^Q[i],q[i]给出思路:启发式拆分,每次选择长度小的序列来进行处理数学化:rev(逆序对个数)   rev(x+1,r)=rev(l,r)-rev(l,x-1)-(一个元素......
  • MultipartFile转File
    总有些奇奇怪怪的转换~publicstaticFileconvertMultipartFileToFile(MultipartFilemultipartFile)throwsIOException{Filefile=newFile(multipartFile.getOriginalFilename());//创建一个新的File对象try(FileOutputStreamfos=newFileO......
  • 自定义配置文件参数在application可以直接识别Not registered via @EnableConfigurati
    自定义配置文件参数在application可以直接识别Notregisteredvia@EnableConfigurationPropertiesormarkedasSpringcomponent看见很多开源项目的配置文件可以直接配置在application.yaml中,自己也想弄一个,怎么弄呢?这是我的demo,你正常ConfigurationProperties会报错Notregi......