首页 > 其他分享 >基于Vision Transformer的mini_ImageNet图片分类实战

基于Vision Transformer的mini_ImageNet图片分类实战

时间:2024-07-22 11:00:23浏览次数:8  
标签:mini Transformer self torch embedding ImageNet embed size

【图书推荐】《PyTorch深度学习与计算机视觉实践》-CSDN博客

PyTorch计算机视觉之Vision Transformer 整体结构-CSDN博客

mini_ImageNet数据集简介与下载

mini_ImageNet数据集节选自ImageNet数据集。ImageNet是一个非常有名的大型视觉数据集,它的建立旨在促进视觉识别研究。ImageNet为超过1400万幅图像进行了注释,而且给至少100万幅图像提供了边框。同时,ImageNet包含2万多个类别,比如“气球”“轮胎”和“狗”等类别,ImageNet的每个类别均不少于500幅图像。

训练这么多图像需要消耗大量的资源,为了节约资源,后续的研究者在全ImageNet的基础上提取出了mini_ImageNet数据集。Mini_ImageNet包含100类共60000幅彩色图片,其中每类有600个样本,每幅图片的规格为84×84。通常而言,这个数据集的训练集和测试集的类别划分为80:20。相比于CIFAR-10数据集,mini_ImageNet数据集更加复杂,但更适合进行原型设计和实验研究。

mini_ImageNet的下载也很容易,读者可以使用提供的库包完成对应的下载操作,安装命令如下:

pip install MLclf

Vision Transformer模型设计

下面就是对训练过程的Vision Transformer进行模型设计,在11.1.4节完成的Vision Transformer模型的设计,针对的是224维度大小的图片,而此时使用的是mini版本的ImageNet,因此在维度上会有所变换。本例Vision Transformer模型的完整代码如下:

import torch
from vit import PatchEmbed,Block

class VisionTransformer(torch.nn.Module):
    def __init__(self,num_patches = 1,image_size = 84,patch_size = 14,embed_dim = 588,num_heads = 6,
                 qkv_bias = True,depth = 3,num_class = 64):
        super().__init__()

        #初始化PatchEmbed层
        self.patch_embed  = PatchEmbed(img_size = image_size,patch_size=patch_size,embed_dim=embed_dim)
        #增加一个作为标志物的参数
        self.cls_token = torch.nn.Parameter(torch.zeros(1, 1, embed_dim))

        #建立位置向量,计算embedding的长度
        self.num_tokens = (image_size * image_size) // (patch_size * patch_size)
        self.pos_embed = torch.nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))

        #这里在使用block模块时采用了指针的方式,注意*号
        self.blocks = torch.nn.Sequential(
            *[Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=4.0, qkv_bias=qkv_bias) for _ in range(depth)]
        )
        #最终的logits推断层
        self.logits_layer = torch.nn.Sequential(torch.nn.Linear(embed_dim, 512),torch.nn.GELU(),torch.nn.Linear(512, num_class))

    def forward(self,x):

        embedding = self.patch_embed(x)

        #添加标志物
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        embedding = torch.cat((cls_token, embedding), dim=1)  #[B, 197, 768]
        embedding += self.pos_embed

        embedding = self.blocks(embedding)

        embedding = embedding[:,0]
        embedding = torch.nn.Dropout(0.1)(embedding)
        logits = self.logits_layer(embedding)
        return logits

if __name__ == '__main__':
    image = torch.randn(size=(2,3,84,84))
    VisionTransformer()(image)

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

标签:mini,Transformer,self,torch,embedding,ImageNet,embed,size
From: https://blog.csdn.net/brucexia/article/details/140604145

