首页 > 其他分享 >LeViT(ICCV 2021)原理与代码解析

LeViT(ICCV 2021)原理与代码解析

时间:2024-06-05 21:28:59浏览次数:12  
标签:dim ICCV 10 self attention LeViT 2021 attn 12

paper:LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference

official implementation:https://github.com/facebookresearch/LeViT

third-party implementation:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/levit.py

本文的创新点

本文旨在设计一种新的图像分类架构,通过结合卷积网络的优势和转换器的优点,优化高效推理时的精度和速度权衡。具体如下

  1. 多阶段Transformer架构:提出了一种多阶段Transformer架构,使用注意力机制进行下采样。这种设计类似于传统卷积网络中的金字塔结构,使得特征图的分辨率逐步降低,提高了计算效率。
  2. 高效的patch descriptor:在模型的前几层中引入了计算效率高的Patch描述符,减少了特征数量,从而提升了网络的整体计算效率。
  3. 注意力偏置:引入了一种新的attention bias,代替了视觉转换器中的位置编码(Positional Encoding),实现了平移不变的空间信息编码,提升了模型的表现。
  4. 重新设计的Attention-MLP block:重新设计了注意力-MLP模块,提高了网络在给定计算时间内的容量,使得网络可以在相同的计算资源下获得更好的表现。
  5. 混合架构设计:通过结合卷积和Transformer的混合架构(grafting experiments),在相同的计算budget下实现了更好的精度和速度权衡。这种混合架构在训练初期表现出与卷积网络相似的快速收敛特性,同时在后期表现出Transformer的高精度。

总体来说,本文通过引入这些创新,提出了一种名为LeViT的混合神经网络,在ImageNet数据集上的实验结果表明,该模型在推理速度和精度上显著优于现有的卷积网络和视觉Transformer。

方法介绍

Vision Transformer中的patch projection层通过16x16 stride=16的卷积实现,引发了作者对卷积与Transformer之间联系的思考。在卷积中,mask的空间平滑性来自于卷积过程中卷积核的重叠:临近的像素接收到相似的梯度。而在ViT中平滑掩膜可能是由于数据增强造成的,当一个图像出现两次且发生微小的平移时,相同的梯度经过每个filter,所以它可以学习这种空间平滑性。因此尽管在Transformer架构中没有归纳偏置"inductive bias",训练确实产生了类似卷积层的filter。

作者首先用ResNet-50和DeiT-S进行了一个嫁接实验,结果如表1所示。

可以看到嫁接的结构比单独的ResNet-50和DeiT-S的效果都要好,其中精度最高同时参数量最小的组合是与两个stage的ResNet-50进行嫁接。

 

一个有趣的观察如图3所示,嫁接模型在训练早期的收敛性和卷积网络类似,然后切换到类似于DeiT-S的收敛速度。一种假设是,卷积由于其本身的inductive bias能力(平移不变性),使其可以在网络的浅层更有效的学习low-level information,它们快速找到有意义的patch embedding,这可以解释为什么在第一个epoch可以快速收敛。

基于上述观察, 作者认为在transformer下面插入卷积stage是有益的,大部分的处理仍然是在后续堆叠的transformer block中实现的,以获得嫁接结构精度最高的变体。因此接下来作者重点研究了如何降低transformer的计算成本,以及如何与卷积更紧密地结合而不仅仅是嫁接起来。

LeViT的完整结构如图4所示。具体的设计原则如下

Patch embedding

在LeViT中,作者采用4层stride=2的3x3卷积进行分辨率的下采样。对于(3, 224, 224)的输入,经过4层卷积后得到维度为(256, 14, 14)的输出进入接下来的transformer block中。

No classification token

为了使用BCHW的张量形式,作者删除了分类token,而是和卷积网络一样,在最后一个特征图上用全局平均池化来得到分类器用的embedding。对于蒸馏,分别训练不同的head进行分类和蒸馏任务。测试时,取这两个head输出的平均值。

Normalization layera and activations

ViT中的FC层等价于1x1卷积。ViT在每个attention层和mlp前都是用了LN。而对于LeViT,每个卷积后都加一个BN。DeiT的激活函数使用了GELU,而在LeViT中所有激活函数都采用Hardswish。

Multi-resolution pyramid

因为LeViT前面嫁接了ResNet的部分stage,因此形成了和卷积网络一样的金字塔结构,特征图的分辨率随着通道数的增加而降低。

下面是对attention block进行的一些修改,如图5所示。

Downsampling

