首页 > 其他分享 >Vision Transformer 的学习与实现

Vision Transformer 的学习与实现

时间:2022-10-14 21:36:54浏览次数:48  
标签:dim Transformer heads self 学习 token num Vision size

 

Vision Transformer 的学习与实现

Transformer最初被用于自然语言处理领率,具体可见论文Attention Is All You Need。后来被用于计算机视觉领域,也取得了十分惊艳的结果(An Image Is Worth 16x16 Words: Transformers For Image Recognition At Scale),以至于现在的transformer大行其道。现在就来学习一下。

Vision Transformer 主要应用了NLP Transformer中的Encoder模块,而且总体的结构和处理过程基本保持了一致,成功将图像问题转换成序列问题。

下图是Vision Transfomer的具体结构和流程

ViT的结构

Vision Transformer 的结构总体而言可以分为3个部分

  1. Linear Projection of Flattened Patches (对切分的图像Patch,进行Embedding)

  2. Transformer Encoder(Transformer的主要结构)

  3. MLP Head (MLP分类头)

接下来我们将介绍这三个结构的一些细节

Linear Projection of Flattened Patches

在标准的Transformer模型中,处理的输入是token序列,是(num_token, num_dim)的二维向量。而图像的表示一般都是(H, W, C)的三维矩阵。因此ViT会先将图像切成若干个大小相等的Patch。

以宽高都为224的图像为例,一张图片的形状为(224, 224, 3),然后我们把他切成若干大小为16 * 16的Patches,因此我们可以得到 (224/16)^2 = 196个Patches,即从(224, 224, 3)可以变换为(196, 16, 16, 3)。之后我们把后三个维度处理成一个维度,即可得到一个(196, 768)的二维向量(196个token,每个token的维度是768),就成功将图像转换成了一个序列。

最后我们进行一个投影操作,也就是ViT的第一个主要结构,将token的维度数量投影到某一个规定的D。

此外,将这个序列输入进Encoder之前,还需要两个操作。第一,是加一个用于分类的class token(参考自Bert)。这个token的尺寸和之前的token序列相同,然后进行一个concate操作,那么最终输入进Encoder的序列尺寸就是(197, 768)。

第二,由于图像本身固有位置信息,而Transformer关注全局信息,缺少之前CNN的偏执归纳,因此需要加上位置编码。由于是一个加操作,token序列的尺寸仍无变化。

Transformer Encoder

Transformer Encoder可以说是直接移植自NLP领域的Transformer。包括LayerNorm,Multi-Head Attention 和 MLP等操作,此外还有残差连接。

LayerNorm

层归一化,将数据分布拉到激活函数的非饱和区,类似于BatchNorm。至于这里为什么用LayerNorm,主要是因为在原生Transformer中,不同的mini-batch可能具有不同的输入长度(NLP问题),会导致BatchNorm出现问题,因此使用了LayerNorm。

 

  • BatchNorm:batch方向做归一化,计算N * H * W的均值

  • LayerNorm:channel方向做归一化,计算C * H * W的均值

在此过程中输出的维度不变。

Multi-Head Attention

多头注意力是Transformer的核心结构。注意力是指给定一个查询query,与所有的key-value对中的key进行注意力权重运算,最后通过该权重加权value运算。这里用到的注意力运算是点积运算,其中Q代表的是query,K代表的是key, V代表的是value,D代表特征长度。

 

 

对同一组查询、键和值,根据多头注意力头的个数n,将其拆分成n份,送入不同的点积注意力模块。然后将得到的多个结果concate起来,最终经过一个全连接层输出。这种设计让每个注意力头可以关注不同的部分,有点类似于卷积层有多个卷积核关注不同特征通道的信息。在ViT中,使用的是自注意力机制,一个注意力头中的Q,K,V是同一个token序列 (num_token, num_dim)。

在Encoder中,尺寸为(197, 768)的token序列,被拆分成12份,形成12个(197, 64)的子token组,然后送入多头注意力。在注意力运算中Q(197, 64) , K^T(64, 197), V(197,64), 得到的结果是(197, 64)。然后将12个头concate起来,再经过一个全连接层,结果仍是(127, 768)。

MLP

这里的MLP就是两个全连接层,将token序列维度升维(197, 3072),然后再降维(197, 768),使其输出仍保持在(num_token, num_dim)上。激活函数是GELU。

 

因此通过一个Transformer Encoder,token序列的尺寸不变,因此可以在ViT中堆叠多个Encoder。除此之外,在Encoder中还有残差连接。

MLP Head

分类头,用于最终的分类。我们提取来自Encoder的输出,在197个token中,我们只要与分类有关的class token。即(1, 768)。然后通过全连接层和tanh激活函数等结构,进行分类。这样整个流程就结束了。

 

ViT的Pytorch实现

在Pytorch中,torch.nn模块里已经集成了MultiHeadAttention和TransformerEncoder,因此可以简洁实现。不过出于学习的目的,还是参考网上的许多资料自己手动实现一下。最核心的部分还是TransformerEncoder,其他部分在ViT里能较为轻松实现。

