首页 > 其他分享 >Swin Transformer

Swin Transformer

时间:2023-09-13 15:44:06浏览次数:41  
标签:Transformer Swin SwinT 复杂度 patch 计算 窗口

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows使用移动窗口的分层视觉转换器阅读笔记

摘要:提出Swin Transformer,作为计算机视觉的通用主干网络。将Transformer应用到是视觉领域的挑战就是语言和视觉两个领域的差异。本文提出的分层transformer,它的表征用移动窗口计算,解决这个差异。通过将自注意计算机制限制在非重叠的本地窗口,允许跨窗口连接,移动的窗口方法带来了更高的效率层次结构具有在不同尺度上建模的灵活性,具有相对于图像的线性计算复杂度。分层设计和移动窗口方法被证明对所有mlp体系结构有益。

图一:a.提出的SwinT通过合并更深层次的图像patch构建分层特征映射,只在每个局部窗口内进行自注意力计算,因此具有关于输入图像大小线性计算复杂度,因此可以作为图像分类和密集识别任务的通用主干。b.以往的visionT产生单一低分辨率的特征映射,由于全局自注意力计算,输入图像大小的计算复杂度为2次方

1.介绍:

本文试图扩展transformer的适用性,使得它可以作为视觉的通用骨干,就像在NLP的作用一样。用到视觉领域限制主要是语言和视觉两种模式的差异。

一:涉及规模,语言方面的基本处理元素就是单词标记,单词token规模都是固定的,但是视觉元素的规模可以有很大的差异;

二:图像像素分辨率比文本的分辨率高很多。(自注意力的计算复杂度是图像大小的二次方) 。

为克服这个问题提出SwinT,构造分层特征映射,并具有与图像大小成线性的计算复杂度。SwinT从小的patch开始逐步在更深的transformer层中逐渐合并相邻的patch,构建分层表示。有了分层特征映射,SwinT模型可以方便的利用高级技术进行密集预测。线性计算复杂度是通过在分割图像的非重叠窗口内部计算自我注意来实现每个窗口的patch数量固定,因此复杂度和图像大小呈线性关系

SwinT的关键设计元素是在连续的自注意层之间移动窗口分区,如图二所示。移动的窗口连接上一层的窗口,提供了它们之间的连接,大大增强了建模能力(见表4)。这种策略在实际延迟方面也很有效:一个窗口中的所有patch共享set1,促进了硬件中的内存访问。移窗方法比滑动窗方法有更低的延迟,但在建模能力上是相似的(见表5和表6)。移窗方法也被证明对所有mlp体系结构有利。

图2。Swin Transformer体系结构中用于计算自我注意的移位窗口方法的说明。l层(左)采用常规窗口分区方案,在每个窗口内计算自注意。在下一层l + 1(右)中,窗口分区被移动,产生新的窗口。新窗口中的自我注意计算跨越了前一层窗口的边界,提供了它们之间的连接

2.相关工作:

基于自注意力机制的骨干结构

使用自注意力层来代替ResNet中的部分或者全部空间卷积。自注意力是计算在局部窗口内每一个像素来加快优化,实现了比对应Resnet结构稍好的精度。但是它访问昂贵,实际延迟大于卷积。本文建议在连续的层之间移动窗口而不是滑动窗口,这允许在更一般硬件中更有效地实现

使用自我注意力层或transformers来补充卷积体系

自我注意层可以通过提供编码远程依赖关系或异构交互的能力来补充骨干或头部网络。

基于transformers的视觉骨干架构:

VIT直接在不重叠的patch上运用transformers进行图像分类,但是VIT的特征图的分辨率较低,且计算复杂度随图像大小二次方增长,不适合做密集视觉任务或输入图像分辨率较高的通用骨干网络。SwinT是通用骨干网络,不专门针对分类,另一个并行的工作是探索在transformers上构建多分辨率特征映射的类似思路。

3.方法

总体架构:

首先通过patch分割模块将输入RGB图像分割成不重叠的patch(和VIT一致)。每个patch看作一个token,patch的特征被设置为原始像素RGB值的拼接。实验中使用4*4的patch大小,因此每个patch的特征维数是4*4*3=48。在原始特征值上应用一个线性嵌入层,以将其投影到任意维度(表示为C)。

在这些patch上应用几个SwinT(具有修正自我注意力计算的transformers块)。Transformer块维持token的数量(H/4*W/4),和线性嵌入层一起称为阶段一。

