首页 > 编程语言 >【即插即用】ShuffleAttention注意力机制(附源码)

【即插即用】ShuffleAttention注意力机制(附源码)

时间:2024-04-09 18:59:57浏览次数:30  
标签:nn self torch ShuffleAttention init 源码 即插即用 SA channel

原文链接:

https://arxiv.org/pdf/2102.00240.pdf

源码地址:

https://github.com/wofmanaf/SA-Ne

摘要简介:

注意力机制让神经网络能够准确关注输入的所有相关元素,已成为提高深度神经网络性能的关键组件。在计算机视觉研究中,主要有两种广泛使用的注意力机制:空间注意力和通道注意力。它们分别旨在捕捉像素级的成对关系和通道依赖性。虽然将它们融合在一起可能比单独使用它们表现更好,但这会不可避免地增加计算开销。

在本文中,我们提出了一个高效的Shuffle Attention(SA)模块来解决这个问题。该模块采用Shuffle单元有效地结合了两种注意力机制。具体来说,SA首先将通道维度分组为多个子特征,然后并行处理它们。接着,对于每个子特征,SA使用Shuffle单元来描绘空间和通道维度上的特征依赖性。之后,所有子特征被聚合,并采用“通道Shuffle”操作符来使不同子特征之间的信息得以交流。

提出的SA模块既高效又有效。例如,与骨干网络ResNet50相比,SA的参数和计算量分别为300与25.56M,以及2.76e-3 GFLOPs与4.12 GFLOPs,但Top-1准确率提高了1.34%以上。在ImageNet-1k分类、MS COCO目标检测和实例分割等常用基准测试上的大量实验结果表明,所提出的SA在保持较低模型复杂度的同时,显著优于当前的SOTA方法,实现了更高的精度。

模型结构图:

Pytorch版源码:
import torch
from torch import nn
from torch.nn import init
from torch.nn.parameter import Parameter