在LeViT的stage之间, 通过一个shrinking attention block来减小激活图的大小,如图5右侧所示。具体来说,在Q变换之前进行一个subsampling,这将大小为 \((C,H,W)\) 的tensor映射为大小为 \((C',H/2,W/2)\) 的输出tensor,其中 \(C'>C\)。由于大小的变化,这个attention block没有使用residual connection。为了防止信息的损失,将attention heads的数量设置为 \(C/D\)。

Attention bias instead of a position embedding

之前的位置编码只包含在attention block的输入序列中,由于位置编码对于更高的层也很重要,所以作者的目标是在每个attention block中都提供位置信息,并显式地在注意力机制中注入相对位置信息:具体是通过在attention map上加上一个attention bias来实现的。两个像素 \((x,y)\in[H]\times [W]\) 和 \((x',y')\in[H]\times [W]\) 之间一个head \(h\in [N]\) 的标量attention value按下式计算

其中第一项就是普通的attention,第二项是平移不变的attention bias,每个head都有对应于不同像素offset的 \(H\times W\) 个参数。取绝对值是鼓励网络以flip invariance的方式进行训练。

Smaller keys

bias项减少了key对位置信息编码的压力,所以相比于 \(V\) 我们减小了key矩阵的大小。如果key的维度 \(D\in\{16,32\}\),\(V\) 的通道数为 \(2D\)。限制key的大小减少了计算 \(QK^T\) 的时间。

对于降采样层,其中没有residual connection,我们将 \(V\) 的维度设置为 \(4D\) 来防止信息损失。

Attention activation

在使用线性映射来组合不同head的输出之前,我们对乘积 \(A^hV\) 应用一个Hardswish。

Reducing the MLP blocks

通常ViT的MLP隐藏层维度的expansion ratio设置为4,对于LeViT,MLP由一个1x1卷积和一个BN组成,为了降低计算量,我们将expansion factor由4降低为2。


至此LeViT中的所有改进都介绍完了,在不同的计算量限制下,LeViT有一系列不同的变体,本文通过输入到第一个transformer block的通道数来进行区分,比如LeViT-256表示第一个transformer block的输入通道数为256。表2展示了不同LeViT变体的具体配置。

代码解析

这里以timm中的实现为例介绍一下代码。选择的模型是'levit_conv_128s',具体配置如下,可以看到和表2的第一列是对应的。 

levit_128s=dict(
        embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),  # 对应表2,key_dim就是D

首先是stem,如上所述就是4个3x3-s2的卷积,且每个卷积后都跟一个激活函数,上面也讲过,本文所有的激活函数都采用Hardswish

class Stem16(nn.Sequential):
    def __init__(self, in_chs, out_chs, act_layer):
        super().__init__()
        self.stride = 16

        self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1))
        self.add_module('act1', act_layer())
        self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1))
        self.add_module('act2', act_layer())
        self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
        self.add_module('act3', act_layer())
        self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))

然后就是transformer stage部分,定义如下。其中embed_dim=(128, 256, 384)表示stem的输出即transformer stage的输入的通道数为128。key_dim=16就是论文中的 \(D\)。attn_ratio都是2对应【smaller keys】部分提到的 “V的通道数为2D”。mlp_ratio也都是2对应【Reducing the MLP blocks】部分提到的将隐藏层的expansion ratio由4改为2。

in_dim = embed_dim[0]
stages = []
for i in range(num_stages):
    stage_stride = 2 if i > 0 else 1
    stages += [LevitStage(
        in_dim,
        embed_dim[i],
        key_dim,  # 16
        depth=depth[i],  # (2,3,4)
        num_heads=num_heads[i],  # (4,6,8)
        attn_ratio=attn_ratio[i],  # (2.0, 2.0, 2.0)
        mlp_ratio=mlp_ratio[i],  # (2.0, 2.0, 2.0)
        act_layer=act_layer,  # 'Hardswish'
        attn_act_layer=attn_act_layer,  # 'Hardswish'
        resolution=resolution,  # (14,14)
        use_conv=use_conv,  # True
        downsample=down_op if stage_stride == 2 else '',  # 'subsample'
        drop_path=drop_path_rate  # 0.0
    )]
    stride *= stage_stride
    resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
    self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
    in_dim = embed_dim[i]
self.stages = nn.Sequential(*stages)

