首页 > 其他分享 >(即插即用模块-Attention部分) 三十三、(2021) SPA 显著位置注意力

(即插即用模块-Attention部分) 三十三、(2021) SPA 显著位置注意力

时间:2025-01-10 10:58:55浏览次数:3  
标签:__ 位置 torch self Attention 矩阵 2021 SPA size

在这里插入图片描述

文章目录

paper:Salient Positions based Attention Network for Image Classification

Code:https://github.com/likyoo/SPANet


1、Salient Positions Attention

在现有的自注意力机制中,其建模长距离依赖关系方面表现出色,但其计算复杂度和内存需求巨大,限制了其在实际应用中的使用。此外,论文还指出了并非所有从全局范围内收集的信息都对上下文建模有益。例如,背景中的纹理信息可能会干扰模型的判断。所以,这篇论文提出一种 显著位置注意力(Salient Positions Attention) 来解决前面提出的问题,并且降低计算复杂度和内存需求。

SPA 的基本思想是通过使用显著位置选择算法,仅选择有限的显著点进行注意力图计算。这种方法不仅可以节省大量内存和计算资源,还可以尝试从输入特征图的变换中提取积极信息。与非局部块方法不同,SPA 是使用选定的位置而不是所有位置来建模上下文信息,沿着通道维度而不是空间维度进行操作。

对于输入X,SPA 的实现过程:

  1. 特征图转换: 首先将输入特征图通过两个二维卷积层转换为查询矩阵 Q 和值矩阵 V。
  2. 显著位置选择: 使用显著位置选择 (SPS) 算法选择查询矩阵 Q 中 top-k 个显著位置。其中,SPS 算法用于从查询矩阵中选取最具代表性的位置进行注意力计算。其核心思想是利用查询矩阵的特征,选择出最具信息量的位置,从而减少计算量并提高模型性能。
  3. 注意力图计算: 使用选定的数据计算亲和矩阵 A。
  4. 上下文信息聚合: 将值矩阵 V 与亲和矩阵 A 相乘,并重塑为 c×h×w 的形状。
  5. 特征图融合: 最后使用 1×1 卷积层对结果进行变换,并将其添加到输入特征图中。

SPA 的实现过程中,SPS实现过程:

  1. 计算查询矩阵的转置矩阵的平方幂: 首先,将查询矩阵 Q 进行转置,然后计算每个位置上所有通道的平方和,得到一个 c x (hw) 的矩阵 Qpow。
  2. 对 Qpow 按照通道维度求和: 将 Qpow 按照通道维度求和,得到一个 1 x (h*w) 的矩阵,表示每个位置上所有通道的平方和的总和。
  3. 选择最大的 k 个位置: 在 1 x (h*w) 的矩阵中,选择最大的 k 个位置,得到它们的索引 indexk。
  4. 返回选取的 k 个位置: 将查询矩阵 Q 的第 c 行对应 indexk 的位置上的元素提取出来,组成一个新的 c x k 的矩阵 K,作为 SPS 算法的输出。

Salient Positions Attention 结构图:
在这里插入图片描述

2、代码实现

import torch
import torch.nn as nn


class SPABlock(nn.Module):
    def __init__(self, in_channels, k=8, adaptive = False, reduction=16, learning=False, mode='pow'):
        super(SPABlock, self).__init__()
        self.in_channels = in_channels
        self.reduction = reduction
        self.k = k
        self.adptive = adaptive
        self.reduction = reduction
        self.learing = learning
        if self.learing is True:
            self.k = nn.Parameter(torch.tensor(self.k))

        self.mode = mode
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    def forward(self, x, return_info=False):
        input_shape = x.shape
        if len(input_shape)==4:
            x = x.view(x.size(0), self.in_channels, -1)
            x = x.permute(0, 2, 1)
        batch_size,N = x.size(0),x.size(1)

        #(B, H*W,C)
        if self.mode == 'pow':
            x_pow = torch.pow(x,2)# (batchsize,H*W,channel)
            x_powsum = torch.sum(x_pow,dim=2)# (batchsize,H*W)

        if self.adptive is True:
            self.k = N//self.reduction
            if self.k == 0:
                self.k = 1

        outvalue, outindices = x_powsum.topk(k=self.k, dim=-1, largest=True, sorted=True)

        outindices = outindices.unsqueeze(2).expand(batch_size, self.k, x.size(2))
        out = x.gather(dim=1, index=outindices).to(self.device)

        if return_info is True:
            return out, outindices, outvalue
        else:
            return out


