首页 > 其他分享 >如何用pytorch调用预训练Swin Transformer中的一个Swin block模块

如何用pytorch调用预训练Swin Transformer中的一个Swin block模块

时间:2024-03-23 09:00:10浏览次数:27  
标签:Transformer Swin features pytorch bias 128 out True Linear

1,首先,我们需要知道的是,想要调用预训练的Swin Transformer模型,必须要安装pytorch2,因为pytorch1对应的torchvision中不包含Swin Transformer。

2,pytorch2调用预训练模型时,不建议使用pretrained=True,这个用法即将淘汰,会报警告。最好用如下方式:

from torchvision.models.swin_transformer import swin_b, Swin_B_Weights  
  
model = swin_b(weights=Swin_B_Weights.DEFAULT)  

这里调用的就是swin_b在imagenet上的预训练模型

3,swin_b的模型结构如下(仅展示到第一个patch merging部分),在绝大部分情况下,我们可能需要的不是整个模型,而是其中的一个模块,比如SwinTransformerBlock。

SwinTransformer(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): Sequential(
      (0): SwinTransformerBlock(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=512, out_features=128, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (1): SwinTransformerBlock(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttention(
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (proj): Linear(in_features=128, out_features=128, bias=True)
        )
        (stochastic_depth): StochasticDepth(p=0.021739130434782608, mode=row)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=512, out_features=128, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (2): PatchMerging(
      (reduction): Linear(in_features=512, out_features=256, bias=False)
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )

那么如何调用其中的SwinTransformerBlock呢。

由于该模型是个嵌套结构,而不是类似vgg一样简单的结构,所以不能直接用layer0=model.SwinTransformerBlock调用。

因为SwinTransformerBlock是Sequential下的子模块,故正确的调用代码如下:

swinblock = model.features[1][0]

结果如下,调用成功:

标签:Transformer,Swin,features,pytorch,bias,128,out,True,Linear
From: https://blog.csdn.net/wwimhere/article/details/136953109

相关文章

  • 4.transformer
    建议直接看参考的知乎链接,我这是一坨1.encorder\[\mathrm{LayerNorm}\big(X+\mathrm{MultiHeadAttention}(X)\big)\]\[\mathrm{LayerNorm}\big(X+\mathrm{Feed}\mathrm{Forward}(X)\big)\]\[\mathrm{FeedForward}(X)=\max(0,XW_1+b_1)W_2+b_2\]做layernorm而不是batchnor......
  • VPCFormer:一个基于transformer的多视角指静脉识别模型和一个新基准
    文章目录VPCFormer:一个基于transformer的多视角指静脉识别模型和一个新基准总结摘要介绍相关工作单视角指静脉识别多视角指静脉识别Transformer数据库基本信息方法总体结构静脉掩膜生成VPC编码器视角内相关性的提取视角间相关关系提取输出融合IFFN近邻感知模块(NPM)p......
  • 遥感影像问题深度学习:PyTorch在气候变化研究中的应用
    我国高分辨率对地观测系统重大专项已全面启动,高空间、高光谱、高时间分辨率和宽地面覆盖于一体的全球天空地一体化立体对地观测网逐步形成,将成为保障国家安全的基础性和战略性资源。未来10年全球每天获取的观测数据将超过10PB,遥感大数据时代已然来临。随着小卫星星座的普及,......
  • ubuntu安装cuda和cudnn,并测试tensorflow和pytorch库的与cuda的兼容性(2023年版)
    lspci|grep-invidia查看nvidia设备,看到GPUgcc--version检查是否安装上gcc软件包根据官方文档指示,pipinstalltorch==1.13.1+cu117-fhttps://download.pytorch.org/whl/torch_stable.html,pipinstalltorchaudio==0.13.1+cu117-fhttps://download.pytorch.org/whl/torch......
  • GTC大会干货:8位大佬对Transformer起源和未来发展的探讨
      添加图片注释,不超过140字(可选) 在2024年的GTC大会上,黄仁勋特邀Transformer机器语言模型的七位创造者,共同探讨Transformer模型的过去、现在与未来。他们一致认为,尽管Transformer已经成为现代自然语言处理领域的基石,但这个世界仍然需要超越Transformer......
  • Pytorch学习笔记(一)
    一、Tensor1.1 基本概念Tensor,又名张量,是pytorch中重要的一种数据结构,从工程的角度上来说,可以很简单将其认为是与numpy的nadarray类似的数组,用来保存数据支持高效的科学计算。但是PyTorch中的Tensor支持cuda用GPU加速。1.2基本操作从接口的角度来说,对tensor的操作可以分......
  • 【论文阅读】SpectFormer: Frequency and Attention is what you need in a Vision Tr
    SpectFormer:FrequencyandAttentioniswhatyouneedinaVisionTransformer引用:PatroBN,NamboodiriVP,AgneeswaranVS.SpectFormer:FrequencyandAttentioniswhatyouneedinaVisionTransformer[J].arXivpreprintarXiv:2304.06446,2023.论文......
  • PyTorch张量
    目录基本创建方式  创建线性和随机张量张量元素类型转换 阿达玛积张量数值计算 ......
  • Python实战:PyTorch入门
    一、引言深度学习是近年来人工智能领域的热点之一,其在图像识别、语音识别、自然语言处理等领域取得了显著的成果。Python作为一门流行的编程语言,拥有丰富的深度学习框架,其中PyTorch是近年来备受关注的一个。本文将详细介绍PyTorch的基本概念、安装方法、基础知识以及实战项......
  • [基础] DiT: Scalable Diffusion Models with Transformers
    名称DiT:ScalableDiffusionModelswithTransformers时间:23/03机构:UCBerkeley&&NYUTL;DR提出首个基于Transformer的DiffusionModel,效果打败SD,并且DiT在图像生成任务上随着Flops增加效果会降低,比较符合scalinglaw。后续sora的DM也使用该网络架构。Method网络结构整......