文章目录
1、HiLo Attention
论文中指出 多头自注意力(MSA) 在高分辨率图像上存在巨大的计算开销。为解决这一问题,本文引入一种 HiLo Attention 来提高速度和准确性。HiLo Attention 通过将注意力层分为高频和低频两部分,分别捕捉图像中的局部细节和全局结构。HiLo的动机在于自然图像包含丰富的频率,其中高、低频率在编码图像图案中分别代表 局部精细细节 和 全局结构。核心思想则是将特征图中的高频和低频信息进行解耦,再分别使用不同的注意力机制进行处理,从而提高视觉 Transformer 在高分辨率图像上的效率。
HiLo Attention 通过将 MSA 分成两条路径,其中一条路径通过 局部自注意力 利用相对高分辨率的特征图来编码高频交互,而另一条路径通过 全局注意力 利用下采样特征图来编码低频交互,这导致了效率的极大提高。对于一个输入特征而言:
- Head Splitting:首先根据设定的比例 α 将 MSA 层的头分为两组,一组用于Hi-Fi,另一组用于Lo-Fi。
- High Frequency Attention (Hi-Fi):在上方的路径中,通过将一组头部分配给高频注意力(Hi-Fi),再通过局部窗口自注意来捕获细粒度的高频(例如,2 × 2窗口)。Hi-Fi 专注于图像的局部细节,适用于处理高分辨率特征图。
- Low Frequency Attention (Lo-Fi):而在下方的路径则是用来实现低频注意力(Lo-Fi),首先对每个窗口应用平均汇集以获得低频信号。然后分配给 Lo-Fi 剩余头部,以建模输入特征映射中的每个查询位置与来自每个窗口的平均池低频键和值之间的关系。Lo-Fi关注于图像的全局结构,适用于处理下采样后的低分辨率特征图。
- 输出:最后将细化后的 Hi-Fi 与 Lo-Fi 的结果连接起来。这种设计不仅提高了效率,还通过减少键(keys)和值(values)的长度,实现了显著的复杂度降低。
HiLo Attention 结构图:
2、LIT v2
在 HiLo Attention 的基础上,论文提出了一种新的 ViTs 架构 LITv2。其整体结构与LIT v1基本类似,不同之处在于LIT v2通过使用 3x3 深度可分离卷积层代替了原来的相对位置编码,将位置信息隐式地学习到零填充中,从而提高速度并扩大早期 MLP 块的感受野,有效地提高了视觉 Transformer 的效率和性能。
LIT v2 结构图:
3、代码实现
import torch
import torch.nn as nn
from einops.einops import rearrange
class HiLo(nn.Module):
"""
这个模块 要求输入 的维度是 [B, N, C] N=H*W
所以,对于 [B, C, H, W]的张量,需要先转换 维度 ,再进行处理
转换维度:
from einops.einops import rearrange
[B, C, H, W]->[B, H*W, C] : rearrange(x, 'b c h w -> b (h w) c')
[B, H*W, C]->[B, C, H, W] : rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., window_size=2,
alpha=0.5):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
head_dim = int(dim / num_heads)
self.dim = dim
# self-attention heads in Lo-Fi
self.l_heads = int(num_heads * alpha)
# token dimension in Lo-Fi
self.l_dim = self.l_heads * head_dim
# self-attention heads in Hi-Fi
self.h_heads = num_heads - self.l_heads
# token dimension in Hi-Fi
self.h_dim = self.h_heads * head_dim
# local window size. The `s` in our paper.
self.ws = window_size
if self.ws == 1:
# ws == 1 is equal to a standard multi-head self-attention
self.h_heads = 0
self.h_dim = 0
self.l_heads = num_heads
self.l_dim = dim
self.scale = qk_scale or head_dim ** -0.5
# Low frequence attention (Lo-Fi)
if self.l_heads > 0:
if self.ws != 1:
self.sr = nn.AvgPool2d(kernel_size=window_size, stride=window_size)
self.l_q = nn.Linear(self.dim, self.l_dim, bias=qkv_bias)
self.l_kv = nn.Linear(self.dim, self.l_dim * 2, bias=qkv_bias)
self.l_proj = nn.Linear(self.l_dim, self.l_dim)
# High frequence attention (Hi-Fi)
if self.h_heads > 0:
self.h_qkv = nn.Linear(self.dim, self.h_dim * 3, bias=qkv_bias)
self.h_proj = nn.Linear(self.h_dim, self.h_dim)
def hifi(self, x):
B, H, W, C = x.shape
h_group, w_group = H // self.ws, W // self.ws
total_groups = h_group * w_group
x = x.reshape(B, h_group, self.ws, w_group, self.ws, C).transpose(2, 3)
qkv = self.h_qkv(x).reshape(B, total_groups, -1, 3, self.h_heads, self.h_dim // self.h_heads).permute(3, 0, 1,
4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2] # B, hw, n_head, ws*ws, head_dim
attn = (q @ k.transpose(-2, -1)) * self.scale # B, hw, n_head, ws*ws, ws*ws
attn = attn.softmax(dim=-1)
attn = (attn @ v).transpose(2, 3).reshape(B, h_group, w_group, self.ws, self.ws, self.h_dim)
x = attn.transpose(2, 3).reshape(B, h_group * self.ws, w_group * self.ws, self.h_dim)
x = self.h_proj(x)
return x
def lofi(self, x):
B, H, W, C = x.shape
q = self.l_q(x).reshape(B, H * W, self.l_heads, self.l_dim // self.l_heads).permute(0, 2, 1, 3)
if self.ws > 1:
x_ = x.permute(0, 3, 1, 2)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
kv = self.l_kv(x_).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.l_kv(x).reshape(B, -1, 2, self.l_heads, self.l_dim // self.l_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, H, W, self.l_dim)
x = self.l_proj(x)
return x
def forward(self, x, H, W):
B, N, C = x.shape
x = x.reshape(B, H, W, C)
if self.h_heads == 0:
x = self.lofi(x)
return x.reshape(B, N, C)
if self.l_heads == 0:
x = self.hifi(x)
return x.reshape(B, N, C)
hifi_out = self.hifi(x)
lofi_out = self.lofi(x)
x = torch.cat((hifi_out, lofi_out), dim=-1)
x = x.reshape(B, N, C)
return x
def flops(self, H, W):
# pad the feature map when the height and width cannot be divided by window size
Hp = self.ws * math.ceil(H / self.ws)
Wp = self.ws * math.ceil(W / self.ws)
Np = Hp * Wp
# For Hi-Fi
# qkv
hifi_flops = Np * self.dim * self.h_dim * 3
nW = (Hp // self.ws) * (Wp // self.ws)
window_len = self.ws * self.ws
# q @ k and attn @ v
window_flops = window_len * window_len * self.h_dim * 2
hifi_flops += nW * window_flops
# projection
hifi_flops += Np * self.h_dim * self.h_dim
# for Lo-Fi
# q
lofi_flops = Np * self.dim * self.l_dim
kv_len = (Hp // self.ws) * (Wp // self.ws)
# k, v
lofi_flops += kv_len * self.dim * self.l_dim * 2
# q @ k and attn @ v
lofi_flops += Np * self.l_dim * kv_len * 2
# projection
lofi_flops += Np * self.l_dim * self.l_dim
return hifi_flops + lofi_flops
if __name__ == '__main__':
H, W = 16, 16
x = torch.randn(4, 512, 16, 16).cuda()
x = rearrange(x, 'b c h w -> b (h w) c')
model = HiLo(512).cuda()
out = model(x, 16, 16)
out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)
print(out.shape)
标签:dim,heads,self,Attention,CVPR,ws,2022,attn,Fi From: https://blog.csdn.net/wei582636312/article/details/143999807本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。