if __name__ == '__main__':
    """
    输入:[B, C, H, W] / [B, H*W, C]
    输出:[B, k, C]
    """
    x = torch.randn(4, 512, 7, 7)
    model = SPABlock(in_channels=512, k=7*7)
    output = model(x)
    print(output.shape)

标签:__,位置,torch,self,Attention,矩阵,2021,SPA,size
From: https://blog.csdn.net/wei582636312/article/details/144792438

相关文章

  • 【Spark SQL】Join连接条件使用or导致运行慢
    现象运行的SQL示例如下selectt1.*fromedw.at1leftjoinedw.bt2on(t1.id=t2.idor((t1.idisnullort2.idisnull)andt1.phone=t2.phone))andt1.province=t2.provinceandt1.city=t2.cityandt1.type=t2.typewheret2.typeisnull;提交运行......
  • 向量空间 Vector Spaces
    向量空间VectorSpaces​ 在GilbertStrang教授的书中,提到了导数的转置(TheTransposeofaDerivative)。在正式的向量空间内容之前,可以先了解一下导数与矩阵转置的联系。​ 考虑将矩阵看做一个运算符(或者说,算子),对于函数\(x(t)\)的线性代数。假设\(\symbfit{A}=\mathrm{d}/\mat......
  • day05_Spark SQL
    文章目录day05_SparkSQL课程笔记一、今日课程内容二、SparkSQL基本介绍(了解)1、什么是SparkSQL**为什么SparkSQL是“SQL与大数据之间的桥梁”?****实际意义**为什么要学习SparkSQL呢?**为什么SparkSQL像“瑞士军刀”?**2、SparkSQL与HIVE异同3、SparkSQL的数......
  • day06_Spark SQL
    文章目录day06_SparkSQL课程笔记一、今日课程内容二、DataFrame详解(掌握)5.清洗相关的API6.SparkSQL的Shuffle分区设置7.数据写出操作写出到文件写出到数据库三、SparkSQL的综合案例(掌握)1、常见DSL代码整理2、电影分析案例需求说明:需求分析:四、SparkSQL函数定义......
  • 对于open_space_roi_decider.cc的解析
    路径modules\planning\planning_open_space\utils\源码/*******************************************************************************Copyright2023TheApolloAuthors.AllRightsReserved.**LicensedundertheApacheLicense,Version2.0(the&quo......
  • 【Apache Paimon】-- 14 -- Spark 集成 Paimon 之 Filesystem Catalog 与 Hive Catalo
    目录1.背景介绍2.环境准备2.1、技术栈说明2.2、环境依赖2.3、硬件与软件环境2.4、主要工具清单2.5、Maven项目结构2.6、mavenpom.xml依赖3.Spark与Paimon FilesystemCatalog集成3.1、HDFSFileSystemcatalog3.1.1、代码内容3.1.2、运行输出结果3.1.2.......
  • Spark 源码分析(一) SparkRpc中序列化与反序列化Serializer的抽象类解读 (java序列化部
    目录(3)JavaSerializerInstance定义了一个Java序列化实例(1)构造方法参数(2)方法1:serializeStream(3)方法2:deserializeStreamdefaultClassLoader(4)方法3:deserializeStreamloader(5)方法4:serialize(6)方法5:deserializeloader(7)方法6:deserializedefaul......
  • 移民统计年鉴(1996-2021年)-社科数据
    移民统计年鉴(1996-2021年)-社科数据https://download.csdn.net/download/paofuluolijiang/90028564https://download.csdn.net/download/paofuluolijiang/90028564移民统计年鉴(1996-2021年)提供了一个全面的视角,以了解全球移民趋势和数据。这份年鉴详细记录了每年的全球移民......
  • P7603 [THUPC2021] 鬼街 题解
    P7603[THUPC2021]鬼街题解第一次见折半报警器的trick,记录一下首先观察到\(x\len\le10^5\),所以\(x\)最多有6个质因数,\(x=30030\)可以取到,这使得对于修改,我们可以暴力单点修改。接下来考虑询问,朴素的做法是:每一次灵异事件之后,都对所有监控器进行检验是否满足和......
  • Spark(一):初识Spark
    哈喽,大家好,我是Leven,今天我们花点时间初步了解大数据计算引擎Spark,也是我们从事数据工作中肯定会用到计算引擎。文章中有书写错误的内容,辛苦评论指正,感谢......