首页 > 其他分享 >ResT(NeurIPS 2021)论文解读

ResT(NeurIPS 2021)论文解读

时间:2024-07-26 19:24:38浏览次数:14  
标签:Transformer 卷积 ResT times token 2021 NeurIPS 输入

paper:ResT: An Efficient Transformer for Visual Recognition

official implementation:https://github.com/wofmanaf/ResT

存在的问题

作者指出现有Transformer在视觉识别任务中存在以下几个问题:

  1. 低级特征提取困难:Transformer直接对原始输入图像的patch进行标记化处理,这使得它难以提取构成图像基础结构的低级特征,例如角落和边缘。
  2. 内存和计算资源的二次方增长:Transformer block中的多头自注意力(MSA)在空间或嵌入维度(即通道数)上的内存和计算需求随着输入标记的增加而呈二次方增长,导致训练和推理过程中的大量开销。
  3. 多头注意力(MHSA)的性能受限:MHSA中每个头只负责部分嵌入维度,这可能会损害网络的性能,特别是当每个头的标记嵌入维度较小时,query和key的点积可能无法构成一个信息丰富的函数。
  4. 固定尺度的输入token和位置编码:现有Transformer骨干网络中的输入token和位置编码都是固定尺度的,这不适合需要密集预测的视觉任务,这些任务通常需要多尺度的特征图表示。

创新点

为了解决上述问题,作者提出了以下改进:

  1. 作者在ResT中设计了一种重叠卷积操作的patch embedding模块,替代了传统的tokenization方法。这种嵌入方法可以更有效地捕捉低级特征信息(如边缘和角点),提高了特征提取的能力。
  2. 作者提出了内存高效的多头自注意力(EMSA)模块,通过简单的深度卷积操作压缩空间维度,从而减少计算成本。
  3. 作者在EMSA中引入了跨头维度的交互投影,从而保持多头注意力的多样性和信息丰富性。通过在注意力head之间投射交互,EMSA能够使每个head依赖于所有的key和query,从而增强了网络的性能。同时,加入了实例归一化(IN)操作以恢复多头的多样性能力。
  4. 作者设计了一种新的灵活位置编码方法,称为空间注意力位置编码(PA),使其能够处理不同大小的输入图像。PA模块采用3×3深度卷积操作来获取像素级权重,并通过sigmoid函数进行缩放。这种方法使得位置编码更加灵活,可以适应不同大小的输入图像,而无需插值或微调。

方法介绍

Efficient Transformer Block

为了解决上述Transformer的两个缺点:一是计算量随输入token的增加呈二次方增长关系,二是每个head只负责一个embedding维度的子集,可能会影响网络性能。作者提出了一个高效的多头自注意力,如图3所示,具体如下:

(1)和MSA一样,EMSA首先通过一组投影得到query \(\mathbf{Q}\)

(2)为了减少内存,2D输入token \(\mathrm{x} \in \mathbb{R}^{n \times d_m}\) 沿空间维度reshape回3D得到 \(\mathrm{\hat{x}}\in\mathbb{R}^{d_m\times h\times w}\),然后通过一个深度卷积将宽高降低一个比例 \(s\),则卷积的kernel size、stride和padding分别为 \(s+1,s,s/2\)。

