首页 > 其他分享 >DeepViT 论文与代码解析

DeepViT 论文与代码解析

时间:2024-08-03 16:25:52浏览次数:11  
标签:DeepViT map attention dim self 论文 attn 解析 注意力

paper:DeepViT: Towards Deeper Vision Transformer

official implementation:https://github.com/zhoudaquan/dvit_repo

出发点

尽管浅层ViTs在视觉任务中表现优异,但随着网络深度增加,性能提升变得困难。研究发现,这种性能饱和的主要原因是注意力崩溃问题,即在深层变压器中,attention map逐渐变得相似,导致feature map在顶层趋于一致,从而限制了模型的表示学习能力。本文旨在研究如何有效地加深ViT模型,并提出了一种新的自注意力机制Re-attention来解决这个问题。

创新点

  • 注意力崩溃问题的提出与分析:首次提出并深入分析了注意力崩溃问题,发现这是导致深层ViT模型性能饱和的主要原因。
  • Re-attention机制:提出了一种简单但有效的Re-attention机制,通过在不同注意力头之间交换信息,以增加不同层的注意力图的多样性。该方法在计算和内存开销上几乎可以忽略不计。
  • 性能提升:通过替换现有ViT模型中的多头自注意力(MHSA)模块,成功训练了具有32个Transformer block的深层ViT模型,在ImageNet上的Top-1分类准确率提高了1.6%。

方法介绍

由于deep CNNs的成功,作者也系统研究了随深度变化ViT性能的变化,其中hidden dimension和head数量分别固定为384和12,然后堆叠不同数量的Transformer block(从12到32),结果如图1所示,可以看到,随着模型深度的增加,分类精度提升缓慢,饱和速度较快,且达到24个block后,性能不再有提升。

之前在CNN中也存在这个问题,但随着残差连接的提出,该问题得到了解决。而ViT和CNN的最大区别就在于self-attention机制,因此作者研究了自注意力或者更具体的说是生成的attention map随着网络深度的增加是如何变化的。作者计算了不同层的attention map之间的相似性来衡量注意力图的变化,如下

 

其中 \(M^{p,q}\) 是 \(p\) 层和 \(q\) 层注意力图之间的余弦相似度矩阵,每个元素 \(M^{p,q}_{h,t}\) 表示head \(h\) 和 token \(t\) 的相似度。

根据式(2),作者在ImageNet上训练了一个包含32个block的ViT,并研究了attention map之间的相似度,结果如图3(a)所示,可以看到,在第17个block之后,注意力图之间的相似度超过了比例超过了90%。这表示后面学习到的attention map是相似度,Transformer block可能退化为一个MLP。

为了理解attention collapse是如何影响ViT的性能的,作者进一步研究了它是如何影响更深层网络的特征学习的。因此作者也绘制出了随网络深度变化feature map之间的相似度变化曲线,如图4(left)所示,可以看到feature map的变化曲线和attention map的变化曲线比较相似,这一结果表明,注意力崩溃是导致ViT模型non-scalable的原因。

 

Re-Attention

在实验过程中,作者发现来自同一block不同head之间的attention map的相似度很小,如图3(c)所示。这表明来自同一自注意力层的不同head关注输入token的不同方面。基于此观察,作者提出建立cross-head通信来重新生成attention map。

具体来说,通过动态地聚合来自不同head的注意力图来生成一组新的注意力图。作者定义了一个可学习的变换矩阵 \(\Theta \in\mathbb{R}^{H\times H}\) 并用它来混合不同head的注意力图,具体如下

其中 \(\Theta\) 和注意力图 \(\mathbf{A}\) 沿head维度相乘,Norm是归一化函数用来减少层之间的方差,\(\Theta\) 是端到端可学习的。

实验结果

如图1所示,在将ViT中的self-attention换成Re-Attention后得到的DeepViT,随着网络深度的增加并没有像ViT那样过早的出现性能饱和,而是继续提升。

如图8(a)所示,Re-Attention的相邻block注意力图的相似度显著降低。

作者定义了DeepViT-S和DeepViT-L,具体配置如下,其中split ratio表示不用Re-Attention和使用Re-Attention的block数的比例,如图3(a)所示,只有在网络的深层注意力图和特征图之间的相似度才会变高,因此没必要在所有层的block中都使用Re-Attention。 

和其它SOTA模型在ImageNet上的性能对比如下所示

 

代码解析

Re-Attention的实现如下,其中 \(\Theta\) 是通过卷积定义的,归一化采用的BN。