为产生分层特征表示,随着网络的深入,通过patch合并层减少token的数量。第一个patch合并层连接2*2相邻的patch特征,并在4C维连续特征上应用线性层。 这将token的数量下采样(2x),设置输出维度为2C。之后使用SwinT,分辨率保持在H/8 * W/8。第一个patch合并层和特征变换(SwinT Block)表示阶段2。阶段3、4是阶段2的复制输出分辨率分别是是H/16 * W/16 、H/32*W/32。这些阶段端共同产生分成表示,具有与典型卷积网络相同的特征图分辨率。因此所提出的方法可以方便的替代现有方法中用于视觉任务主干。

图3。(a) Swin变压器(Swin- t)的结构;(b)两个连续的Swin变压器块(用公式(3)表示)。W-MSA和SW-MSA分别是具有规则和移位窗口结构的多头自注意模块。

SwinT block:SwinT 是通过标准的多头自注意力模块(MSA)替换为基于移动窗口的模块而构建,其他保持不变。SwinT由一个基于移动窗口的MSA模块组成随后是一个2层MLP(中间带有GELU非线性激活),在每个MSA模块和每个MLP之前应用层规范(LN层),每个模块应用残差连接

基于移动窗口的自注意力:

标准transformer架构和分类的适应都要进行全局自注意力(计算一个token和其他所有的token的关系)。全局计算导致token数量的平方复杂度。

非重叠窗口的自注意力机制:为了有效建模,建议在局部窗口内进行计算自注意力非重叠的方式均匀分割图像。假设每个窗口有M*M个patch,则全局MSA模块和基于h*w的patch图像的窗口计算复杂度:

前者是patch数量h*w的平方时,当M时固定的时候后者是线性的。全局自关注计算通常是无法负担的,而基于窗口的自关注是可扩展的。

连续块中的移动窗口分区:基于窗口的自注意力机制缺乏跨窗口的连接,这限制了它的建模能力。为了保持非重叠窗口的有效计算的同时引入跨窗口连接,我们提出一种移动窗口划分方法,该方法在连续的SwinT块中的两个划分配置之间

交替。

如图2所示,第一个模块使用从左上角像素开始的常规窗口划分8*8特征图被均匀分为2*2个大小为4*4(M=4)的窗口。下一个模块采用和前一层的窗口配置不同的窗口配置,通过将窗口从规则划分的窗口移位(M/2,M/2 [这里是下取整])个像素。

使用移位窗口分割方法,连续的SwinT块被计算为:

其中ˆzlzl表示块l(S)WMSA模块MLP模块输出特征。W-MSA和SW-MSA分别表示使用规则移位窗口划分配置基于窗口的多头自注意力移位窗口分割方法引入了前一层中相邻非重叠窗口之间的连接,并被发现在图像分类对象检测语义分割中有效,如表4所示。

图4.移位窗口分区中用于自我关注的高效批计算方法的说明。

移位配置的高效批处理计算:移动窗口的一个问题就是它导致产生更多的窗口,在移位的配置里窗口数由()变成(),并且一些窗口小于

一个简单的解决方法是将较小的窗口填充到的大小,并在计算注意力时屏蔽填充的值。如果窗口数量非常小的话,这种方法增加的计算量非常大(由2*2->3*3,计算量是原本的2.25倍)。因此提出一种更有效地方法,通过向左上方向循环移位,如图4所示。这种移位之后,一个批处理窗口可能由几个在特征图中不相邻的子窗口组成,因此采用一种掩蔽机制将自关注计算限制在每个子窗口内。通过循环移位批处理的数量常规窗口分区的数量相同,因此也是有效的。

相对位置偏差:

计算自注意力时,在每个计算相似度时引入相对位置偏差B():

  1. K、V()是查询、键、值矩阵;d是它们的维度,是窗口的patch的数量。每个轴的相对位置位于,参数
  2. 大小取值于B。

表4表明,与没有偏置项使用绝对位置嵌入的有明显改善。进一步向输入添加绝对位置嵌入会略微降低性能,因此本文未使用该项。

体系结构变体

构建一个称为Swin-B的基础模型,模型大小和计算复杂性与VITB,还介绍一个SWin-T、Swin-S和Swin-L,他们分别是模型大小和计算复杂度的0.25,0.5和2倍。Swin-T和Swin-S复杂度和ResNet50和ResNet 101的复杂性相似。默认情况下,窗口大小设置为

M= 7,每个头部的查询维度为d=32,每个MLP的扩展为α=4:

C是第一级的隐藏层的信道号。图像分类模型变量的模型大小、理论计算复杂性和吞吐量列于表中1.

5.结论