(3)新的token map \(\mathrm{\hat{x}}\in\mathbb{R}^{d_m\times h/s\times w/s}\) 再reshape回2D,即 \(\mathrm{\hat{x}}\in\mathbb{R}^{n'\times d_m},n'=h/s\times w/s\)。然后 \(\mathrm{\hat{x}}\) 再通过两组投影得到key \(\mathbf{K}\) 和 value \(\mathbf{V}\)。

(4)然后再通过下式计算注意力

  

这里 \(Conv(\cdot)\) 是一个标准的1x1卷积,用来建模不同head之间的交互。这样每个head的注意力都可以依赖于所有的keys和queries。但是这会削弱MSA联合关注来自不同位置不同表示子集信息的能力。为了恢复这种多样性能力,作者为点积矩阵(softmax之后)添加了一个instance normalization,即 \(IN(\cdot)\)。

(5)最后,每个head的输出拼接起来并通过一个线性投影得到最终输出。

和传统的Transforme block一样,作者还加入了FFN和residual connection,完整的efficient Transformer block表示如下

Patch Embedding

传统的Transformer接收token序列作为输入,以ViT为例,输入被切分成不重叠的patch,这些patches展平成2D并通过线性投影得到固定的维度。但是这种直接的tokenization未能捕捉低级特征信息例如边缘和角点。此外不同block中token的长度是固定的,使得它不适合需要多尺度特征表示的下游任务如检测和分割。

本文提出了一个新的多尺度骨干网络ResT用于密集预测任务,在每个stage通过一个patch embedding module来逐步减低空间分辨率,增大通道数。

为了有效的用少量参数捕捉低级特征,作者引入了一种简单有效的方法,即堆叠三个3x3卷积,padding都为1,步长分别为2、1、2。前两层还加入了BN和ReLU。在stage2、3、4中,patch embedding module用来将空间维度降低4x并将通道维度增加2x。如图2所示

Position Encoding

位置编码是利用序列顺序的关键。ViT中一组可学习的参数被加入到输入token中来编码位置,设 \(\mathrm{x}\in\mathbb{R}^{n\times c}\) 表示输入,\(\theta\in\mathbb{R}^{n\times c}\) 表示位置参数,则编码的输入可以表示如下

但是位置编码的长度和输入token的长度必须完全相等,这限制了实际应用场景。

式(6)中的求和操作可以看做是为输入分配一个pixel-wise的权重,假设 \(\theta\) 与输入相关,即 \(\theta=GL(x)\),其中 \(GL(\cdot)\) 是group linear操作,group数为 \(c\)。这样式(6)就可以修改如下

除了式(7),\(\theta\) 可以通过更灵活的空间注意力机制得到。这里作者提出了一个简单有效的空间注意力模块PA(pixel-attention)来编码位置。PA通过一个3x3深度卷积得到pixel-wise权重然后通过一个sigmoid函数 \(\sigma(\cdot)\) 进行缩放。用PA模块编码位置信息可以表示如下 

由于每个stage的输入token也可以通过卷积获得,因此我们可以将位置编码嵌入到patch embedding模块模块中,stage \(i\) 的整体结构如图4所示 

网络最后通过一个全局平均池化和一个线性分类层得到预测结果。作者设计了四种不同大小的ResT,具体结构参数如表1所示

实验结果

和其它模型在ImageNet上的结果如表2所示,可以看到ResT取得了显著的改进。例如对于较小的模型,ResT-Small(79.6%)超过了复杂度相似的PVT-T(75.1%)4.5%。对于较大的模型,ResT-Base(81.6%)超过了Swin-T(81.3%)0.3%,ResT-Large(83.6%)超过了Swin-S(83.3%)0.3%。

在COCO数据集上目标检测的结果如下表所示,其中检测模型采用RetinaNet。可以看到,相同的计算成本下,ResT-Small的AP比PVT-T高了3.6(40.3 vs. 36.7)。对于更大的模型,ResT-Base超过了PVT-S 1.6 AP。

在COCO数据集上的实例分割结果如表4所示,其中模型选择了Mask R-CNN,可以看到对应大小的ResT也超过了PVT和ResNet。

 

标签:Transformer,卷积,ResT,times,token,2021,NeurIPS,输入
From: https://blog.csdn.net/ooooocj/article/details/140711466

相关文章

  • 我如何在 Django Rest 框架中过滤当前用户的查询集
    classSalonCarDetailsSerializer(serializers.ModelSerializer):salon=PrimaryKeyRelatedField(queryset=Salon.objects.filter(owner=?))classMeta:model=SalonCarDetailsfields=["salon","car","price&qu......
  • 如何使用REST查询sys_user表?
    我正在使用PyPi的servicenowv2.0.1与ServiceNow交互。我需要能够在sys_user表中查找用户,以便找到分配给他们的sys_id。如果有人有一些示例代码,他们可以分享,我们将不胜感激。我尝试使用以下内容但没有成功:fromservicenowimportConnectionfromservicen......
  • 使用 Python 构建一个简单的 REST API
    使用Python构建一个简单的RESTAPI简介本文档将引导您使用Python和Flask框架构建一个简单的RESTAPI。我们将创建一个API,用于管理一个虚拟的书籍数据库。准备工作Python环境:确保您的系统上安装了Python3.x。Flask框架:使用pip安装Flask:pipinstallFla......
  • 如何从 Firebase 保存和显示 Firestore 数据库中的图像
    我正在使用Flask使用Python编写一个用于IT研究的应用程序。我使用FirestoreDatabase作为数据库。一切都很好,但我想知道是否可以将照片保存到给定的集合并从网站上的集合中读取/显示这张照片?我的端点可以更改我想要的用户设置上传这张照片,当然,在正确上传显示之后:@b......
  • Velero backup and restore k8s cluster
    Velero部署及使用示例Velero是用于备份和恢复Kubernetes集群资源和PV的开源项目。基于VeleroCRD创建备份(Backup)和恢复作业(Restore)可以备份或恢复集群中的几乎所有对象,也可以按类型、名称空间或标签过滤对象支持基于文件系统备份(FileSystemBackup,简称FSB)备份Pod卷中的数......
  • ResNet strikes back(NeurIPS 2021,Meta)论文解读
    paper:ResNetstrikesback:Animprovedtrainingprocedureintimmofficialimplementation:https://github.com/huggingface/pytorch-image-models背景ResNet(残差网络)架构自He等人引入以来,一直在各种科学出版物中占据重要地位,并作为新模型的基准。然而,自2015年ResNet问世......
  • @RestController注解
    1.引言在现代的JavaWeb开发中,Spring框架因其简洁、高效和强大的功能而受到广泛欢迎。SpringMVC是Spring框架的一个重要组成部分,用于构建Web应用程序。@RestController注解是SpringMVC提供的一个关键注解,用于简化RESTfulWeb服务的开发。本文将详细讲解@RestController......
  • 在 Gerrit 的 REST API 中,如何查找补丁集 ID 值?我有 url.../details Json 但在那里找
    我正在使用request.get来获取Json,并想使用最新的补丁集和扩展名/revisions/(patchSetNumber)/files再次执行此操作以查找所有修改的文件。我无法弄清楚如何通过请求找到补丁集ID。我尝试通过url/details扩展进行搜索,但无法找到修订选项卡或补丁集选项卡是对的,Ger......
  • DRF入门规范,API接口,接口测试工具,restful规范,序列化和反序列化,drf安装和快速使用
    ⅠDRF入门规范【一】Web应用模式在开发Web应用中,有两种应用模式:【1】前后端不分离【2】前后端分离【3】前后端开发模式#1前后端混合开发-不少公司在用-flask混合-django混合-例如最简单的bbs项目-模板:dtl语法:djangotemplatelanguage模板语......
  • RestSharp编写api接口测试,并实现异步调用(不卡顿)
    首先,确保你已经安装了RestSharpNuGet包。如果没有安装,可以通过以下命令安装:bashInstall-PackageRestSharp然后,在你的C#代码中,你可以按照以下步骤操作:引用RestSharp命名空间。创建一个RestClient实例。创建一个RestRequest实例,并设置请求方法和URL。执行异步POST请求。......