首页 > 其他分享 >深度学习(VIT)

深度学习(VIT)

时间:2024-08-03 17:40:08浏览次数:13  
标签:__ dim nn self 学习 VIT 深度 embed size

将Transformer引入图像领域之作,学习一下。

网络结构:

VIT结构有几个关键的地方:

1. 图像分块:输入图像被划分为固定大小的非重叠小块(patches),每个小块被展平并线性嵌入到一个固定维度的向量中。这里是将32x32的图像划分成4x4的小块,总共会有16个小块,每个小块有64维向量。

2. 位置编码:由于Transformer不具备位置敏感性,需要添加位置编码来提供位置信息。每个图像块向量都会加上一个对应的可学习的位置编码,以保留图像空间信息。

3. Transformer编码:嵌入向量连同位置编码一起被输入到Transformer编码器中,编码器由多个相同自注意力层堆叠而成。

4. MLP分类。

测试代码如下: 

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import CIFAR10
import torchvision.models

class EmbedLayer(nn.Module):
    def __init__(self,channels, embed_dim,img_size,patch_size):
        super().__init__()
        self.embed_dim = embed_dim
        self.conv1 = nn.Conv2d(channels, embed_dim, patch_size, patch_size)  
        self.pos_embedding = nn.Parameter(torch.zeros(1, (img_size // patch_size) ** 2,embed_dim), requires_grad=True)  # Positional Embedding

    def forward(self, x):
        x = self.conv1(x)       
        x = x.reshape([x.shape[0], self.embed_dim, -1]) 
        x = x.transpose(1, 2)  
        x = x + self.pos_embedding  
        return x


class SelfAttention(nn.Module):
    def __init__(self,embed_dim, heads):
        super().__init__()
        self.heads = heads
        self.embed_dim = embed_dim
        self.head_embed_dim = self.embed_dim // heads

        self.queries = nn.Linear(self.embed_dim, self.head_embed_dim * heads, bias=True)
        self.keys = nn.Linear(self.embed_dim, self.head_embed_dim * heads, bias=True)
        self.values = nn.Linear(self.embed_dim, self.head_embed_dim * heads, bias=True)

    def forward(self, x):
        m, s, e = x.shape 

        q = self.queries(x).reshape(m, s, self.heads, self.head_embed_dim).transpose(1, 2) 
        k = self.keys(x).reshape(m, s, self.heads, self.head_embed_dim).transpose(1, 2)   
        v = self.values(x).reshape(m, s, self.heads, self.head_embed_dim).transpose(1, 2)   

        q = q.reshape([-1, s, self.head_embed_dim]) 
        k = k.reshape([-1, s, self.head_embed_dim])
        v = v.reshape([-1, s, self.head_embed_dim]) 

        k = k.transpose(1, 2) 
        x_attention = q.bmm(k) 
        x_attention = torch.softmax(x_attention, dim=-1)

        x = x_attention.bmm(v)  
        x = x.reshape([-1, self.heads, s, self.head_embed_dim])  
        x = x.transpose(1, 2)  
        x = x.reshape(m, s, e)  
        return x


class Encoder(nn.Module):
    def __init__(self, embed_dim,heads):
        super().__init__()
        self.attention = SelfAttention(embed_dim,heads)  
        self.fc1 = nn.Linear(embed_dim, embed_dim * 2)
        self.activation = nn.GELU()
        self.fc2 = nn.Linear(embed_dim * 2,embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = x + self.attention(self.norm1(x)) 
        x = x + self.fc2(self.activation(self.fc1(self.norm2(x))))  
        return x


class Classifier(nn.Module):
    def __init__(self, embed_dim,num_patches,classes):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim*num_patches, embed_dim)
        self.activation = nn.Tanh()
        self.fc2 = nn.Linear(embed_dim, classes)

    def forward(self, x):
        x = x.view(x.shape[0],-1)
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x


class VisionTransformer(nn.Module):
    def __init__(self,channels, embed_dim,n_layers,heads,img_size,patch_size,classes):
        super().__init__()
        self.embedding = EmbedLayer(channels,embed_dim,img_size,patch_size)
        self.encoder = nn.Sequential(*[Encoder(embed_dim,heads) for _ in range(n_layers)], nn.LayerNorm(embed_dim))
        self.norm = nn.LayerNorm(embed_dim) 
        self.classifier = Classifier(embed_dim,(img_size//patch_size)**2,classes)

    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = self.norm(x)
        x = self.classifier(x)
        return x
    
if __name__ == '__main__':

    device = torch.device("cuda")
        
    trainTransforms = transforms.Compose([
                transforms.ToTensor()
                , transforms.RandomHorizontalFlip(p=0.5) 
                , transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
            ])
    
    testTransforms = transforms.Compose([
                transforms.ToTensor()
                , transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  
            ])

    trainset = CIFAR10(root='./data', train=True, download=True, transform=trainTransforms)
    trainLoader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
    testset = CIFAR10(root='./data', train=False,download=False, transform=testTransforms)
    testLoader = torch.utils.data.DataLoader(testset, batch_size=128,shuffle=False)

    model = VisionTransformer(channels=3, embed_dim=128,n_layers=6,heads=8,img_size=32,patch_size=8,classes=10)
    
    # model = torchvision.models.resnet18(pretrained=True)
    # model.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)  
    # model.maxpool = nn.MaxPool2d(1, 1, 0)  
    # model.fc = nn.Linear(model.fc.in_features, 10)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-3)
    cos_decay = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100, verbose=True) 

    model.to(device)

    for epoch in range(50):
        print("epoch :",epoch)

        model.train()
        correct = 0
        total = 0

        for images, labels in trainLoader:

            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            print(loss.item(),f" train Accuracy: {(100 * correct / total):.2f}%")

        cos_decay.step()
  
        model.eval()
        with torch.no_grad():
            correct = 0
            total = 0
            for images,labels in testLoader:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
            print(f"test Accuracy: {(100 * correct / total):.2f}%")

    # 保存模型
    torch.save(model.state_dict(), 'vit.pth')

标签:__,dim,nn,self,学习,VIT,深度,embed,size
From: https://www.cnblogs.com/tiandsp/p/17935794.html

相关文章

  • Objective-C学习笔记(协议和代理)
    协议协议是多个类共享的一个方法列。协议中列出的方法没有相应的实现,计划由其他人来实现。可以定义这些方法为必须实现的,也可以为可选择实现的@protocal协议名//在此处添加必须实现的协议方法@optional//在此处添加可选择实现的协议方法@end遵循协议也符合继承关系......
  • 中级软件设计师---小白学习第一天:数据的表示和校验码
    计算机中只能识别的数据是二进制,低电平代表0,高电平代表1进制的符号表示:二进制B,十进制D,十六进制H真值:符合人类习惯的数字机器数:数字实际存到机器里面的形式,正负号需要被”数字化“15——1111+15——011118——1000-8——11000数据的表示:定点数与浮......
  • Python学习中最常见的10个列表操作问题
    列表是Python中使用最多的一种数据结果,如何高效操作列表是提高代码运行效率的关键,这篇文章列出了10个常用的列表操作,希望对你有帮助。1、迭代列表时如何访问列表下标索引普通版:items=[8,23,45]forindexinrange(len(items)):print(index,"-->",items[index])​......
  • VUE3学习路线
    以下是一份详细的Vue3学习路线,涵盖从基础到进阶的各个方面,以帮助你系统掌握Vue3开发。第一阶段:基础知识理解前端基础HTML:了解文档结构,常用标签,语义化HTML。CSS:学习选择器、布局、Flexbox和Grid,基本的样式应用。JavaScript:理解基本语法、DOM操作、事件处......
  • DeepViT 论文与代码解析
    paper:DeepViT:TowardsDeeperVisionTransformerofficialimplementation:https://github.com/zhoudaquan/dvit_repo出发点尽管浅层ViTs在视觉任务中表现优异,但随着网络深度增加,性能提升变得困难。研究发现,这种性能饱和的主要原因是注意力崩溃问题,即在深层变压器中,attentio......
  • 实现一个终端文本编辑器来学习golang语言:第一章项目构建
    欢迎!这个系列的博文会带你使用golang语言来编写一个你自己的文本编辑器。更多介绍见https://www.cnblogs.com/Ama2ingYJ/p/18340634这里我把我们的文本编辑器项目命名为zedterm。首先第一步自然是初始化golang工程gomodinitzedterm作为文本编辑器,其中重要的一个工作便是......
  • 实现一个终端文本编辑器来学习golang语言
    欢迎!这个系列的博文会带你使用golang语言来编写一个你自己的文本编辑器。首先想说说写这个系列文章的动机。其实作为校招生加入某头部互联网大厂一转眼已经快4年了。可以说该大厂算是比较早的用golang语言作为主要后端开发技术栈的公司了,绝大部分后端项目的语言选型都是golang......
  • Android Studio开发学习(二、注册存储)
    用户注册首先我们创建一个新的Activity,将他命名为RegisterActivity我们还是先设计注册界面布局(根据自身喜好),我这里延用了上一篇透明框布局bg_username、btn_left、btn_right上一篇我们已经简单介绍了LinearLayout、TextView、EditText功能,这里补充一下Button布局,它决定按钮......
  • 李沐动手学深度学习V2-chapter_convolutional-modern
    李沐动手学深度学习V2文章内容说明本文主要是自己学习过程中的随手笔记,需要自取课程参考B站:https://space.bilibili.com/1567748478?spm_id_from=333.788.0.0课件等信息原视频简介中有卷积神经网络经典卷积神经网络LeNet深度卷积神经网络AlexNetAlexNet与LeNet对比:1.......
  • 学生java学习路程-5
    ok,到了一周一次的总结时刻,我大致会有下面几个方面的论述:1.这周学习了Java的那些东西2.这周遇到了什么苦难3.未来是否需要改进方法等几个方面阐述我的学习路程。抽象类abstract接口interface,定义时加入注释解释接口含义String类String是不可变字符串,所有的替换,截取子字符串,......