首页 > 其他分享 >手写VIT

手写VIT

时间:2024-08-08 17:52:26浏览次数:11  
标签:dim head nn self patch VIT 手写 size

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class pre_process(nn.Module):
    def __init__(self, image_size, patch_size, patch_dim, dim):
        """
        image_size:输入图像的大小,例如224
        patch_size:每个patch的大小,例如16
        patch_dim: 每个patch的特征维度,例如768
        dim: 模型的隐藏层维度,例如768
        """
        super().__init__()
        self.patch_size = patch_size
        self.dim = dim
        # 计算图像中patch的数量,并赋值给类属性 self.patch_num
        self.patch_num = (image_size // patch_size) ** 2
        # 初始化一个线性层 self.linear_embedding,用于将每个patch的特征维度从patch_dim映射到dim
        self.linear_embedding = nn.Linear(patch_dim, dim)
        self.position_embedding = nn.Parameter(torch.randn(1,self.patch_num+1, self.dim)) # 使用广播
        self.CLS_token = nn.Parameter(torch.randn(1, 1, self.dim)) # 别忘了维度要和(B,L,C)对齐

    def forward(self, x):
        """
        输入的x的数组表示为(B,C,H,W),需要将它划分为(B,L,C)
        输入形状- B:批次大小  C:通道数  H:高度 W:宽度
        输出形状- B:批次大小  L:patch的个数,等于(H/p1)*(W/p2)  C:每个patch的像素值数量,等于 p1*p2*channels
        """
        x = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = self.patch_size, p2 = self.patch_size) # (B, L, C)
        """

        """
        x = self.linear_embedding(x)
        b, l, c = x.shape
        # 复制后的CLS_token的形状为(b, 1, d),表示每个样本都拥有一个CLS Token
        CLS_token = repeat(self.CLS_token, '1 1 d -> b 1 d', b=b) # 位置编码复制 B 份
        # 级联CLS Token
        x = torch.concat((CLS_token, x), dim=1) 
        # 位置嵌入
        # self.position_embedding的形状是(1, patch_num+1, dim),x的形状是(batch_num, patch_length+1, dim).
        # 在进行加法时,self.position_embedding会被扩展到(batch_size, patch_num+1, dim),从而与x进行逐元素相加
        x = x + self.position_embedding #batch的数量没有发生变化,但是batch的

class Multihead_self_attention(nn.Module):
    def __init__(self, heads, head_dim, dim):
        super().__init__()
        """
        self.dim: 输入数据的维度  
        self.head_dim: 每个注意力头的维度
        self.heads: 注意力头的数量
        self.inner_dim: 多头自注意力最后的输出维度,等于 heads*head_dim
        """
        self.head_dim = head_dim # 每一个注意力头的维度
        self.heads = heads # 注意力头个数
        self.inner_dim = self.heads*self.head_dim #多头注意力最后输出维度
        self.scale = self.head_dim**-0.5 #正则化系数
        # 一个线性层,用于将输入数据映射到查询(Q),键(K)和值(V)三个矩阵,每个矩阵的维度为inner_dim
        self.to_qkv = nn.Linear(dim, self.inner_dim*3) # 生成qkv,每一个矩阵的维度由自注意力头的维度以及头的个数决定
        # 一个线性层,用于将多头自注意力输出结果映射回原始维度dim
        self.to_output = nn.Linear(self.inner_dim, dim)
        # 一个层归一化层,用于对输入数据进行归一化
        self.norm = nn.LayerNorm(dim)
        # 一个Softmax函数,用于对注意力得分进行归一化
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        x = self.norm(x) # PreNorm
        # .chunk()方法是PyTorch中用于分割张量的函数;dim=-1表示沿着最后一个维度进行分割
        qkv = self.to_qkv(X).chunk(3, dim=-1) # 划分QKV, 返回一个列表, 其中就包含了QKV
        # 对 QKV 的多头映射进行拆分, 得到(B, head, L, head_dim)
        """
        map函数会对qkv列表中的每个张量应用lambda表达式
        'b l (h dim) -> b h l dim'表示将输入张量的维度从(batch_size, sequence_length, num_heads*head_dim)变换为(batch_size, num_heads, sequence_length, head_dim)
        dim=self.head_dim 是指定head_dim 的值,以便正确拆分
        """
        Q, K, V = map(lambda t: rearrange(t, 'b l (h dim) -> b h l dim', dim=self.head_dim), qkv)
        K_T = K.transpose(-1, -2) # 对K进行转置,用于计算自注意力
        att_score = Q@K_T*self.scale # 计算自注意力得分
        att = self.softmax(att_score) # softmax
        out = att@V # (B, H, L, dim):自注意力输出
        # 将输出张量的维度从(batch_size, num_heads, sequence_length, head_dim)变换为 (batch_size, sequence_length, num_heads*head_dim)
        out = rearrange(out, 'b h l dim -> b l (h dim)') # 拼接
        output = self.to_output(out) # 输出映射
        return output

