paper:LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference
official implementation:https://github.com/facebookresearch/LeViT
third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/levit.py
本文的创新点
本文旨在设计一种新的图像分类架构,通过结合卷积网络的优势和转换器的优点,优化高效推理时的精度和速度权衡。具体如下
- 多阶段Transformer架构:提出了一种多阶段Transformer架构,使用注意力机制进行下采样。这种设计类似于传统卷积网络中的金字塔结构,使得特征图的分辨率逐步降低,提高了计算效率。
- 高效的patch descriptor:在模型的前几层中引入了计算效率高的Patch描述符,减少了特征数量,从而提升了网络的整体计算效率。
- 注意力偏置:引入了一种新的attention bias,代替了视觉转换器中的位置编码(Positional Encoding),实现了平移不变的空间信息编码,提升了模型的表现。
- 重新设计的Attention-MLP block:重新设计了注意力-MLP模块,提高了网络在给定计算时间内的容量,使得网络可以在相同的计算资源下获得更好的表现。
- 混合架构设计:通过结合卷积和Transformer的混合架构(grafting experiments),在相同的计算budget下实现了更好的精度和速度权衡。这种混合架构在训练初期表现出与卷积网络相似的快速收敛特性,同时在后期表现出Transformer的高精度。
总体来说,本文通过引入这些创新,提出了一种名为LeViT的混合神经网络,在ImageNet数据集上的实验结果表明,该模型在推理速度和精度上显著优于现有的卷积网络和视觉Transformer。
方法介绍
Vision Transformer中的patch projection层通过16x16 stride=16的卷积实现,引发了作者对卷积与Transformer之间联系的思考。在卷积中,mask的空间平滑性来自于卷积过程中卷积核的重叠:临近的像素接收到相似的梯度。而在ViT中平滑掩膜可能是由于数据增强造成的,当一个图像出现两次且发生微小的平移时,相同的梯度经过每个filter,所以它可以学习这种空间平滑性。因此尽管在Transformer架构中没有归纳偏置"inductive bias",训练确实产生了类似卷积层的filter。
作者首先用ResNet-50和DeiT-S进行了一个嫁接实验,结果如表1所示。
可以看到嫁接的结构比单独的ResNet-50和DeiT-S的效果都要好,其中精度最高同时参数量最小的组合是与两个stage的ResNet-50进行嫁接。
一个有趣的观察如图3所示,嫁接模型在训练早期的收敛性和卷积网络类似,然后切换到类似于DeiT-S的收敛速度。一种假设是,卷积由于其本身的inductive bias能力(平移不变性),使其可以在网络的浅层更有效的学习low-level information,它们快速找到有意义的patch embedding,这可以解释为什么在第一个epoch可以快速收敛。
基于上述观察, 作者认为在transformer下面插入卷积stage是有益的,大部分的处理仍然是在后续堆叠的transformer block中实现的,以获得嫁接结构精度最高的变体。因此接下来作者重点研究了如何降低transformer的计算成本,以及如何与卷积更紧密地结合而不仅仅是嫁接起来。
LeViT的完整结构如图4所示。具体的设计原则如下
Patch embedding
在LeViT中,作者采用4层stride=2的3x3卷积进行分辨率的下采样。对于(3, 224, 224)的输入,经过4层卷积后得到维度为(256, 14, 14)的输出进入接下来的transformer block中。
No classification token
为了使用BCHW的张量形式,作者删除了分类token,而是和卷积网络一样,在最后一个特征图上用全局平均池化来得到分类器用的embedding。对于蒸馏,分别训练不同的head进行分类和蒸馏任务。测试时,取这两个head输出的平均值。
Normalization layera and activations
ViT中的FC层等价于1x1卷积。ViT在每个attention层和mlp前都是用了LN。而对于LeViT,每个卷积后都加一个BN。DeiT的激活函数使用了GELU,而在LeViT中所有激活函数都采用Hardswish。
Multi-resolution pyramid
因为LeViT前面嫁接了ResNet的部分stage,因此形成了和卷积网络一样的金字塔结构,特征图的分辨率随着通道数的增加而降低。
下面是对attention block进行的一些修改,如图5所示。
Downsampling
在LeViT的stage之间, 通过一个shrinking attention block来减小激活图的大小,如图5右侧所示。具体来说,在Q变换之前进行一个subsampling,这将大小为 \((C,H,W)\) 的tensor映射为大小为 \((C',H/2,W/2)\) 的输出tensor,其中 \(C'>C\)。由于大小的变化,这个attention block没有使用residual connection。为了防止信息的损失,将attention heads的数量设置为 \(C/D\)。
Attention bias instead of a position embedding
之前的位置编码只包含在attention block的输入序列中,由于位置编码对于更高的层也很重要,所以作者的目标是在每个attention block中都提供位置信息,并显式地在注意力机制中注入相对位置信息:具体是通过在attention map上加上一个attention bias来实现的。两个像素 \((x,y)\in[H]\times [W]\) 和 \((x',y')\in[H]\times [W]\) 之间一个head \(h\in [N]\) 的标量attention value按下式计算
其中第一项就是普通的attention,第二项是平移不变的attention bias,每个head都有对应于不同像素offset的 \(H\times W\) 个参数。取绝对值是鼓励网络以flip invariance的方式进行训练。
Smaller keys
bias项减少了key对位置信息编码的压力,所以相比于 \(V\) 我们减小了key矩阵的大小。如果key的维度 \(D\in\{16,32\}\),\(V\) 的通道数为 \(2D\)。限制key的大小减少了计算 \(QK^T\) 的时间。
对于降采样层,其中没有residual connection,我们将 \(V\) 的维度设置为 \(4D\) 来防止信息损失。
Attention activation
在使用线性映射来组合不同head的输出之前,我们对乘积 \(A^hV\) 应用一个Hardswish。
Reducing the MLP blocks
通常ViT的MLP隐藏层维度的expansion ratio设置为4,对于LeViT,MLP由一个1x1卷积和一个BN组成,为了降低计算量,我们将expansion factor由4降低为2。
至此LeViT中的所有改进都介绍完了,在不同的计算量限制下,LeViT有一系列不同的变体,本文通过输入到第一个transformer block的通道数来进行区分,比如LeViT-256表示第一个transformer block的输入通道数为256。表2展示了不同LeViT变体的具体配置。
代码解析
这里以timm中的实现为例介绍一下代码。选择的模型是'levit_conv_128s',具体配置如下,可以看到和表2的第一列是对应的。
levit_128s=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), # 对应表2,key_dim就是D
首先是stem,如上所述就是4个3x3-s2的卷积,且每个卷积后都跟一个激活函数,上面也讲过,本文所有的激活函数都采用Hardswish
class Stem16(nn.Sequential):
def __init__(self, in_chs, out_chs, act_layer):
super().__init__()
self.stride = 16
self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1))
self.add_module('act1', act_layer())
self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1))
self.add_module('act2', act_layer())
self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
self.add_module('act3', act_layer())
self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))
然后就是transformer stage部分,定义如下。其中embed_dim=(128, 256, 384)表示stem的输出即transformer stage的输入的通道数为128。key_dim=16就是论文中的 \(D\)。attn_ratio都是2对应【smaller keys】部分提到的 “V的通道数为2D”。mlp_ratio也都是2对应【Reducing the MLP blocks】部分提到的将隐藏层的expansion ratio由4改为2。
in_dim = embed_dim[0]
stages = []
for i in range(num_stages):
stage_stride = 2 if i > 0 else 1
stages += [LevitStage(
in_dim,
embed_dim[i],
key_dim, # 16
depth=depth[i], # (2,3,4)
num_heads=num_heads[i], # (4,6,8)
attn_ratio=attn_ratio[i], # (2.0, 2.0, 2.0)
mlp_ratio=mlp_ratio[i], # (2.0, 2.0, 2.0)
act_layer=act_layer, # 'Hardswish'
attn_act_layer=attn_act_layer, # 'Hardswish'
resolution=resolution, # (14,14)
use_conv=use_conv, # True
downsample=down_op if stage_stride == 2 else '', # 'subsample'
drop_path=drop_path_rate # 0.0
)]
stride *= stage_stride
resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
in_dim = embed_dim[i]
self.stages = nn.Sequential(*stages)
LevitStage的定义如下,可以看到由一个downsample和blocks组成。先来看downsample部分,首先是num_heads=in_dim // key_dim=128//16=8对应【Downsampling】部分提到的“将attention heads的数量设置为 \(C/D\)”。然后attn_ratio=4对应【smaller keys】部分说的“降采样层,将 \(V\)的维度设置为4D”。
class LevitStage(nn.Module):
def __init__(
self,
in_dim,
out_dim,
key_dim,
depth=4,
num_heads=8,
attn_ratio=4.0,
mlp_ratio=4.0,
act_layer=nn.SiLU,
attn_act_layer=None,
resolution=14,
downsample='',
use_conv=False,
drop_path=0.,
):
super().__init__()
resolution = to_2tuple(resolution)
if downsample:
self.downsample = LevitDownsample(
in_dim, # 128
out_dim, # 256
key_dim=key_dim, # 16
num_heads=in_dim // key_dim, # 128//16=8, 这里就是C/D
attn_ratio=4., # 这里对应的是"we set the dimension of V to 4D to prevent loss of information."
mlp_ratio=2., # Reducing the MLP blocks. we reduce the expansion factor of the convolution from 4 to 2.
act_layer=act_layer,
attn_act_layer=attn_act_layer,
resolution=resolution,
use_conv=use_conv,
drop_path=drop_path,
)
resolution = [(r - 1) // 2 + 1 for r in resolution]
else:
assert in_dim == out_dim
self.downsample = nn.Identity()
blocks = []
for _ in range(depth):
blocks += [LevitBlock(
out_dim,
key_dim,
num_heads=num_heads,
attn_ratio=attn_ratio, # 2, 这里对应的是"V will have 2D channels"
mlp_ratio=mlp_ratio,
act_layer=act_layer,
attn_act_layer=attn_act_layer,
resolution=resolution,
use_conv=use_conv,
drop_path=drop_path,
)]
self.blocks = nn.Sequential(*blocks)
def forward(self, x):
x = self.downsample(x)
x = self.blocks(x)
return x
下面是降采样的具体实现,其中通过attention层来进行降采样上面也提过,对应图5右侧。在forward函数中可以看到attn_downsample部分没有进行residual connection,只在后面的mlp部分进行了残差连接。
class LevitDownsample(nn.Module):
def __init__(
self,
in_dim,
out_dim,
key_dim,
num_heads=8,
attn_ratio=4.,
mlp_ratio=2.,
act_layer=nn.SiLU,
attn_act_layer=None,
resolution=14,
use_conv=False,
use_pool=False,
drop_path=0.,
):
super().__init__()
attn_act_layer = attn_act_layer or act_layer
self.attn_downsample = AttentionDownsample(
in_dim=in_dim,
out_dim=out_dim,
key_dim=key_dim,
num_heads=num_heads,
attn_ratio=attn_ratio,
act_layer=attn_act_layer,
resolution=resolution,
use_conv=use_conv,
use_pool=use_pool,
)
self.mlp = LevitMlp(
out_dim,
int(out_dim * mlp_ratio),
use_conv=use_conv,
act_layer=act_layer
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): # (1,128,14,14)
x = self.attn_downsample(x) # (1,256,7,7)
x = x + self.drop_path(self.mlp(x)) # (1,256,7,7)
# 文中说 "Due to the change in scale, this attention block is used without a residual connection",这里是经过attn_downsample
# 之后的x与mlp的输出做residual,而不是一开始输入x与mlp的输出做residual
return x
接下来我们再看AttentionDownsample的实现,首先key_dim=16,输入通道数为128,则有128/16=8个head,而 \(V\) 的维度是key的4倍即16x4=64。下面的代码进行了详细的注释,可以对照图5右侧看,需要提一下图5右侧中attention bias的维度是 N x (HW x HW),实际上应该是 N x (HW/4, HW)。下面的具体实现就是普通的attention,其中有两点区别,一个是对 \(Q\) 进行了降采样,这里也要提一下,原文说的是"subsampling"而不是"downsampling",代码中的降采样是一个kernel_size=1,stride=2的池化层,因为核大小为1所以实际上是每隔stride取一个值,而不是像通常的池化那样取 (k, k)中的均值或最大值。另一点就是位置编码用attention bias表示,并直接与attention map相加。
class AttentionDownsample(nn.Module):
attention_bias_cache: Dict[str, torch.Tensor]
def __init__(
self,
in_dim,
out_dim,
key_dim,
num_heads=8,
attn_ratio=2.0,
stride=2,
resolution=14,
use_conv=False,
use_pool=False,
act_layer=nn.SiLU,
):
super().__init__()
resolution = to_2tuple(resolution)
self.stride = stride # 2
self.resolution = resolution
self.num_heads = num_heads # 8
self.key_dim = key_dim # 16
self.key_attn_dim = key_dim * num_heads # 16x8=128
self.val_dim = int(attn_ratio * key_dim) # 4 * 16 = 64, "For downsampling layers, ..., we set the dimension of V to 4D"
self.val_attn_dim = self.val_dim * self.num_heads # 64x8=512
self.scale = key_dim ** -0.5
self.use_conv = use_conv
if self.use_conv:
ln_layer = ConvNorm # 用1x1卷积代替FC
sub_layer = partial(
nn.AvgPool2d,
kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
else:
ln_layer = LinearNorm
sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool)
self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim) # 128, 512+128=640
self.q = nn.Sequential(OrderedDict([
('down', sub_layer(stride=stride)), # a subsampling is applied before the Q transformation
# 注意这里AvgPool2d(kernel_size=1, stride=2, padding=0)中的kernel大小为1,相当于隔一个像素取一个值,而不是真正的平均池化
('ln', ln_layer(in_dim, self.key_attn_dim)) # 128,128
]))
self.proj = nn.Sequential(OrderedDict([
('act', act_layer()),
# "We apply a Hardswish activation to the product A^hV before the regular linear # projection is used to
# combine the output of the different heads"
('ln', ln_layer(self.val_attn_dim, out_dim))
]))
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1])) # (8,196)
k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1) # (2,196)
# tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
# 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2,
# 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
# 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5,
# 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
# 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
# 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
# 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10,
# 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11,
# 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
# 12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13],
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 3,
# 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 3, 4, 5, 6, 7,
# 8, 9, 10, 11, 12, 13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 1,
# 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 3, 4, 5,
# 6, 7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
# 10, 11, 12, 13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
# 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 3,
# 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0, 1, 2, 3, 4, 5, 6, 7,
# 8, 9, 10, 11, 12, 13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
# 12, 13, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]])
q_pos = torch.stack(ndgrid(
torch.arange(0, resolution[0], step=stride),
torch.arange(0, resolution[1], step=stride)
)).flatten(1) # (2,49)
# tensor([[ 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4,
# 4, 4, 4, 6, 6, 6, 6, 6, 6, 6, 8, 8, 8, 8, 8, 8, 8, 10,
# 10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 12, 12],
# [ 0, 2, 4, 6, 8, 10, 12, 0, 2, 4, 6, 8, 10, 12, 0, 2, 4, 6,
# 8, 10, 12, 0, 2, 4, 6, 8, 10, 12, 0, 2, 4, 6, 8, 10, 12, 0,
# 2, 4, 6, 8, 10, 12, 0, 2, 4, 6, 8, 10, 12]])
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs() # (2,49,1) - (2,1,196) -> (2,49,196)
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1] # (49,196),我他妈懂了,这里相当于将二维展平求两个点之间的距离,即y坐标的差表示差了几行,乘以每行的像素点数即宽即这里的resolution[1],然后再加上x坐标的差
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False) # 通过register_buffer注册的张量不会被优化器更新。
# persistent控制缓冲区是否在模型的状态字典(state dictionary)中保存
self.attention_bias_cache = {} # per-device attention_biases cache
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.attention_bias_cache:
self.attention_bias_cache = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if torch.jit.is_tracing() or self.training:
# attention_biases是模型学习到的, 而attention_bias_idxs是两点之间的offset是存在缓存区的不会随网络进行更新,表示对于相对位置相同的pairs的bias是相同的,
# 比如(1,2)(3,4)与(11,2)(13,4)的offset都是(2,2),因此这两对的bias是相同的
# https://github.com/facebookresearch/LeViT/issues/9
return self.attention_biases[:, self.attention_bias_idxs] # (8,196),(49,196)
else:
device_key = str(device)
if device_key not in self.attention_bias_cache:
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.attention_bias_cache[device_key]
def forward(self, x):
if self.use_conv:
B, C, H, W = x.shape # (1,128,14,14)
HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1 # 7,7
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
# key_dim=D=16, value_dim=4D=64, num_heads=8, (16+64)x8=640
# (1,640,14,14)->(1,8,80,196) -> (1,8,16,196), (1,8,64,196)
q = self.q(x).view(B, self.num_heads, self.key_dim, -1) # (1,128,7,7)->(1,8,16,49), q进行了下采样
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) # + (8,49,196)
# 这里q和k的seq_len不同,因为q进行了下采样,一个是7x7=49,一个是14x14=196。但dim一样都是16
# (1,8,49,16) @ (1,8,16,196) -> (1,8,49,196)
# 论文中图5b的attention bias的维度不对,不是(HWxHW),应该是(HW/4xHW)
attn = attn.softmax(dim=-1) # (1,8,49,196)
x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
# (1,8,64,196) @ (1,8,196,49) -> (1,8,64,49) -> (1,64x8,7,7)
else:
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
k = k.permute(0, 2, 3, 1) # BHCN
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
attn = q @ k * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
x = self.proj(x) # (1,256,7,7)
return x
关于attention bias详细解释一下,代码中的self.attention_biases是通过nn.Parameter定义的,是随着网络学习到的。而self.attention_bias_idxs是通过register_buffer注册到缓存区内的常量,在整个训练过程中保持不变。因此对q进行了降采样大小为 (7, 7),而k保持原始大小 (14, 14),attention_bias_idxs表示q的任意一点与k中任意一点的偏差,比如q中点 (1, 3) 与k中点 (11, 5)的offset为 (10, 2)。作者在文中提到到两点的位置偏差相同时,它们的bias也是相同的,比如 (3,5) 和 (13, 7) 的偏差也是 (10, 2),则这一对点之间的bias和上一对的bias值相等,具体的大小是网络学习到的。
对于特征图(7, 7)的q和(14, 14)的k,任意两点之间的offset的矩阵的维度应该是(2, 49, 196),其中2表示x坐标和y坐标,但这样就有很多重复的,因为上面提到只有两点的offset相同则bias也是相同的。在官方实现中是通过一个字典的key来保证offset的唯一性,如下所示
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
在timm的实现中,在得到了(2, 49, 196)的rel_pos后还有一行如下,这里是将两点之间的xy坐标的偏差转换为两点之间按照先行后列的像素距离,比如(1, 2)和(5, 6)之间的距离为(6-2) * 14+(5-1)=60,表示点(1,2)按照先行后列的顺序移动60个像素到达(5,6)位置处,这样rel_pos的维度就从(2, 49, 196)变成了(49, 196)。
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
至此,AttentionDownsample就讲完了。而普通的attention和这里类似,区别一个是q不进行降采样,则attention_bias_idx的大小是(196, 196)。以及v的维度是2D而不是4D。还有就是因为没有进行降采样,在attention部分加上了residual connection。
实验结果
和其它模型的speed-accuracy tradeoff对比如表3所示,LeViT-384和DeiT-small的精度相当,但FLOPs只有后者的一半。LeViT-128和DeiT-tiny的精度相当,FLOPs只有后者的1/4。
和其它SOTA模型的对比
标签:dim,ICCV,10,self,attention,LeViT,2021,attn,12 From: https://blog.csdn.net/ooooocj/article/details/139482058