LevitStage的定义如下,可以看到由一个downsample和blocks组成。先来看downsample部分,首先是num_heads=in_dim // key_dim=128//16=8对应【Downsampling】部分提到的“将attention heads的数量设置为 \(C/D\)”。然后attn_ratio=4对应【smaller keys】部分说的“降采样层,将 \(V\)的维度设置为4D”。

class LevitStage(nn.Module):
    def __init__(
            self,
            in_dim,
            out_dim,
            key_dim,
            depth=4,
            num_heads=8,
            attn_ratio=4.0,
            mlp_ratio=4.0,
            act_layer=nn.SiLU,
            attn_act_layer=None,
            resolution=14,
            downsample='',
            use_conv=False,
            drop_path=0.,
    ):
        super().__init__()
        resolution = to_2tuple(resolution)

        if downsample:
            self.downsample = LevitDownsample(
                in_dim,  # 128
                out_dim,  # 256
                key_dim=key_dim,  # 16
                num_heads=in_dim // key_dim,  # 128//16=8, 这里就是C/D
                attn_ratio=4.,  # 这里对应的是"we set the dimension of V to 4D to prevent loss of information."
                mlp_ratio=2.,  # Reducing the MLP blocks. we reduce the expansion factor of the convolution from 4 to 2.
                act_layer=act_layer,
                attn_act_layer=attn_act_layer,
                resolution=resolution,
                use_conv=use_conv,
                drop_path=drop_path,
            )
            resolution = [(r - 1) // 2 + 1 for r in resolution]
        else:
            assert in_dim == out_dim
            self.downsample = nn.Identity()

        blocks = []
        for _ in range(depth):
            blocks += [LevitBlock(
                out_dim,
                key_dim,
                num_heads=num_heads,
                attn_ratio=attn_ratio,  # 2, 这里对应的是"V will have 2D channels"
                mlp_ratio=mlp_ratio,
                act_layer=act_layer,
                attn_act_layer=attn_act_layer,
                resolution=resolution,
                use_conv=use_conv,
                drop_path=drop_path,
            )]
        self.blocks = nn.Sequential(*blocks)

    def forward(self, x):
        x = self.downsample(x)
        x = self.blocks(x)
        return x

下面是降采样的具体实现,其中通过attention层来进行降采样上面也提过,对应图5右侧。在forward函数中可以看到attn_downsample部分没有进行residual connection,只在后面的mlp部分进行了残差连接。

class LevitDownsample(nn.Module):
    def __init__(
            self,
            in_dim,
            out_dim,
            key_dim,
            num_heads=8,
            attn_ratio=4.,
            mlp_ratio=2.,
            act_layer=nn.SiLU,
            attn_act_layer=None,
            resolution=14,
            use_conv=False,
            use_pool=False,
            drop_path=0.,
    ):
        super().__init__()
        attn_act_layer = attn_act_layer or act_layer

        self.attn_downsample = AttentionDownsample(
            in_dim=in_dim,
            out_dim=out_dim,
            key_dim=key_dim,
            num_heads=num_heads,
            attn_ratio=attn_ratio,
            act_layer=attn_act_layer,
            resolution=resolution,
            use_conv=use_conv,
            use_pool=use_pool,
        )

        self.mlp = LevitMlp(
            out_dim,
            int(out_dim * mlp_ratio),
            use_conv=use_conv,
            act_layer=act_layer
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x):  # (1,128,14,14)
        x = self.attn_downsample(x)  # (1,256,7,7)
        x = x + self.drop_path(self.mlp(x))  # (1,256,7,7)
        # 文中说 "Due to the change in scale, this attention block is used without a residual connection",这里是经过attn_downsample
        # 之后的x与mlp的输出做residual,而不是一开始输入x与mlp的输出做residual
        return x

接下来我们再看AttentionDownsample的实现,首先key_dim=16,输入通道数为128,则有128/16=8个head,而 \(V\) 的维度是key的4倍即16x4=64。下面的代码进行了详细的注释,可以对照图5右侧看,需要提一下图5右侧中attention bias的维度是 N x (HW x HW),实际上应该是 N x (HW/4, HW)。下面的具体实现就是普通的attention,其中有两点区别,一个是对 \(Q\) 进行了降采样,这里也要提一下,原文说的是"subsampling"而不是"downsampling",代码中的降采样是一个kernel_size=1,stride=2的池化层,因为核大小为1所以实际上是每隔stride取一个值,而不是像通常的池化那样取 (k, k)中的均值或最大值。另一点就是位置编码用attention bias表示,并直接与attention map相加。

class AttentionDownsample(nn.Module):
    attention_bias_cache: Dict[str, torch.Tensor]

    def __init__(
            self,
            in_dim,
            out_dim,
            key_dim,
            num_heads=8,
            attn_ratio=2.0,
            stride=2,
            resolution=14,
            use_conv=False,
            use_pool=False,
            act_layer=nn.SiLU,
    ):
        super().__init__()
        resolution = to_2tuple(resolution)

        self.stride = stride  # 2
        self.resolution = resolution
        self.num_heads = num_heads  # 8
        self.key_dim = key_dim  # 16
        self.key_attn_dim = key_dim * num_heads  # 16x8=128
        self.val_dim = int(attn_ratio * key_dim)  # 4 * 16 = 64, "For downsampling layers, ..., we set the dimension of V to 4D"
        self.val_attn_dim = self.val_dim * self.num_heads  # 64x8=512
        self.scale = key_dim ** -0.5
        self.use_conv = use_conv

        if self.use_conv:
            ln_layer = ConvNorm  # 用1x1卷积代替FC
            sub_layer = partial(
                nn.AvgPool2d,
                kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
        else:
            ln_layer = LinearNorm
            sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool)

        self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim)  # 128, 512+128=640
        self.q = nn.Sequential(OrderedDict([
            ('down', sub_layer(stride=stride)),  # a subsampling is applied before the Q transformation
            # 注意这里AvgPool2d(kernel_size=1, stride=2, padding=0)中的kernel大小为1,相当于隔一个像素取一个值,而不是真正的平均池化
            ('ln', ln_layer(in_dim, self.key_attn_dim))  # 128,128
        ]))
        self.proj = nn.Sequential(OrderedDict([
            ('act', act_layer()),
            # "We apply a Hardswish activation to the product A^hV before the regular linear # projection is used to
            # combine the output of the different heads"
            ('ln', ln_layer(self.val_attn_dim, out_dim))
        ]))

        self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))  # (8,196)
        k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)  # (2,196)
        # tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,
        #           1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,
        #           2,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,
        #           3,  3,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,
        #           5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,
        #           6,  6,  6,  6,  6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
        #           7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
        #           9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10,
        #          10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11,
        #          11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
        #          12, 12, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13],
        #         [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  0,  1,  2,  3,
        #           4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  0,  1,  2,  3,  4,  5,  6,  7,
        #           8,  9, 10, 11, 12, 13,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
        #          12, 13,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  0,  1,
        #           2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  0,  1,  2,  3,  4,  5,
        #           6,  7,  8,  9, 10, 11, 12, 13,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9,
        #          10, 11, 12, 13,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,
        #           0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  0,  1,  2,  3,
        #           4,  5,  6,  7,  8,  9, 10, 11, 12, 13,  0,  1,  2,  3,  4,  5,  6,  7,
        #           8,  9, 10, 11, 12, 13,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11,
        #          12, 13,  0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13]])
        q_pos = torch.stack(ndgrid(
            torch.arange(0, resolution[0], step=stride),
            torch.arange(0, resolution[1], step=stride)
        )).flatten(1)  # (2,49)
        # tensor([[ 0,  0,  0,  0,  0,  0,  0,  2,  2,  2,  2,  2,  2,  2,  4,  4,  4,  4,
        #           4,  4,  4,  6,  6,  6,  6,  6,  6,  6,  8,  8,  8,  8,  8,  8,  8, 10,
        #          10, 10, 10, 10, 10, 10, 12, 12, 12, 12, 12, 12, 12],
        #         [ 0,  2,  4,  6,  8, 10, 12,  0,  2,  4,  6,  8, 10, 12,  0,  2,  4,  6,
        #           8, 10, 12,  0,  2,  4,  6,  8, 10, 12,  0,  2,  4,  6,  8, 10, 12,  0,
        #           2,  4,  6,  8, 10, 12,  0,  2,  4,  6,  8, 10, 12]])
        rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()  # (2,49,1) - (2,1,196) -> (2,49,196)
        rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]  # (49,196),我他妈懂了,这里相当于将二维展平求两个点之间的距离,即y坐标的差表示差了几行,乘以每行的像素点数即宽即这里的resolution[1],然后再加上x坐标的差
        self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)  # 通过register_buffer注册的张量不会被优化器更新。
        # persistent控制缓冲区是否在模型的状态字典(state dictionary)中保存

        self.attention_bias_cache = {}  # per-device attention_biases cache

    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        if mode and self.attention_bias_cache:
            self.attention_bias_cache = {}  # clear ab cache

    def get_attention_biases(self, device: torch.device) -> torch.Tensor:
        if torch.jit.is_tracing() or self.training:
            # attention_biases是模型学习到的, 而attention_bias_idxs是两点之间的offset是存在缓存区的不会随网络进行更新,表示对于相对位置相同的pairs的bias是相同的,
            # 比如(1,2)(3,4)与(11,2)(13,4)的offset都是(2,2),因此这两对的bias是相同的
            # https://github.com/facebookresearch/LeViT/issues/9
            return self.attention_biases[:, self.attention_bias_idxs]  # (8,196),(49,196)
        else:
            device_key = str(device)
            if device_key not in self.attention_bias_cache:
                self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
            return self.attention_bias_cache[device_key]

    def forward(self, x):
        if self.use_conv:
            B, C, H, W = x.shape  # (1,128,14,14)
            HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1  # 7,7
            k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
            # key_dim=D=16, value_dim=4D=64, num_heads=8, (16+64)x8=640
            # (1,640,14,14)->(1,8,80,196) -> (1,8,16,196), (1,8,64,196)
            q = self.q(x).view(B, self.num_heads, self.key_dim, -1)  # (1,128,7,7)->(1,8,16,49), q进行了下采样

            attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)  # + (8,49,196)
            # 这里q和k的seq_len不同,因为q进行了下采样,一个是7x7=49,一个是14x14=196。但dim一样都是16
            # (1,8,49,16) @ (1,8,16,196) -> (1,8,49,196)
            # 论文中图5b的attention bias的维度不对,不是(HWxHW),应该是(HW/4xHW)
            attn = attn.softmax(dim=-1)  # (1,8,49,196)

            x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
            # (1,8,64,196) @ (1,8,196,49) -> (1,8,64,49) -> (1,64x8,7,7)
        else:
            B, N, C = x.shape
            k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
            k = k.permute(0, 2, 3, 1)  # BHCN
            v = v.permute(0, 2, 1, 3)  # BHNC
            q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)

            attn = q @ k * self.scale + self.get_attention_biases(x.device)
            attn = attn.softmax(dim=-1)

            x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
        x = self.proj(x)  # (1,256,7,7)
        return x