class ReAttention(nn.Module):
    """
    It is observed that similarity along same batch of data is extremely large. 
    Thus can reduce the bs dimension when calculating the attention map.
    """
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., expansion_ratio=3,
                 apply_transform=True, transform_scale=False):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.apply_transform = apply_transform
        
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5
        if apply_transform:
            self.reatten_matrix = nn.Conv2d(self.num_heads, self.num_heads, 1, 1)
            self.var_norm = nn.BatchNorm2d(self.num_heads)
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
            self.reatten_scale = self.scale if transform_scale else 1.0
        else:
            self.qkv = nn.Linear(dim, dim * expansion_ratio, bias=qkv_bias)
        
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, atten=None):
        B, N, C = x.shape
        # x = self.fc(x)
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]   # make torchscript happy (cannot use tensor as tuple)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        if self.apply_transform:
            attn = self.var_norm(self.reatten_matrix(attn)) * self.reatten_scale
        attn_next = attn
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x, attn_next

标签:DeepViT,map,attention,dim,self,论文,attn,解析,注意力
From: https://blog.csdn.net/ooooocj/article/details/140867020

相关文章

  • Android开发 - (适配器)ArrayObjectAdapter类与Presenter实现类关联的作用解析
    ListRowPresenterArrayObjectAdapteradapter=newArrayObjectAdapter(newListRowPresenter());用途:用于展示ListRow中的水平滚动列表项ImageCardViewPresenterArrayObjectAdapteradapter=newArrayObjectAdapter(newImageCardViewPresenter());用途:用于显示带......
  • Python中15个递归函数经典案例解析
    1.阶乘计算阶乘是一个常见的递归应用,定义为n!=n*(n-1)*…*1。deffactorial(n):ifn==0:return1else:returnn*factorial(n-1)print(factorial(5))#输出:1202.斐波那契数列斐波那契数列的每一项都......
  • SSM大学生兼职推荐系统4ozlb 本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表系统内容:企业,学生,企业招聘,应聘信息,录取通知,工作评价,防骗指南开题报告内容一、课题背景与意义随着大学生群体对兼职需求的日益增长,传统的兼职信息获取方......
  • 招商银行笔试题目答案解析
    招商银行笔试题目:笔试题库:1.测试流程,静态测试和动态测试的区别?静态测试和动态测试的区别:是否执行代码,执行代码是动态,不执行是静态人工检查2.http协议的标识符有哪些?什么含义,TCP连接两台设备间通过一一连接,TCP报文头部相关的,结构,长度,2层循环计算题,除法取余数运算?UR......
  • ssm+vue服装店管理系统【开题+程序+论文】-计算机毕业设计
    本系统(程序+源码)带文档lw万字以上文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着现代商业的快速发展,传统服装店的管理方式面临着前所未有的挑战。传统的手工记录和简单的电子表格已难以满足日益增长的库存控制、商品追踪及顾客......
  • ssm+vue电影推荐系统【开题+程序+论文】-计算机毕业设计
    本系统(程序+源码)带文档lw万字以上文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着互联网技术的飞速发展,电影作为一种重要的文化娱乐形式,其传播与消费方式发生了深刻变革。在线视频平台如雨后春笋般涌现,为用户提供了海量的电影资......
  • ssm+vue的校园后台报修管理系统设计与实现【开题+程序+论文】-计算机毕业设计
    本系统(程序+源码)带文档lw万字以上文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着教育信息化的不断推进,校园内各类设施设备的数量与复杂度日益增长,如何高效管理这些设备的维护与报修工作成为了学校管理的一大挑战。传统的报修方......
  • ssm+vue高校家教平台【开题+程序+论文】-计算机毕业设计
    本系统(程序+源码)带文档lw万字以上文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着教育行业的蓬勃发展与信息技术的飞速进步,高校家教平台作为连接学生与家长、教师之间的桥梁,其重要性日益凸显。当前,高校学生在学业上寻求额外辅导......
  • Android开发 - ListRow类解析
    ListRow是什么ListRow是AndroidTV开发中的一个类,用于在应用的用户界面中显示水平滚动的项(如卡片、图像等)列表。它通常在一个BrowseFragment或RowsFragment中使用,以组织和显示内容//创建一个BrowseFragment实例BrowseFragmentbrowseFragment=newBrowseFragment......
  • Android开发 - Presenter抽象类解析
    Presenter是什么职责:Presenter的主要职责是管理视图(通常是用户界面组件)的显示和行为它不处理数据的逻辑,而是专注于如何展示数据在Leanback库中的作用:Leanback库是为AndroidTV设计的一个库,提供了一些特殊的UI组件,比如BrowseFragment。Presenter在L......