class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, mlp_dim)
        self.fc2 = nn.Linear(mlp_dim, dim)
        self.norm = nn.LayerNorm(dim)
    
    def forward(self, x):
        x = self.norm(x)
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

class Transformer_block(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, depth, num_class):
        super().__init__()
        self.to_patch_embedding = pre_process(image_size=image_size, patch_size=patch_size)

标签:dim,head,nn,self,patch,VIT,手写,size
From: https://www.cnblogs.com/hitzzk/p/18349412

相关文章

  • vue手写转盘抽奖
    目前抽奖最常见的两种:九宫格抽奖和 转盘抽奖,但转盘抽奖的大多是采用的是将做好的图放在页面上,如果需要变动奖池里面的奖品或数量,就会让设计师重新出一张图片。分享我自己在自定义转盘抽奖的样式布局思路<template><divclass="main"><Topmassage="转盘抽奖"/>......
  • vue3+vite+ts 颜色选择器组件支持颜色吸取,透明度
    ciw-color-picker-vue:vue3+vite+ts颜色选择器,支持颜色吸取,透明度,与浏览器原生颜色选择器相似,与饿了么颜色选择器相似使用了ciw-color-picker-vuenpmi ciw-color-picker-vue 安装 npmiciw-color-picker-vue全局引入方式main.ts或main.js import'ci......
  • vite和webpack的区别
    内核区别最明显区别是及时编译和打包编译,开发编译速度上的区别。配置区别主要区别就是webpack有loader和plugins配置,vite直接是plugins为主体。viteplugins中配置编译器和插件:其他json配置大同小异。参考文章https://zxuqian.cn/difference-between-vite-and-webpac......
  • 中文手写体识别(ocr)测试
    记录一下,以下是测试中文手写体识别结果图展示(对于潦草的字迹效果一般),后期会开放模型,有java和python版本:......
  • 基于神经网络的手写数字识别及其ZYNQ实现
        基于MNIST数据集的手写数字识别是神经网络(NeuralNetwork)的经典应用。    本文将讨论一种名为“ZYNET”的全连接神经网络框架,它可以自动生成针对FPGA的硬件实现架构。我们以手写数字识别为例,在ZYNQ平台上对该架构进行验证。本章包括以下几个部分:1环境配......
  • 【Vitepress系列】-- 自定义组件及布局,配置tailwindcss、配置Markdown
    Vitepress自定义页面,以及配置tailwindcssvitepress中,除了使用一些配置项目,还可以通过写vue代码,来做一个定制化的UI。下面这个UI主页便是vue组件+tailwindcss做的一.自定义vitepress中,如果内置的home、doc、page不满足需求,还可以自己写vue代码进行自定义1.1自定义布......
  • 不销毁activity实现白天黑夜主题切换
    Androidactivity加载布局文件流程一.onCreate初始化AppCompatActivity.onCreate先调用getDelegate()创建AppCompatDelegateImplN(最终继承AppCompatDelegateImplV9->AppCompatDelegateImplBase)对象delegate,然后调用delegate.installViewFactory()AppCompatDelegateImplV......
  • Android activity主题设置
    主题配置<stylename="MainThemeCamera"parent="Theme.AppCompat.DayNight.NoActionBar"><itemname="android:windowBackground">@color/black</item><itemname="android:windowTranslucentStatu......
  • 2. vite 构建项目
    目录vite构建项目1安装node.js2创建项目3项目部署4vue-router组件4.1创建项目4.2router-link指定路由vite构建项目1安装node.js下载链接https://nodejs.p2hp.com/版本需大于18.0查看版本node-v添加环境变量npm(包管理工具,类似于python的pip)npm-g(查看包......
  • 深度学习(VIT)
    将Transformer引入图像领域之作,学习一下。网络结构:VIT结构有几个关键的地方:1.图像分块:输入图像被划分为固定大小的非重叠小块(patches),每个小块被展平并线性嵌入到一个固定维度的向量中。这里是将32x32的图像划分成4x4的小块,总共会有16个小块,每个小块有64维向量。2.位置编码:......