关于attention bias详细解释一下,代码中的self.attention_biases是通过nn.Parameter定义的,是随着网络学习到的。而self.attention_bias_idxs是通过register_buffer注册到缓存区内的常量,在整个训练过程中保持不变。因此对q进行了降采样大小为 (7, 7),而k保持原始大小 (14, 14),attention_bias_idxs表示q的任意一点与k中任意一点的偏差,比如q中点 (1, 3) 与k中点 (11, 5)的offset为 (10, 2)。作者在文中提到到两点的位置偏差相同时,它们的bias也是相同的,比如 (3,5) 和 (13, 7) 的偏差也是 (10, 2),则这一对点之间的bias和上一对的bias值相等,具体的大小是网络学习到的。

对于特征图(7, 7)的q和(14, 14)的k,任意两点之间的offset的矩阵的维度应该是(2, 49, 196),其中2表示x坐标和y坐标,但这样就有很多重复的,因为上面提到只有两点的offset相同则bias也是相同的。在官方实现中是通过一个字典的key来保证offset的唯一性,如下所示

attention_offsets = {}
idxs = []
for p1 in points:
    for p2 in points:
        offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
        if offset not in attention_offsets:
            attention_offsets[offset] = len(attention_offsets)
        idxs.append(attention_offsets[offset])

