Vision Transformer (ViT) 分类标识符
Vision Transformer (ViT) 分类标识符
1. 初始化分类标识符
在ViT中,分类标识符是一个可学习的向量,通常在模型初始化时随机初始化。这个标识符的维度与图像块的嵌入向量维度相同,通常记作 zcls,其大小为 D(与每个图像块的嵌入向量维度一致)。
2. 与图像块嵌入一起作为输入
将这个分类标识符 zcls 附加在所有图像块的嵌入向量之前,形成一个扩展后的输入序列。
假设原始图像块嵌入的序列表示为 [z1, z2, …, zN],其中 N 是图像块的数量,那么完整的输入序列将是:
[zcls, z1, z2, …, zN]
这里,输入序列的维度为 (N+1) × D。
3. 在Transformer中处理
这个包含分类标识符的输入序列会传递给Transformer的多层编码器,经过多层自注意力机制和前馈神经网络的处理。分类标识符在每一层都会被更新,并最终聚合整个图像的信息。
4. 提取最终分类标识符
当输入序列经过所有Transformer层的处理后,提取出最终的分类标识符 zclsfinal。
这个分类标识符是一个综合了整个图像信息的嵌入向量。
5. 传递给分类头
最终的分类标识符 zclsfinal 会被传递给一个分类头(通常是一个全连接层)进行图像的分类任务。分类头输出的向量用于预测图像属于哪个类别。
6. 代码示例(假设使用Python和PyTorch)
import torch
import torch.nn as nn
class VisionTransformer(nn.Module):
def __init__(self, num_patches, embed_dim, num_classes):
super(VisionTransformer, self).__init__()
# 初始化分类标识符 (CLS token)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(embed_dim, nhead=8),
num_layers=12
)
self.classifier = nn.Linear(embed_dim, num_classes)
def forward(self, x):
batch_size = x.size(0)
# 复制分类标识符,使其适应批处理大小
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
# 将分类标识符添加到图像块的嵌入向量之前
x = torch.cat((cls_tokens, x), dim=1)
# 添加位置编码
x = x + self.position_embeddings
# 输入Transformer
x = self.transformer(x)
# 提取最终的分类标识符
cls_token_final = x[:, 0, :]
# 传递给分类头进行分类
out = self.classifier(cls_token_final)
return out