首页 > 其他分享 >CeiT(ICCV 2021, SenseTime)论文与代码解析

CeiT(ICCV 2021, SenseTime)论文与代码解析

时间:2024-07-23 15:53:40浏览次数:17  
标签:dim mathbf self SenseTime times ICCV 2021 attn qkv

paper:Incorporating Convolution Designs into Visual Transformers

official implementation:GitHub - coeusguo/ceit

背景

近年来,Transformer在自然语言处理(NLP)任务中取得了巨大的成功,并且开始有一些尝试将其应用于视觉领域。然而,纯Transformer架构在视觉任务中通常需要大量的训练数据或额外的监督才能获得与卷积神经网络(CNN)相当的性能。例如,ViT(Vision Transformer)直接从NLP领域借用Transformer架构用于图像分类,但严重依赖于大规模数据集(如JFT-300M)进行训练。为了缓解这一问题,DeiT(Data-efficient image Transformer)引入了知识蒸馏技术,通过一个高性能的CNN模型作为教师模型来提升ViT的性能。

出发点

在DeiT中,使用CNN教师蒸馏的Transformer效果更好,这可能是“Transformer通过蒸馏继承的归纳偏置”的功劳。此外作者重新审视了Transformer,总结了纯Transformer架构在视觉任务中的几个主要问题:

  • 难以从原始图像中提取低级特征(如边缘和角点)。
  • 忽略了空间维度中的局部性。
  • 需要大量的训练数据和计算资源。

这些问题促使研究者重新思考是否应该完全移除卷积操作,或者是否应当保留卷积中的某些归纳偏置(如平移不变性和局部性)来改善视觉Transformer的性能。

创新点

CeiT(Convolution-enhanced image Transformer)结合了CNN和Transformer的优点,提出了以下创新点:

  1. Image-to-Tokens(I2T)模块
    设计一个轻量级的I2T模块,从生成的低级特征中提取图像块,而不是直接从原始图像中提取。这一模块利用卷积层和最大池化层来生成特征图,从而降低了嵌入层的训练难度。
  2. Locally-enhanced Feed-Forward(LeFF)层
    替换每个编码器块中的前馈网络(FFN)层,增强空间维度中邻近token之间的相关性。LeFF层通过深度卷积在恢复到原始位置的“图像”上执行操作,从而增强局部特征的提取。
  3. Layer-wise Class token Attention(LCA)
    在Transformer的顶层附加LCA模块,利用多级表示来改进最终的图像表示。

方法介绍

Image-to-Tokens with Low-level Features

为了解决tokenization难以提取低级特征的问题,作者提出了一个有效的模块Image-to-Tokens(I2T)从特征图中提取patch而不是原始图片,如图2所示。

I2T是一个轻量的模块,由一个卷积层和一个最大池化层组成,此外在卷积层后面还加了一个BN,如下

其中 \(\mathbf{x}'\in \mathbb{R}^{\frac{H}{S}\times \frac{W}{S}\times D}\),\(S\) 是相对于原始输入图片的步长,\(D\) 是扩展后的通道数。然后和原始ViT中的patch embedding层一样从特征图 \(\mathbf{x}'\) 中提取patch得到一个序列,为了和ViT的token数量保持一致,patch的分辨率缩小到 \((\frac{P}{S},\frac{P}{S})\),实际应用中设置 \(S=4\)。

I2T充分利用了CNN在提取低级特征方面的优势,此外减小patch size也降低了embedding的训练难度。

Locally-Enhanced Feed-Forward Network

为了结合CNN提取局部信息的优势和Transformer建立长距离依赖的能力,作者提出了一个局部增强前馈网络(Locally-enhanced Feed-Forward Network,LeFF),如图3所示。

