paper:Incorporating Convolution Designs into Visual Transformers
official implementation:GitHub - coeusguo/ceit
背景
近年来,Transformer在自然语言处理(NLP)任务中取得了巨大的成功,并且开始有一些尝试将其应用于视觉领域。然而,纯Transformer架构在视觉任务中通常需要大量的训练数据或额外的监督才能获得与卷积神经网络(CNN)相当的性能。例如,ViT(Vision Transformer)直接从NLP领域借用Transformer架构用于图像分类,但严重依赖于大规模数据集(如JFT-300M)进行训练。为了缓解这一问题,DeiT(Data-efficient image Transformer)引入了知识蒸馏技术,通过一个高性能的CNN模型作为教师模型来提升ViT的性能。
出发点
在DeiT中,使用CNN教师蒸馏的Transformer效果更好,这可能是“Transformer通过蒸馏继承的归纳偏置”的功劳。此外作者重新审视了Transformer,总结了纯Transformer架构在视觉任务中的几个主要问题:
- 难以从原始图像中提取低级特征(如边缘和角点)。
- 忽略了空间维度中的局部性。
- 需要大量的训练数据和计算资源。
这些问题促使研究者重新思考是否应该完全移除卷积操作,或者是否应当保留卷积中的某些归纳偏置(如平移不变性和局部性)来改善视觉Transformer的性能。
创新点
CeiT(Convolution-enhanced image Transformer)结合了CNN和Transformer的优点,提出了以下创新点:
- Image-to-Tokens(I2T)模块:
设计一个轻量级的I2T模块,从生成的低级特征中提取图像块,而不是直接从原始图像中提取。这一模块利用卷积层和最大池化层来生成特征图,从而降低了嵌入层的训练难度。 - Locally-enhanced Feed-Forward(LeFF)层:
替换每个编码器块中的前馈网络(FFN)层,增强空间维度中邻近token之间的相关性。LeFF层通过深度卷积在恢复到原始位置的“图像”上执行操作,从而增强局部特征的提取。 - Layer-wise Class token Attention(LCA):
在Transformer的顶层附加LCA模块,利用多级表示来改进最终的图像表示。
方法介绍
Image-to-Tokens with Low-level Features
为了解决tokenization难以提取低级特征的问题,作者提出了一个有效的模块Image-to-Tokens(I2T)从特征图中提取patch而不是原始图片,如图2所示。
I2T是一个轻量的模块,由一个卷积层和一个最大池化层组成,此外在卷积层后面还加了一个BN,如下
其中 \(\mathbf{x}'\in \mathbb{R}^{\frac{H}{S}\times \frac{W}{S}\times D}\),\(S\) 是相对于原始输入图片的步长,\(D\) 是扩展后的通道数。然后和原始ViT中的patch embedding层一样从特征图 \(\mathbf{x}'\) 中提取patch得到一个序列,为了和ViT的token数量保持一致,patch的分辨率缩小到 \((\frac{P}{S},\frac{P}{S})\),实际应用中设置 \(S=4\)。
I2T充分利用了CNN在提取低级特征方面的优势,此外减小patch size也降低了embedding的训练难度。
Locally-Enhanced Feed-Forward Network
为了结合CNN提取局部信息的优势和Transformer建立长距离依赖的能力,作者提出了一个局部增强前馈网络(Locally-enhanced Feed-Forward Network,LeFF),如图3所示。
LeFF的步骤如下:首先给定前一个MSA的输出tokens \(\mathbf{x}_t^h\in\mathbb{R}^{(N+1)\times C}\),将其split成patch tokens \(\mathbf{x}^h_p\in \mathbb{R}^{N\times C}\) 和class token \(\mathbf{x}_c^h\in\mathbb{R}^{C}\),然后通过一个线性层将patch tokens映射到一个更高的维度 \(\mathbf{x}_p^{l_1}\in \mathbb{R}^{N\times(e\times C)}\),其中 \(e\) 是expand ratio。然后将patch tokens在空间维度reshape回“images”得到 \(\mathbf{x}_p^s\in\mathbb{R}^{\sqrt{N}\times \sqrt{N}\times (e\times C)}\),然后经过一个kernel大小为 \(k\) 的深度卷积增强与临近 \(k^2-1\) okens表示的相关性,得到 \(\mathbf{x}_p^d\in\mathbb{R}^{\sqrt{N}\times \sqrt{N}\times (e\times C)}\)。然后再flatten成序列得到 \(\mathbf{x}_p^f\in\mathbb{R}^{N\times (e\times C)}\)。最后再通过一个线性层映射回原始维度 \(\mathbf{x}_p^{l_2}\in\mathbb{R}^{N\times C}\),并与class token拼接起来,得到 \(\mathbf{x}_t^{h+1}\in\mathbb{R}^{(N+1)\times C}\)。在每个线性层和深度卷积后都有一个BN层和GELU激活函数,整个过程如下
Layer-wise Class-Token Attention
在网络中不同层的特征表示是不同的,为了整合不同层的信息,作者设计了一个Layer-wise Class-Token Attention模块(LCA)。如图4所示,LCA的输入为来自不同层的class token,表示为 \(\mathbf{X}_c=[\mathbf{x}_c^{(1)},...,\mathbf{x}_c^{(l)},...,\mathbf{x}_c^{(L)}]\),其中 \(l\) 表示层数。LCA和标准的Transformer block一样包含一个MSA和一个FFN,但是它只计算最后一层即 \(L\) 层的class token \(\mathbf{x}_c^{(L)}\) 和其它层class tokens的单向相似性,这将attention的计算复杂度从 \(O(n^2)\) 降低到了 \(O(n)\)。
实验结果
作者设计了三种不同大小的CeiT,具体配置如下
在ImageNet数据集上和其它模型的对比如表4所示
消融实验
不同类型的I2T的对比如表5所示,可以看到不用max pooling,无论是直接用步长为4的卷积还是两个步长为2的卷积性能都有所下降,maxpool和BN都对性能的提升有帮助。
在LeFF模块中,卷积核的大小代表了建立局部相关性的区域大小,不同卷积核的对比如表7所示,可以看到随着卷积核的增大精度也跟着提升,当使用BN时精度进一步得到提升,基于参数和精度的权衡考虑,最终采用3x3 conv+BN的配置。
代码解析
I2T和LeFF的代码很简单就不讲了,这里只讲一下LCA,因为上面提到过这里只计算最后一层的class token和其它层class token之间的单向相似性。下面的代码中类Attention是常规的自注意力,AttentionLCA继承了Attention,forward中可以看到所谓的单向注意力其实就是query只包含最后一层的class token,而key和value包含了所有层的class token,因此复杂度是 \(O(n)\),其它就没什么了。
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.attention_map = None
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
# self.attention_map = attn
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class AttentionLCA(Attention):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super(AttentionLCA, self).__init__(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
self.dim = dim
self.qkv_bias = qkv_bias
def forward(self, x):
q_weight = self.qkv.weight[:self.dim, :]
q_bias = None if not self.qkv_bias else self.qkv.bias[:self.dim]
kv_weight = self.qkv.weight[self.dim:, :]
kv_bias = None if not self.qkv_bias else self.qkv.bias[self.dim:]
B, N, C = x.shape
_, last_token = torch.split(x, [N-1, 1], dim=1)
q = F.linear(last_token, q_weight, q_bias)\
.reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# (B,1,C)->(B,1,C)->(B,1,h,C/h)->(B,h,1,C/h)
kv = F.linear(x, kv_weight, kv_bias)\
.reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# (B,N,C)->(B,N,2C)->(B,N,2,h,C/h)->(2,B,h,N,C/h)
k, v = kv[0], kv[1] # (B,h,N,C/h)
attn = (q @ k.transpose(-2, -1)) * self.scale # (B,h,1,N)
attn = attn.softmax(dim=-1)
# self.attention_map = attn
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (B,h,1,C/h)->(B,1,h,C/h)->(B,1,C)
x = self.proj(x)
x = self.proj_drop(x)
return x
标签:dim,mathbf,self,SenseTime,times,ICCV,2021,attn,qkv
From: https://blog.csdn.net/ooooocj/article/details/140638082