class ShuffleAttention(nn.Module):

    def __init__(self, channel=512, G=8):
        super().__init__()
        self.G = G
        self.channel = channel
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.gn = nn.GroupNorm(channel // (2 * G), channel // (2 * G))
        self.cweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.cbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sweight = Parameter(torch.zeros(1, channel // (2 * G), 1, 1))
        self.sbias = Parameter(torch.ones(1, channel // (2 * G), 1, 1))
        self.sigmoid = nn.Sigmoid()

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    @staticmethod
    def channel_shuffle(x, groups):
        b, c, h, w = x.shape
        x = x.reshape(b, groups, -1, h, w)
        x = x.permute(0, 2, 1, 3, 4)

        # 扁平化
        x = x.reshape(b, -1, h, w)

        return x

    def forward(self, x):
        b, c, h, w = x.size()
        # 将通道分成子特征
        x = x.view(b * self.G, -1, h, w)  # bs*G,c//G,h,w

        # 通道分割
        x_0, x_1 = x.chunk(2, dim=1)  # bs*G,c//(2*G),h,w

        # 通道注意力
        x_channel = self.avg_pool(x_0)  # bs*G,c//(2*G),1,1
        x_channel = self.cweight * x_channel + self.cbias  # bs*G,c//(2*G),1,1
        x_channel = x_0 * self.sigmoid(x_channel)

        # 空间注意力
        x_spatial = self.gn(x_1)  # bs*G,c//(2*G),h,w
        x_spatial = self.sweight * x_spatial + self.sbias  # bs*G,c//(2*G),h,w
        x_spatial = x_1 * self.sigmoid(x_spatial)  # bs*G,c//(2*G),h,w

        # 沿通道轴拼接
        out = torch.cat([x_channel, x_spatial], dim=1)  # bs*G,c//G,h,w
        out = out.contiguous().view(b, -1, h, w)

        # 通道混洗
        out = self.channel_shuffle(out, 2)
        return out


if __name__ == '__main__':
    input = torch.randn(2, 32, 512, 512)
    SA = ShuffleAttention(channel=input.size(1))
    output = SA(input)
    print(output.shape)

标签:nn,self,torch,ShuffleAttention,init,源码,即插即用,SA,channel
From: https://blog.csdn.net/weixin_45694817/article/details/137563356

相关文章

  • 【独立版】手边酒店多商户版V2小程序源码部署在线更新SAAS坑位账号
    手边酒店多商户版V2小程序源码:解锁便捷酒店预订新体验在快节奏的现代生活中,我们总是追求更高效、更便捷的生活方式。手边酒店多商户版V2小程序源码,正是为了满足这一需求而诞生的创新之作。这款小程序源码,汇聚了众多酒店商户资源,为用户提供了一个集中、便捷的酒店预订平台。无......
  • starganvc2变声器项目实战及其源码解读
    1.数据与项目文件解读        数据文件目录如下所示,需要注意的是,我们并不能直接对声音进行建模,而需要对声音数据进行预处理,从而得到一系列数值特征,然后对特征进行建模,特征数据存储到processed文件夹中         2.环境配置        pipinstall li......
  • 【全开源】JAVA红娘婚恋相亲交友系统源码支持微信小程序+微信公众号+H5+APP
    JAVA红娘婚恋相亲交友系统源码:跨平台交友新纪元,微信小程序、公众号、H5、APP全覆盖在数字化浪潮汹涌的今天,婚恋相亲已不再是传统的线下模式所能满足。JAVA红娘婚恋相亲交友系统源码,以其卓越的跨平台特性和强大的功能优势,为您打造了一个全新的相亲交友体验。无论是微信小程序、......
  • 【全开源】JAVA上门家政服务系统源码微信小程序+微信公众号+APP+H5
    JAVA上门家政服务系统源码:一站式家政服务,微信小程序、公众号、APP、H5全平台覆盖,便捷生活触手可及在现代生活的快节奏中,人们对家政服务的需求日益旺盛。JAVA上门家政服务系统源码,以其高效、便捷的特性,结合微信小程序、公众号、APP和H5平台,为您打造了一站式的家政服务体验,让您......
  • 【全开源】JAVA红娘婚恋相亲交友系统源码支持微信小程序+微信公众号+H5+APP
    JAVA红娘婚恋相亲交友系统源码:跨平台交友新纪元,微信小程序、公众号、H5、APP全覆盖在数字化浪潮汹涌的今天,婚恋相亲已不再是传统的线下模式所能满足。JAVA红娘婚恋相亲交友系统源码,以其卓越的跨平台特性和强大的功能优势,为您打造了一个全新的相亲交友体验。无论是微信小程序、......
  • 【全开源】JAVA上门家政服务系统源码微信小程序+微信公众号+APP+H5
    JAVA上门家政服务系统源码:一站式家政服务,微信小程序、公众号、APP、H5全平台覆盖,便捷生活触手可及在现代生活的快节奏中,人们对家政服务的需求日益旺盛。JAVA上门家政服务系统源码,以其高效、便捷的特性,结合微信小程序、公众号、APP和H5平台,为您打造了一站式的家政服务体验,让您......
  • 【全开源】JAVA红娘婚恋相亲交友系统源码支持微信小程序+微信公众号+H5+APP
    JAVA红娘婚恋相亲交友系统源码:跨平台交友新纪元,微信小程序、公众号、H5、APP全覆盖在数字化浪潮汹涌的今天,婚恋相亲已不再是传统的线下模式所能满足。JAVA红娘婚恋相亲交友系统源码,以其卓越的跨平台特性和强大的功能优势,为您打造了一个全新的相亲交友体验。无论是微信小程序、......
  • 【全开源】JAVA上门家政服务系统源码微信小程序+微信公众号+APP+H5
    JAVA上门家政服务系统源码:一站式家政服务,微信小程序、公众号、APP、H5全平台覆盖,便捷生活触手可及在现代生活的快节奏中,人们对家政服务的需求日益旺盛。JAVA上门家政服务系统源码,以其高效、便捷的特性,结合微信小程序、公众号、APP和H5平台,为您打造了一站式的家政服务体验,让您......
  • 【全开源】JAVA红娘婚恋相亲交友系统源码支持微信小程序+微信公众号+H5+APP
    JAVA红娘婚恋相亲交友系统源码:跨平台交友新纪元,微信小程序、公众号、H5、APP全覆盖在数字化浪潮汹涌的今天,婚恋相亲已不再是传统的线下模式所能满足。JAVA红娘婚恋相亲交友系统源码,以其卓越的跨平台特性和强大的功能优势,为您打造了一个全新的相亲交友体验。无论是微信小程序、......
  • 【全开源】JAVA上门家政服务系统源码微信小程序+微信公众号+APP+H5
    JAVA上门家政服务系统源码:一站式家政服务,微信小程序、公众号、APP、H5全平台覆盖,便捷生活触手可及在现代生活的快节奏中,人们对家政服务的需求日益旺盛。JAVA上门家政服务系统源码,以其高效、便捷的特性,结合微信小程序、公众号、APP和H5平台,为您打造了一站式的家政服务体验,让您......