摘要
视觉Transformer在许多视觉任务上展示了卓越的性能。然而,它在浅层捕获局部特征时可能会面临高度冗余的问题。因此,使用了局部自注意力或早期阶段的卷积来减少这种冗余,但这牺牲了捕获长距离依赖的能力。一个挑战随之而来:在神经网络的早期阶段,我们是否能高效且有效地进行全局上下文建模?为解决这一问题,我们从超像素的设计中获得启示,这种设计通过减少图像基元的数量来简化后续处理,并在视觉Transformer中引入了超级令牌。超级令牌旨在为视觉内容提供有意义的语义分割,这样既减少了自注意力中的令牌数量,也保留了全局建模能力。具体而言,我们提出了一种简单而有效的超级令牌注意力(STA)机制,它包括三个步骤:首先通过稀疏关联学习从视觉令牌中抽取超级令牌,接着对这些超级令牌进行自注意力处理,最后将它们映射回原始的令牌空间。STA通过将普通的全局注意力分解为稀疏关联图与低维度注意力的乘积,极大地提高了捕获全局依赖的效率。基于STA,我们开发了一个层次化的视觉Transformer。广泛的实验显示了它在各种视觉任务上的强大性能。特别是,在没有任何额外训练数据或标签的情况下,它在ImageNet-1K上实现了86.4%的顶级准确率,以及在COCO检测任务上达到53.9的盒AP和46.8的掩码AP,在ADE20K语义分割任务上实现了51.9的mIOU。
YOLO目标检测创新改进与实战案例专栏
专栏目录: YOLO有效改进系列及项目实战目录 包含卷积,主干 注意力,检测头等创新机制 以及 各种目标检测分割项目实战案例
专栏链接: YOLO基础解析+创新改进+实战案例
创新点
-
引入超级标记(super tokens):通过引入超级标记的概念,实现了在视觉Transformer中的全局上下文建模。超级标记将原始标记聚合成具有语义意义的单元,从而减少了自注意力计算的复杂度,提高了全局信息的捕获效率。
-
Super Token Attention(STA)机制:提出了一种简单而强大的超级标记注意力机制,包括超级标记采样、多头自注意力和标记上采样等步骤。STA通过稀疏映射和自注意力计算,在全局和局部之间实现了高效的信息交互,有效地学习全局表示。
-
Hierarchical Vision Transformer:设计了一种层次化的视觉Transformer结构,结合了卷积层和超级标记Transformer块,以在不同层次上实现高效和有效的表示学习。这种结构在各种视觉任务上展现了出色的性能,包括图像分类、目标检测和语义分割等。
-
性能优越性:在多个视觉任务上进行了广泛的实验验证,包括ImageNet图像分类、COCO目标检测和ADE20K语义分割等。实验结果表明,基于超级标记的视觉Transformer在各项任务上均取得了优异的性能,超越了其他Transformer模型的表现。
yolov8 引入
class StokenAttention(nn.Module):
def __init__(self, dim, stoken_size=[8,8], n_iter=1, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.n_iter = n_iter # 迭代次数
self.stoken_size = stoken_size # 空间令牌的大小
self.scale = dim ** - 0.5 # 缩放因子
self.unfold = Unfold(3) # 定义Unfold实例
self.fold = Fold(3) # 定义Fold实例
# 定义空间注意力机制
self.stoken_refine = StAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=proj_drop)
def stoken_forward(self, x):
'''
x: (B, C, H, W)
'''
B, C, H0, W0 = x.shape
h, w = self.stoken_size
pad_l = pad_t = 0
pad_r = (w - W0 % w) % w
pad_b = (h - H0 % h) % h
if pad_r > 0 or pad_b > 0:
x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) # 对输入张量进行填充
_, _, H, W = x.shape
hh, ww = H // h, W // w
stoken_features = F.adaptive_avg_pool2d(x, (hh, ww)) # 自适应平均池化
pixel_features = x.reshape(B, C, hh, h, ww, w).permute(0, 2, 4, 3, 5, 1).reshape(B, hh * ww, h * w, C)
with torch.no_grad():
for idx in range(self.n_iter):
stoken_features = self.unfold(stoken_features) # 展开空间令牌特征
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
affinity_matrix = pixel_features @ stoken_features * self.scale # 计算亲和矩阵
affinity_matrix = affinity_matrix.softmax(-1) # 对亲和矩阵进行softmax
affinity_matrix_sum = affinity_matrix.sum(2).transpose(1, 2).reshape(B, 9, hh, ww)
affinity_matrix_sum = self.fold(affinity_matrix_sum)
if idx < self.n_iter - 1:
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum + 1e-12) # 归一化
stoken_features = pixel_features.transpose(-1, -2) @ affinity_matrix
stoken_features = self.fold(stoken_features.permute(0, 2, 3, 1).reshape(B * C, 9, hh, ww)).reshape(B, C, hh, ww)
stoken_features = stoken_features / (affinity_matrix_sum.detach() + 1e-12) # 归一化
stoken_features = self.stoken_refine(stoken_features) # 细化空间令牌特征
stoken_features = self.unfold(stoken_features) # 展开细化后的特征
stoken_features = stoken_features.transpose(1, 2).reshape(B, hh * ww, C, 9)
pixel_features = stoken_features @ affinity_matrix.transpose(-1, -2) # 计算最终的像素特征
pixel_features = pixel_features.reshape(B, hh, ww, C, h, w).permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
if pad_r > 0 or pad_b > 0:
pixel_features = pixel_features[:, :, :H0, :W0] # 去除填充部分
return pixel_features # 返回最终的像素特征
def direct_forward(self, x):
B, C, H, W = x.shape
stoken_features = x
stoken_features = self.stoken_refine(stoken_features)
return stoken_features # 返回直接计算的空间令牌特征
def forward(self, x):
if self.stoken_size[0] > 1 or self.stoken_size[1] > 1:
return self.stoken_forward(x) # 使用空间令牌前向计算
else:
return self.direct_forward(x) # 直接前向计算
task与yaml配置
详见:https://blog.csdn.net/shangyanaf/article/details/139113660
标签:令牌,features,self,Attention,YOLOv8,ww,pad,stoken,STA From: https://www.cnblogs.com/banxia-frontend/p/18259535