相关文章

  • 无法再现 Transformer 库的 ViTImageProcessor 预处理
    我正在编写一个用于预处理图像的独立代码image_processing_vit但是,我的结果与图书馆的结果不同。以下代码包含两部分:(1)不使用变压器,(2)使用变压器。我不知道我错过了什么。我已阅读代码+询问副驾驶,但它不能解决问题。请帮助我!谢谢。importcv2importnump......
  • LAVIS库学习及MiniGPT4-Qwen中的实现
    目录LAVIS库一、lavis库介绍二、体验示例ImageCaptioningVisualquestionanswering(VQA)UnifiedFeatureExtractionInterface加载数据集在任务数据集上评估预训练模型微调BLIP在COCO-Captioning数据集深度剖析模型配置数据集配置三、lavis自定义模块3.1自定义数据集Datase......
  • MiniAuth 一个轻量 ASP.NET Core Identity Web 后台管理中间插件
    MiniAuth一个轻量ASP.NETCoreIdentityWeb后台管理中间插件「一行代码」为「新、旧项目」添加Identity系统跟用户、权限管理网页后台系统开箱即用,避免打掉重写或是严重耦合情况Github:https://github.com/mini-software/MiniAuth,Gitee:https://gitee.com/shps9510......
  • Transformer多头自注意力及掩码机制详解
    系列文章目录文章目录系列文章目录@[TOC](文章目录)前言一、self-attention1.注意力机制2.自注意力机制3.代码实现二、掩码机制1.原理介绍2.代码实现三、多头注意力模块1.原理介绍2.代码实现前言在本文中我们重点介绍Transformer中的掩码机制及多头自注......
  • ChatGPT如何开启使用gpt-4o mini模型?
    OpenAI发布了新的LLM大模型:gpt-4omini。gpt-3.5现在已经取消掉了,用gpt-4omini代替且gpt-4omini是免费的。根据OpenAI官方介绍,GPT-4omini在学术测试中表现优异,超越了GPT-3.5Turbo等小型模型。它在文本智能、多模态推理和语言支持方面水平与GPT-4o相当。在函数调用方面表现......
  • GraphRAG参数与使用步骤 | 基于GPT-4o-mini实现更便宜的知识图谱RAG
    首先给兄弟朋友们展示一下结论,一个文本18万多字,txt文本大小185K,采用GraphRAG,GPT-4o-mini模型,索引耗时差不多5分钟,消耗API价格0.15美元GraphRAG介绍GraphRAG是微软最近开源的一款基于知识图谱技术的框架,主要应用于问答、摘要和推理等方面。它的核心特点是将大型语言模型(LL......
  • 如何在 8 个 GPU 上并行化 Transformer 模型进行机器翻译?
    我正在尝试使用变压器模型以几乎与原始文章相同的方式执行机器翻译。虽然该模型运行得相当好,但它需要更多的计算资源。为了解决这个问题,我在一台具有8个GPU处理器的计算机上运行了该模型,但我缺乏这方面的经验。我尝试对并行化进行必要的调整:transformer=nn.DataParallel......
  • transformer model architecture
    transformermodelarchitecturehttps://www.datacamp.com/tutorial/how-transformers-work 动手写https://www.datacamp.com/tutorial/building-a-transformer-with-py-torch Attentionhttps://www.cnblogs.com/jins-note/p/13056604.html人类的视觉注意力从注意力......
  • 【独家首发】Matlab实现淘金优化算法GRO优化Transformer-LSTM实现负荷数据回归预测
    %导入数据集load(‘load_data.mat’);%假设负荷数据保存在load_data.mat文件中%数据预处理%这里省略了数据预处理的步骤,包括数据归一化、特征提取等%构建Transformer-LSTM模型model=create_transformer_lstm_model();%自定义创建Transformer-LSTM模型的函数......
  • 【独家首发】Matlab实现狮群优化算法LSO优化Transformer-LSTM实现负荷数据回归预测
    %导入数据集load(‘load_data.mat’);%假设负荷数据保存在load_data.mat文件中%数据预处理%这里省略了数据预处理的步骤,包括数据归一化、特征提取等%构建Transformer-LSTM模型model=create_transformer_lstm_model();%自定义创建Transformer-LSTM模型的函数......