文章目录
paper:Salient Positions based Attention Network for Image Classification
1、Salient Positions Attention
在现有的自注意力机制中,其建模长距离依赖关系方面表现出色,但其计算复杂度和内存需求巨大,限制了其在实际应用中的使用。此外,论文还指出了并非所有从全局范围内收集的信息都对上下文建模有益。例如,背景中的纹理信息可能会干扰模型的判断。所以,这篇论文提出一种 显著位置注意力(Salient Positions Attention) 来解决前面提出的问题,并且降低计算复杂度和内存需求。
SPA 的基本思想是通过使用显著位置选择算法,仅选择有限的显著点进行注意力图计算。这种方法不仅可以节省大量内存和计算资源,还可以尝试从输入特征图的变换中提取积极信息。与非局部块方法不同,SPA 是使用选定的位置而不是所有位置来建模上下文信息,沿着通道维度而不是空间维度进行操作。
对于输入X,SPA 的实现过程:
- 特征图转换: 首先将输入特征图通过两个二维卷积层转换为查询矩阵 Q 和值矩阵 V。
- 显著位置选择: 使用显著位置选择 (SPS) 算法选择查询矩阵 Q 中 top-k 个显著位置。其中,SPS 算法用于从查询矩阵中选取最具代表性的位置进行注意力计算。其核心思想是利用查询矩阵的特征,选择出最具信息量的位置,从而减少计算量并提高模型性能。
- 注意力图计算: 使用选定的数据计算亲和矩阵 A。
- 上下文信息聚合: 将值矩阵 V 与亲和矩阵 A 相乘,并重塑为 c×h×w 的形状。
- 特征图融合: 最后使用 1×1 卷积层对结果进行变换,并将其添加到输入特征图中。
SPA 的实现过程中,SPS实现过程:
- 计算查询矩阵的转置矩阵的平方幂: 首先,将查询矩阵 Q 进行转置,然后计算每个位置上所有通道的平方和,得到一个 c x (hw) 的矩阵 Qpow。
- 对 Qpow 按照通道维度求和: 将 Qpow 按照通道维度求和,得到一个 1 x (h*w) 的矩阵,表示每个位置上所有通道的平方和的总和。
- 选择最大的 k 个位置: 在 1 x (h*w) 的矩阵中,选择最大的 k 个位置,得到它们的索引 indexk。
- 返回选取的 k 个位置: 将查询矩阵 Q 的第 c 行对应 indexk 的位置上的元素提取出来,组成一个新的 c x k 的矩阵 K,作为 SPS 算法的输出。
Salient Positions Attention 结构图:
2、代码实现
import torch
import torch.nn as nn
class SPABlock(nn.Module):
def __init__(self, in_channels, k=8, adaptive = False, reduction=16, learning=False, mode='pow'):
super(SPABlock, self).__init__()
self.in_channels = in_channels
self.reduction = reduction
self.k = k
self.adptive = adaptive
self.reduction = reduction
self.learing = learning
if self.learing is True:
self.k = nn.Parameter(torch.tensor(self.k))
self.mode = mode
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
def forward(self, x, return_info=False):
input_shape = x.shape
if len(input_shape)==4:
x = x.view(x.size(0), self.in_channels, -1)
x = x.permute(0, 2, 1)
batch_size,N = x.size(0),x.size(1)
#(B, H*W,C)
if self.mode == 'pow':
x_pow = torch.pow(x,2)# (batchsize,H*W,channel)
x_powsum = torch.sum(x_pow,dim=2)# (batchsize,H*W)
if self.adptive is True:
self.k = N//self.reduction
if self.k == 0:
self.k = 1
outvalue, outindices = x_powsum.topk(k=self.k, dim=-1, largest=True, sorted=True)
outindices = outindices.unsqueeze(2).expand(batch_size, self.k, x.size(2))
out = x.gather(dim=1, index=outindices).to(self.device)
if return_info is True:
return out, outindices, outvalue
else:
return out
if __name__ == '__main__':
"""
输入:[B, C, H, W] / [B, H*W, C]
输出:[B, k, C]
"""
x = torch.randn(4, 512, 7, 7)
model = SPABlock(in_channels=512, k=7*7)
output = model(x)
print(output.shape)
标签:__,位置,torch,self,Attention,矩阵,2021,SPA,size
From: https://blog.csdn.net/wei582636312/article/details/144792438