LeFF的步骤如下:首先给定前一个MSA的输出tokens \(\mathbf{x}_t^h\in\mathbb{R}^{(N+1)\times C}\),将其split成patch tokens \(\mathbf{x}^h_p\in \mathbb{R}^{N\times C}\) 和class token \(\mathbf{x}_c^h\in\mathbb{R}^{C}\),然后通过一个线性层将patch tokens映射到一个更高的维度 \(\mathbf{x}_p^{l_1}\in \mathbb{R}^{N\times(e\times C)}\),其中 \(e\) 是expand ratio。然后将patch tokens在空间维度reshape回“images”得到 \(\mathbf{x}_p^s\in\mathbb{R}^{\sqrt{N}\times \sqrt{N}\times (e\times C)}\),然后经过一个kernel大小为 \(k\) 的深度卷积增强与临近 \(k^2-1\) okens表示的相关性,得到 \(\mathbf{x}_p^d\in\mathbb{R}^{\sqrt{N}\times \sqrt{N}\times (e\times C)}\)。然后再flatten成序列得到 \(\mathbf{x}_p^f\in\mathbb{R}^{N\times (e\times C)}\)。最后再通过一个线性层映射回原始维度 \(\mathbf{x}_p^{l_2}\in\mathbb{R}^{N\times C}\),并与class token拼接起来,得到 \(\mathbf{x}_t^{h+1}\in\mathbb{R}^{(N+1)\times C}\)。在每个线性层和深度卷积后都有一个BN层和GELU激活函数,整个过程如下

Layer-wise Class-Token Attention

在网络中不同层的特征表示是不同的,为了整合不同层的信息,作者设计了一个Layer-wise Class-Token Attention模块(LCA)。如图4所示,LCA的输入为来自不同层的class token,表示为 \(\mathbf{X}_c=[\mathbf{x}_c^{(1)},...,\mathbf{x}_c^{(l)},...,\mathbf{x}_c^{(L)}]\),其中 \(l\) 表示层数。LCA和标准的Transformer block一样包含一个MSA和一个FFN,但是它只计算最后一层即 \(L\) 层的class token \(\mathbf{x}_c^{(L)}\) 和其它层class tokens的单向相似性,这将attention的计算复杂度从 \(O(n^2)\) 降低到了 \(O(n)\)。

实验结果

作者设计了三种不同大小的CeiT,具体配置如下

在ImageNet数据集上和其它模型的对比如表4所示

消融实验

不同类型的I2T的对比如表5所示,可以看到不用max pooling,无论是直接用步长为4的卷积还是两个步长为2的卷积性能都有所下降,maxpool和BN都对性能的提升有帮助。

在LeFF模块中,卷积核的大小代表了建立局部相关性的区域大小,不同卷积核的对比如表7所示,可以看到随着卷积核的增大精度也跟着提升,当使用BN时精度进一步得到提升,基于参数和精度的权衡考虑,最终采用3x3 conv+BN的配置。

代码解析

I2T和LeFF的代码很简单就不讲了,这里只讲一下LCA,因为上面提到过这里只计算最后一层的class token和其它层class token之间的单向相似性。下面的代码中类Attention是常规的自注意力,AttentionLCA继承了Attention,forward中可以看到所谓的单向注意力其实就是query只包含最后一层的class token,而key和value包含了所有层的class token,因此复杂度是 \(O(n)\),其它就没什么了。

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
        self.attention_map = None

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        # self.attention_map = attn
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class AttentionLCA(Attention):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super(AttentionLCA, self).__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
        self.dim = dim
        self.qkv_bias = qkv_bias
        
    def forward(self, x):

        q_weight = self.qkv.weight[:self.dim, :]
        q_bias = None if not self.qkv_bias else self.qkv.bias[:self.dim]
        kv_weight = self.qkv.weight[self.dim:, :]
        kv_bias = None if not self.qkv_bias else self.qkv.bias[self.dim:]
        
        B, N, C = x.shape
        _, last_token = torch.split(x, [N-1, 1], dim=1)
        
        q = F.linear(last_token, q_weight, q_bias)\
             .reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        # (B,1,C)->(B,1,C)->(B,1,h,C/h)->(B,h,1,C/h)
        kv = F.linear(x, kv_weight, kv_bias)\
              .reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # (B,N,C)->(B,N,2C)->(B,N,2,h,C/h)->(2,B,h,N,C/h)
        k, v = kv[0], kv[1]  # (B,h,N,C/h)

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B,h,1,N)
        attn = attn.softmax(dim=-1)
        # self.attention_map = attn
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, 1, C)  # (B,h,1,C/h)->(B,1,h,C/h)->(B,1,C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

标签:dim,mathbf,self,SenseTime,times,ICCV,2021,attn,qkv
From: https://blog.csdn.net/ooooocj/article/details/140638082