在timm的实现中,在得到了(2, 49, 196)的rel_pos后还有一行如下,这里是将两点之间的xy坐标的偏差转换为两点之间按照先行后列的像素距离,比如(1, 2)和(5, 6)之间的距离为(6-2) * 14+(5-1)=60,表示点(1,2)按照先行后列的顺序移动60个像素到达(5,6)位置处,这样rel_pos的维度就从(2, 49, 196)变成了(49, 196)。

rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]

至此,AttentionDownsample就讲完了。而普通的attention和这里类似,区别一个是q不进行降采样,则attention_bias_idx的大小是(196, 196)。以及v的维度是2D而不是4D。还有就是因为没有进行降采样,在attention部分加上了residual connection。

实验结果

和其它模型的speed-accuracy tradeoff对比如表3所示,LeViT-384和DeiT-small的精度相当,但FLOPs只有后者的一半。LeViT-128和DeiT-tiny的精度相当,FLOPs只有后者的1/4。

和其它SOTA模型的对比

 

 

标签:dim,ICCV,10,self,attention,LeViT,2021,attn,12
From: https://blog.csdn.net/ooooocj/article/details/139482058

相关文章

  • 打卡信奥刷题(52)用Scratch图形化工具信奥P7909 [普及组] [CSP-J 2021] 分糖果
    [CSP-J2021]分糖果题目背景红太阳幼儿园的小朋友们开始分糖果啦!题目描述红太阳幼儿园有nnn个小朋友,你是其中之一。保证......
  • P8125 [BalticOI 2021 Day2] The short shank 题解
    首先会发现若\(t_i<=T\)的话那么他最终一定会造反。我们只考虑\(t_i>T\)的情况。设\(lst_i\)表示\(i\)左边第一个可以影响(使他造反)到\(i\)的位置,那么我们一定要在\([lst_i,i]\)这个区间中的某一个位置放上床垫才能使\(i\)不造反。这样有一个\(O(nd)\)的dp,但......
  • 我见我思之hvv偷师学艺——Vmware vcenter未授权任意文件上传(CVE-2021-21972)
    本文为个人整理内容,大部分东西都是参考其它师傅的文章,具体如下:https://blog.csdn.net/qq_37602797/article/details/114109428https://blog.csdn.net/tigerman20201/article/details/129098137常见告警特征:漏洞类型:文件上传。poc利用接口为:/ui/vropspluginui/rest/servic......
  • 2021新书Python程序设计 人工智能案例实践 Python编程人工智能基本描述统计集中趋势和
    书:pan.baidu.com/s/1owku2NBxL7GdW59zEi20AA?pwd=suov​提取码:suov我的阅读笔记:图像识别:使用深度学习框架(如TensorFlow、PyTorch)创建图像分类模型。探索迁移学习,使用预训练模型进行定制。自然语言处理(NLP):构建一个情感分析模型,用于分析文本中的情感。实现一个文本生成模型,......
  • 2024年计算机视觉、设计与算法国际会议( ICCVDA 2024)
    2024年计算机视觉、设计与算法国际会议( ICCVDA2024)会议简介本次大会旨在建立一个国际性的学术交流和合作平台,重点关注计算机视觉领域的最新进展、设计与算法的创新应用,分享前沿研究成果,并探索未来发展趋势。我们诚挚邀请全球各地的学者、专家、企业代表及感兴趣的个人积......
  • 20211215-sdf测试2-openssl
    以下是按照Markdown格式整理的你所需要的代码和操作过程,使用中文描述:任务详情在openEuler(推荐)、Ubuntu或Windows(不推荐)中完成以下任务。参考网内容以及AI给出的详细过程,否则不得分。0.根据gmt0018标准,如何调用接口实现基于SM3求你的学号姓名的SM3值?#include"sd......
  • 2021 NOIP
    廊桥分配1.错误想法:让当前飞机停到右端点最小的廊桥,但是当两个区间右端点都小于当前飞机左端点,选择最小的么?显然不是,应该选择序号最小的廊桥,这样不影响下一个飞机继续放置(左端点从小打到排序的)。这样,当只能有i个廊桥(枚举国内廊桥)的时候,也是可以取得最大值的。最后前缀和。错误......
  • 20211317李卓桐 Exp8 Web安全 实验报告
    Exp8Web安全实验报告实践内容(1)Web前端HTMLWeb前端HTML(2)Web前端javascipt理解JavaScript的基本功能,理解DOM。在(1)的基础上,编写JavaScript验证用户名、密码的规则。在用户点击登陆按钮后回显“欢迎+输入的用户名”尝试注入攻击:利用回显用户名注入HTML及JavaScript。(3......
  • 基于稀疏辅助信号平滑的心电信号降噪方法(Matlab R2021B)
    基于形态成分分析理论(MCA)的稀疏辅助信号分解方法是由信号的形态多样性来分解信号中添加性的混合信号成分,它最早被应用在图像处理领域,后来被引入到一维信号的处理中。在基于MCA稀疏辅助的信号分析模型中,总变差方法TV是其中一个原型,稀疏辅助平滑方法结合并统一了传统的LTI低通滤......
  • HumanEval (2021年)
    HumanEval:Hand-WrittenEvaluationSetHumanEval是一个OpenAI在2021年构造的代码生成LLM评估数据集。数据格式所有数据放在一个json文件中,每条数据包含提示词,有效代码示例,多个测试代码。下面是截取的第一条数据{"task_id":"HumanEval/0","prompt":"fromtypingimport......