Transformer Encoder

Encoder的LayerNorm层,就一个简单的LayerNorm,残差连接在后边的forward中实现。

​
class Norm(nn.Module):
    def __init__(self, num_dim):
        super(Norm, self).__init__()
        self.norm = nn.LayerNorm(num_dim)
​
    def forward(self, x):
        x = self.norm(x)
        return x

Encoder的MLP层,简单的Linear组合。

​
class MLP(nn.Module):
    def __init__(self, num_dim, hidden_num, dropout=0.):
        super(MLP, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(num_dim, hidden_num),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_num, num_dim),
            nn.Dropout(dropout)
        )
​
    def forward(self, x):
        return self.mlp(x)

多头注意力层,这个有些复杂,参考了动手学深度学习的实现方式。包括点积注意力,和两个用于转换token序列形状其实能够在多头注意力并行计算的辅助函数,最后就是多头注意力。在具体的实现中,给tensor的形状做了注释,方便理解。

class DotProductAttention(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
​
    def forward(self, q, k, v):
        # q(batch, q_size, d) k(batch, k_size, d) v(batch, v_size, d)
        d = q.shape[-1]
        scores = torch.bmm(q, k.transpose(1, 2)) / math.sqrt(d)
        # (batch, q_size, v_size)
        attention_weights = F.softmax(scores, dim=1)
        # (batch, q_size, v_size) * (batch, v_size, d)
        return torch.bmm(self.dropout(attention_weights), v)
        # (batch, q_size, d)
        
def transpose_qkv(x, num_heads):
    # batch_size, num_token, num_dim
    x = x.reshape(x.shape[0], x.shape[1], num_heads, -1)
    # batch_size, num_token, num_heads, num_dim / num_heads
    x = x.permute(0, 2, 1, 3)
    # batch_size, num_heads, num_token, num_dim / num_heads
    return x.reshape(-1, x.shape[2], x.shape[3])
    # batch_size * num_heads,  num_token, num_dim / num_heads
​
def transpose_output(x, num_heads):
    # (batch_size*num_heads, num_token, num_dim/num_heads)
    x = x.reshape(-1, num_heads, x.shape[1], x.shape[2])
    # (batch_size, num_heads, num_token, num_dim/num_heads)
    x = x.permute(0, 2, 1, 3)
    # (batch_size, num_token, num_heads, num_dim/num_heads)
    return x.reshape(x.shape[0], x.shape[1], -1)
    # (batch_size, num_token, num_dim)
​
class MultiHeadAttention(nn.Module):
    def __init__(self, q_size, k_size, v_size, num_dim, num_heads, dropout, bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.wq = nn.Linear(q_size, num_dim, bias)
        self.wk = nn.Linear(k_size, num_dim, bias)
        self.wv = nn.Linear(v_size, num_dim, bias)
        self.wo = nn.Linear(num_dim, num_dim, bias)
​
    def forward(self, q, k, v):  # num_dim = qkv_size
        # q, k, v (batch_size, num_token, num_dim)
        # q, k, v (batch_size, num_token, num_dim)
        q = transpose_qkv(self.wq(q), self.num_heads)
        k = transpose_qkv(self.wk(k), self.num_heads)
        v = transpose_qkv(self.wv(v), self.num_heads)
        # (batch_size*num_heads, num_token, num_dim/num_head)
        out = self.attention(q, k, v)
        # (batch_size*num_head, num_token, num_dim/num_head)
        out = transpose_output(out, self.num_heads)
        # (batch_size, num_token, num_dim)
        return self.wo(out)
        # (batch_size, num_token, num_dim)

然后我们把上边实现的组件合体起来,构成ViT的一个Encoder模块。

class TransFormerEncoder(nn.Module):
    def __init__(self, num_dim, num_heads, dropout, num_hidden):
        super(TransFormerEncoder, self).__init__()
        self.norm1 = Norm(num_dim)
        self.norm2 = Norm(num_dim)
        self.attention = MultiHeadAttention(q_size=num_dim, k_size=num_dim, v_size=num_dim,
                                            num_dim=num_dim, num_heads=num_heads, dropout=dropout)
        self.mlp = MLP(num_dim, num_hidden, dropout)
​
    def forward(self, x):
        y = self.norm1(x)
        y = self.attention(y, y, y)
        tmp = y + x            # shortcut
        out = self.norm2(tmp)
        out = self.mlp(out)
​
        return out + tmp      # shortcut

Vision Transformer

之后我们可以直接实现Vision Transformer,它的处理流程基本如下, 以batch_size=4, img_size=224, patch_size=16, num_dim=768, num_hidden=3072为例,也讲述一下tensor的变换形状。

  1. 将图片tensor切分成patches,可以使用eniops的rearrange,实现tensor的快速切分,然后图片就可以处理成序列。

(4, 3, 224, 224)----(4, 196, 768)

  1. 然后对切分的patches做一个embedding操作,使用Linear线性层即可。

(4, 196, 768)----(4, 196, 768)

  1. 给处理好的patches,concate一个class token。class token的形状和处理好的patches token一致,都是(196, 768)

(4, 196, 768)----(4, 197, 768)

  1. 对于位置编码,采用了比较简单的可学习的位置编码,通过nn.Parameter实现。

采用加操作,不改变形状

  1. 进入Transformer Encoder。Encoder输入输出形状一致,也因此可以堆叠多个。

形状不变

  1. 最后取出class token,接一个分类头

(4, 768) ---- (4, num_class)

class ViT(nn.Module):
    def __init__(self, num_dim, num_embedding, num_hidden, num_layer, num_heads, dropout,
                 num_class=102, img_size=224, patch_size=16):
        super(ViT, self).__init__()
        self.num_patch = (img_size // patch_size) * (img_size // patch_size)
        self.split_patch = Rearrange('b c (p1 h) (p2 w) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
        self.embedding = nn.Linear(num_dim, num_embedding)  # num_dim = num_embedding
        self.pos = nn.Parameter(torch.randn(1, self.num_patch + 1, num_dim))  # 可学习位置编码
        self.cls_token = nn.Parameter(torch.randn(1, 1, num_dim))  # class token
        self.encoder = nn.Sequential()
        for i in range(num_layer):
            self.encoder.add_module(f'encoder {i}', TransFormerEncoder(num_dim, num_heads, dropout, num_hidden))
​
        self.head = nn.Sequential(
            nn.LayerNorm(num_dim),
            nn.Linear(num_dim, num_class)
        )
​
    def forward(self, x):
        x = self.split_patch(x)
        batch, num_token, _ = x.shape
        x = self.embedding(x)    # (batch, num_token, num_dim)
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=batch) # (batch, num)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos
        x = self.encoder(x)
        x = x[:, 0]       # 取出class token
        out = self.head(x)
        return out

然后我们随便用一个randn生成一个tensor,看看能否跑通

 

 

应该是没有问题的。

ps. 个人的实现可能在一些细节方面同原著有些不一样,以后可能会再改进,欢迎大家批评指正。

102种鲜花分类

准备通过Ai研习社的鲜花分类练习赛试一下ViT的效果。这一部分先咕咕掉了,等跑完了再贴出来..

标签:dim,Transformer,heads,self,学习,token,num,Vision,size
From: https://www.cnblogs.com/Brisling/p/16793093.html

相关文章

  • 《Unix/Linux系统编程》第四章学习笔记 20201209戴骏
    第四章并发编程知识点归纳1、并行计算导论在早期,大多数计算机只有一个处理组件,称为处理器或中央处理器(CPU)。受这种硬件条件的限制,计算机程序通常是为串行计算编写的。......
  • 深度学习算法基础
    1,基本概念1.1,余弦相似度1.2,欧式距离1.3,余弦相似度和欧氏距离的区别2,容量、欠拟合和过拟合3,正则化方法4,超参数和验证集5,估计、偏差和方差6,随机梯度下降算法......
  • UART学习笔记
    UART是一种通用串行数据总线,用于异步通信。该总线双向通信,可以实现全双工传输和接收。在嵌入式设计中,UART用于主机与辅助设备通信,如汽车音响与外接AP之间的通信,与PC机通信......
  • ffmpeg数据结构学习(AVpacket & AVframe)
     其中的AVBufferRef是一个AVbuffer的指针:图片来源于网络 关于AVframe:音频解码API avcodec_decode_audio4在新版中已废弃,替换为使用更为简单的avcodec_send_packet......
  • Python学习路程——Day15
    Python学习路程——Day15重要内置函数zip()'''zip()函数的作用 zip()函数可以将多个序列(列表、元组、字典、集合、字符串以及ranger()区间构成的列表压缩成一个zip对......
  • letcode-学习-数组去重
    数组去重问题描述:给你一个升序排列的数组nums,请你原地删除重复出现的元素,使每个元素只出现一次,返回删除后数组的新长度。元素的相对顺序应该保持一致。由于......
  • 学习历程
    我是刚进入大学的大一新生,专业是软件工程,今天正式学c语言。我的目标是能够自己做出能够运行的软件,并且深造c语言技术。我打算是先跟着视频学习,并且及时的练习、后期也要自己......
  • 20201318李兴昕第四章学习笔记
    第四章:并发编程知识点归纳总结:本章论述了并发编程,介绍了并行计算的概念,指岀了并行计算的重要性;比较了顺序算法与并行算法,以及并行性与并发性;解释了线程的原理及其相对......
  • 【博学谷学习记录】超强总结,用心分享|狂野架构师redis数据类型的不同使用场景
    目录redis数据类型的不同使用场景数据使用场景String类型存储商品数量。用户信息。分布式锁。hash类型存用户信息。存储对象信息。list类型秒杀set类型某日用户签到情况。......
  • C语言内嵌汇编学习笔记
    参考:gnugcc中关于ExtendedAsm的文档​​​https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html​​​BasicAsm文档​​https://gcc.gnu.org/onlinedocs/gcc/Basic-A......