相关文章

  • ICS_S7comm协议分析-2021工业互联网内部预选赛
    根据功能可以看出是设置密码拆分data数据,用前两个与0x55进行异或接着对剩下的数操作,操作为与自己距离为-2的数进行异或list=[0x26,0x62,0x10,0x42,0x37,0x7e,0x16,0x52]passwd=[]foriinrange(0,len(list)):ifi==0ori==1:passwd.append(chr(list[i]^0x55))......
  • MViT:性能杠杠的多尺度ViT | ICCV 2021
    论文提出了多尺度视觉Transformer模型MViT,将多尺度层级特征的基本概念与Transformer模型联系起来,在逐层扩展特征复杂度同时降低特征的分辨率。在视频识别和图像分类的任务中,MViT均优于单尺度的ViT。来源:晓飞的算法工程笔记公众号论文:MultiscaleVisionTransformers论文......
  • 2021 ICPC 网络赛 第二场 L Euler Function(势能线段树,欧拉函数,状态压缩)
    2021ICPC网络赛第二场LEulerFunction题意给定序列,定义两个操作\(l,r,x\)对区间\([l,r]\)的数乘\(x\)\(l,r\)求\(\sum\phi{a}_{i}\)思路注意欧拉函数的性质,若\(i\bmodp=0\),\(\phi(i*p)=p*\phi(i)\),否则\(\phi(i*p)=(p-1)*\phi(i)\)因为\(x,w\)的......
  • python 解题 洛谷B2021到B2025
    B2021输出保留3位小数的浮点数n=float(input())n=n-0.000000000000001print('%.3f'%n)B2022输出保留12位小数的浮点数m=float(input())print('%.12f'%m)B2023空格分隔输出a=input()b=int(input())c=float(input())d=float(input())print(a,"",b,"......
  • LeViT:Facebook提出推理优化的混合ViT主干网络 | ICCV 2021
    论文提出了用于快速图像分类推理的混合神经网络LeVIT,在不同的硬件平台上进行不同的效率衡量标准的测试。总体而言,LeViT在速度/准确性权衡方面明显优于现有的卷积神经网络和ViT,比如在80%的ImageNettop-1精度下,LeViT在CPU上比EfficientNet快5倍来源:晓飞的算法工程笔记公众号论......
  • 助力智慧交通,基于YOLO家族最新端到端实时目标检测算法YOLOv10全系列【n/s/m/b/l/x】参
    交通标志检测是交通标志识别系统中的一项重要任务。与其他国家的交通标志相比,中国的交通标志有其独特的特点。卷积神经网络(CNN)在计算机视觉任务中取得了突破性进展,在交通标志分类方面取得了巨大的成功。CCTSDB数据集是由长沙理工大学的相关学者及团队制作而成的,其有交通标志样......
  • Project2007-2021安装包分享:附网盘地址+安装步骤
    不得不承认,Project是从事项目管理人员最常用的软件之一,它不仅可以提高项目的效率,缩短项目开发周期,操作难度相对来说也比较小。也可以说,Project是一款专注于项目管理的桌面应用软件。它可以帮助用户制定项目计划、分派任务、管理资源、跟踪进度以及生成汇报等。MicrosoftProj......
  • 【题解】 [CSP-J 2021 T1] 分糖果
    题目描述题目大意给定正整数\(n\)、\(L\)、\(R\),求\(\max_{i\in\left[L,R\right]}{i\bmodn}\)。思路题目主要考察:分类讨论。众所周知,对于\(\forallx\),有$(x\bmodn)\in\left[0,n-1\right]$。可以分为两种情况讨论:如果\(\left\lfloor\frac{L......
  • [CSP-S 2021] 廊桥分配
    戳我跳转题面题意一共有n个廊桥,全部分配给互相独立的两组。第一组有$m1$个区间$[l_i,r_i]$,第二组有$m2$个区间$[a_i,b_i]$(互相独立),一旦有廊桥空着,就会将$i$区间覆盖于总区间。问一共能满足多少个区间。思路45pts由于两组的处理方法几乎一样,在这里只举第一组的例......
  • 2021杭电多校10 D.Pty hates prime numbers题解
    前言暑期第三次组队赛是选的21年杭电多校10,遗憾爆0,被对面队打爆,赛后狠狠补题。这道题的题解,以及网上搜到的其他题解看了好久没看懂,在问了队里大腿多次后,总算磨出来了,这里讲一下我的理解。题意多次询问,每次给定\(n\)和\(k\),如果一个数的质因数里包括前\(k\)个质数,则这个数......