提出的SwinT能产生分层的特征表示,具有与输入图像大小的线性计算复杂度。SwinT的关键元素,基于移动窗口的自我注意被证明在视觉问题上是有效和高效的。期待用于NLP。

标签:Transformer,Swin,SwinT,复杂度,patch,计算,窗口
From: https://www.cnblogs.com/wangzhe52xia/p/17699864.html

相关文章

  • TensorFlow PyTorch Transformer --GPT
    你是机器学习专家,请为我解答如下疑问:tensorflow,pytorch分别是什么?他们跟numpy的区别是什么?什么是深度学习,深度学习框架中的框架指的是什么?什么是机器学习,什么是transformer?transformer和tensorflow,pytorch的关系是什么?ChatGPTTensorFlow和PyTorchTensorFlow:这是一个由......
  • Swing中的组件
    1.顶级容器JFrame(框架):表示主程序窗口JDialog(对话框):每个JDialog对象表示一个对话框,对话框属于二级窗口JApplet(小程序):在浏览器内显示一个小程序界面2.布局方式1、边界布局(BorderLayout)2、流式布局(FlowLayout)3、网格布局(GridLayout)4、盒子布局(BoxLaYout)5、空布局(null)3.......
  • Spikformer: When Spiking Neural Network Meets Transformer
    郑重声明:原文参见标题,如有侵权,请联系作者,将会撤销发布!PublishedasaconferencepaperatICLR2023(同大组工作) ABSTRACT我们考虑了两种生物学合理的结构,脉冲神经网络(SNN)和自注意机制。前者为深度学习提供了一种节能且事件驱动的范式,而后者则能够捕获特征依赖性,使Trans......
  • ViTPose+:迈向通用身体姿态估计的视觉Transformer基础模型 | 京东探索研究院
    身体姿态估计旨在识别出给定图像中人或者动物实例身体的关键点,除了典型的身体骨骼关键点,还可以包括手、脚、脸部等关键点,是计算机视觉领域的基本任务之一。目前,视觉transformer已经在识别、检测、分割等多个视觉任务上展现出来很好的性能。在身体姿态估计任务上,使用CNN提取的特征,结......
  • CMT:卷积与Transformers的高效结合
    论文提出了一种基于卷积和VIT的混合网络,利用Transformers捕获远程依赖关系,利用cnn提取局部信息。构建了一系列模型cmt,它在准确性和效率方面有更好的权衡。CMT:体系结构CMT块由一个局部感知单元(LPU)、一个轻量级多头自注意模块(LMHSA)和一个反向残差前馈网络(IRFFN)组成。 ......
  • ICML 2023 | 神经网络大还是小?Transformer模型规模对训练目标的影响
    前言 本文研究了Transformer类模型结构(configration)设计(即模型深度和宽度)与训练目标之间的关系。结论是:token级的训练目标(如maskedtokenprediction)相对更适合扩展更深层的模型,而sequence级的训练目标(如语句分类)则相对不适合训练深层神经网络,在训练时会遇到over-smoothin......
  • 【ICML2022】Understanding The Robustness in Vision Transformers
    来自NUS&NVIDIA文章地址:[2204.12451]UnderstandingTheRobustnessinVisionTransformers(arxiv.org)项目地址:https://github.com/NVlabs/FAN一、MotivationCNN使用滑动窗的策略来处理输入,ViT将输入划分成一系列的补丁,随后使用自注意力层来聚合补丁并产生他们的表示,ViT的......
  • ChatGLM2 源码解析:`GLMTransformer`
    #编码器模块,包含所有GLM块classGLMTransformer(torch.nn.Module):"""Transformerclass."""def__init__(self,config:ChatGLMConfig,device=None):super(GLMTransformer,self).__init__()self.fp32_residual_co......
  • 基于Swing实现的PDFViewer
    最近因项目需求,需要使用Swing实现PDFViewer,并且需要鼠标拖动,放大缩小等操作,一开始在网上也找到了PDF-Renderer,但是一看原理,不也就是将PDF文件转化为image而已,目前解决掉了拖动以及放大缩小的BUG问题。如下使用apache-pdfbox转换的PDF,当然也可以替换为iText或者别的依赖代码如下:......
  • Java Swing查看字体和设置全局字体
    查看支持的字体以下代码用于运行时在控制台打印支持的字体GraphicsEnvironmentgEnv=GraphicsEnvironment.getLocalGraphicsEnvironment();finalStringAvailableFontFamilyNames[]=gEnv.getAvailableFontFamilyNames();Stream.of(AvailableFontFamilyNames).forEach(Sys......