ViT 原理解析 (Transformers for Image Recognition at Scale)
原创 小白 小白研究室 2024年06月10日 21:09 北京
如何将 transformer 应用到图像领域
Transformer模型最开始是用于自然语言处理(NLP)领域的,NLP主要处理的是文本、句子、段落等,即序列数据。
视觉领域处理的是图像数据,因此将Transformer模型应用到图像数据上面临着诸多挑战,理由如下:
-
与单词、句子、段落等文本数据不同,图像中包含更多的信息,并且是以像素值的形式呈现。
-
如果按照处理文本的方式来处理图像,即逐像素处理的话,复杂度较高,硬件难以实现。
-
Transformer缺少CNNs的归纳偏差,比如平移不变性和局部受限感受野。
-
CNNs是通过相似的卷积操作来提取特征,随着模型层数的加深,感受野也会逐步增加。但是由于Transformer的本质,其在计算量上会比CNNs更大。
-
Transformer无法直接用于处理基于网格的数据,比如图像数据。
为了解决上述问题,Google的研究团队提出了ViT模型,它的本质其实也很简单,既然Transformer只能处理序列数据,那么我们就把图像数据转换成序列数据就可以了呗。下面来看下ViT是如何做的。
基本结构
另外,从网上也看到有人绘制了比较详细的算法结构图,对于理解 ViT 也是有比较大的帮助,就复用粘贴在这里供大家学习:
(结构图来自https://blog.csdn.net/weixin_42118657/article/details/121789116)
模块细节
将图片转换成 patches 序列
对于图像数据而言,其数据格式为[H, W, C],是三维矩阵,明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。
首先将一张图片按给定大小分成一堆Patches。
以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]
在代码实现中,直接通过一个卷积层来实现。以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。
# https://github.com/lucidrains/vit-pytorch/blob/90be7233a3f55c29692a72da6ee4dcb5aab267d4/vit_pytorch/twins_svt.py#L59
class PatchEmbedding(nn.Module):
def __init__(self, *, dim, dim_out, patch_size):
super().__init__()
self.dim = dim
self.dim_out = dim_out
self.patch_size = patch_size
self.proj = nn.Sequential(
LayerNorm(patch_size ** 2 * dim),
nn.Conv2d(patch_size ** 2 * dim, dim_out, 1),
LayerNorm(dim_out)
)
def forward(self, fmap):
p = self.patch_size
fmap = rearrange(fmap, 'b c (h p1) (w p2) -> b (c p1 p2) h w', p1 = p, p2 = p)
return self.proj(fmap)
这里增加了 class token,class token的维度是[1,768],然后将其与第1步得到的tokens进行拼接,即Cat([1, 768], [196, 768]) -> [197, 768]。
在传统CNN分类任务中,会对最后卷积输出的 feature map 进行一个 global average pooling 操作,用以进行最后的类别预测;在 vision transformer 里面能否进行相同的操作呢,即把 16 个 patch 的 token 进行一个 average pooling 来替代 class token。作者消融实验下来验证是可以的,但是要验证使用不同的学习率。论文中作者是为了尽可能的和 transformer 结构保持一致、所以才默认使用了 class token (In order to stay as close as possible to the original Transformer model)
添加 Position embedding
从公式可以看出,其实一个词语的位置编码是由不同频率的余弦函数函数组成的,从低位到高位,余弦函数对应的频率由 1 降低到了 110000 ,按照论文中的说法,也就是,波长从 2
标签:dim,Transformer,Scale,Transformers,Image,drop,ViT,path,self From: https://blog.csdn.net/sinat_37574187/article/details/141368574