首页 > 其他分享 >MOGANET-SA模块

MOGANET-SA模块

时间:2024-11-10 16:08:54浏览次数:1  
标签:nn self dims MOGANET 模块 act attn embed SA

paper
`
import torch.nn as nn
import torch
import torch.nn.functional as F
def build_act_layer(act_type):
"""Build activation layer."""
if act_type is None:
return nn.Identity()
assert act_type in ['GELU', 'ReLU', 'SiLU']
if act_type == 'SiLU':
return nn.SiLU()
elif act_type == 'ReLU':
return nn.ReLU()
else:
return nn.GELU()
class ElementScale(nn.Module):
"""A learnable element-wise scaler."""

def __init__(self, embed_dims, init_value=0., requires_grad=True):
    super(ElementScale, self).__init__()
    self.scale = nn.Parameter(
        init_value * torch.ones((1, embed_dims, 1, 1)),
        requires_grad=requires_grad
    )

def forward(self, x):
    return x * self.scale

class MultiOrderDWConv(nn.Module):
"""Multi-order Features with Dilated DWConv Kernel.

Args:
    embed_dims (int): Number of input channels.
    dw_dilation (list): Dilations of three DWConv layers.
    channel_split (list): The raletive ratio of three splited channels.
"""

def __init__(self,
             embed_dims,
             dw_dilation=[1, 2, 3,],
             channel_split=[1, 3, 4,],
            ):
    super(MultiOrderDWConv, self).__init__()
    '''
    1/8  3/8  4/8
    1/8 dim  3/8 dim  4/8 dim
    '''
    self.split_ratio = [i / sum(channel_split) for i in channel_split]
    self.embed_dims_1 = int(self.split_ratio[1] * embed_dims)
    self.embed_dims_2 = int(self.split_ratio[2] * embed_dims)
    self.embed_dims_0 = embed_dims - self.embed_dims_1 - self.embed_dims_2
    self.embed_dims = embed_dims
    assert len(dw_dilation) == len(channel_split) == 3
    assert 1 <= min(dw_dilation) and max(dw_dilation) <= 3
    assert embed_dims % sum(channel_split) == 0

    # basic DW conv
    self.DW_conv0 = nn.Conv2d(
        in_channels=self.embed_dims,
        out_channels=self.embed_dims,
        kernel_size=5,
        padding=(1 + 4 * dw_dilation[0]) // 2,
        groups=self.embed_dims,
        stride=1, dilation=dw_dilation[0],
    )
    # DW conv 1
    self.DW_conv1 = nn.Conv2d(
        in_channels=self.embed_dims_1,
        out_channels=self.embed_dims_1,
        kernel_size=5,
        padding=(1 + 4 * dw_dilation[1]) // 2,
        groups=self.embed_dims_1,
        stride=1, dilation=dw_dilation[1],
    )
    # DW conv 2
    self.DW_conv2 = nn.Conv2d(
        in_channels=self.embed_dims_2,
        out_channels=self.embed_dims_2,
        kernel_size=7,
        padding=(1 + 6 * dw_dilation[2]) // 2,
        groups=self.embed_dims_2,
        stride=1, dilation=dw_dilation[2],
    )
    # a channel convolution
    self.PW_conv = nn.Conv2d(  # point-wise convolution
        in_channels=embed_dims,
        out_channels=embed_dims,
        kernel_size=1)

def forward(self, x):
    '''
    
    '''
    
    x_0 = self.DW_conv0(x)
    x_1 = self.DW_conv1(
        x_0[:, self.embed_dims_0: self.embed_dims_0+self.embed_dims_1, ...])
    x_2 = self.DW_conv2(
        x_0[:, self.embed_dims-self.embed_dims_2:, ...])
    x = torch.cat([
        x_0[:, :self.embed_dims_0, ...], x_1, x_2], dim=1)
    x = self.PW_conv(x)
    return x

class MultiOrderGatedAggregation(nn.Module):
"""Spatial Block with Multi-order Gated Aggregation.

Args:
    embed_dims (int): Number of input channels.
    attn_dw_dilation (list): Dilations of three DWConv layers.
    attn_channel_split (list): The raletive ratio of splited channels.
    attn_act_type (str): The activation type for Spatial Block.
        Defaults to 'SiLU'.
"""

def __init__(self,
             embed_dims,
             attn_dw_dilation=[1, 2, 3],
             attn_channel_split=[1, 3, 4],
             attn_act_type='SiLU',
             attn_force_fp32=False,
            ):
    super(MultiOrderGatedAggregation, self).__init__()

    self.embed_dims = embed_dims
    self.attn_force_fp32 = attn_force_fp32
    self.proj_1 = nn.Conv2d(
        in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
    self.gate = nn.Conv2d(
        in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)
    self.value = MultiOrderDWConv(
        embed_dims=embed_dims,
        dw_dilation=attn_dw_dilation,
        channel_split=attn_channel_split,
    )
    self.proj_2 = nn.Conv2d(
        in_channels=embed_dims, out_channels=embed_dims, kernel_size=1)

    # activation for gating and value
    self.act_value = build_act_layer(attn_act_type)
    self.act_gate = build_act_layer(attn_act_type)

    # decompose
    self.sigma = ElementScale(
        embed_dims, init_value=1e-5, requires_grad=True)

def feat_decompose(self, x):
    '''
    不改变宽高和维度的 点卷积
    '''
    x = self.proj_1(x)
    # x_d: [B, C, H, W] -> [B, C, 1, 1]  计算平均值
    x_d = F.adaptive_avg_pool2d(x, output_size=1)
    x = x + self.sigma(x - x_d)  # 每层减去平均值 再缩放 再和原来的相加
    x = self.act_value(x)
    return x

def forward_gating(self, g, v):
    with torch.autocast(device_type='cuda', enabled=False):
        g = g.to(torch.float32)
        v = v.to(torch.float32)
        return self.proj_2(self.act_gate(g) * self.act_gate(v))

def forward(self, x):
    shortcut = x.clone()
    # proj 1x1
    x = self.feat_decompose(x)
    # gating and value branch
    g = self.gate(x)  # 左分支
    v = self.value(x)  #右分支
    # aggregation
    if not self.attn_force_fp32: # 默认走这个分支
        x = self.proj_2(self.act_gate(g) * self.act_gate(v))
    else:
        x = self.forward_gating(self.act_gate(g), self.act_gate(v))
    x = x + shortcut
    return x

if name == 'main':
input = torch.randn(1, 64, 32, 32).cuda()# 输入 B C H W
block = MultiOrderGatedAggregation(embed_dims=64).cuda()
output = block(input)
print(input.size())
print(output.size())
`

标签:nn,self,dims,MOGANET,模块,act,attn,embed,SA
From: https://www.cnblogs.com/plumIce/p/18538123

相关文章

  • 合并果子 / [USACO06NOV] Fence Repair G
    题目描述在一个果园里,多多已经将所有的果子打了下来,而且按果子的不同种类分成了不同的堆。多多决定把所有的果子合成一堆。每一次合并,多多可以把两堆果子合并到一起,消耗的体力等于两堆果子的重量之和。可以看出,所有的果子经过 n−1n−1 次合并之后,就只剩下一堆了。多多在......
  • The 3rd Universal Cup. Stage 16: Nanjing
    B.BirthdayGift把原始串的偶数位取反,题目从消除相同就可以转换为消除不同。因此只要有不同位,就一定可以消除。因此最终剩下的串一定是全0或者全1。因此答案就是翻转后的1、0之差。我们用2尽可能的减少0,1只差即可。#include<bits/stdc++.h>#definelllonglongvo......
  • 轨迹联邦用到的具体公式 + 轨迹模块的设计
    轨迹数据+roadnetwork数据在这个LightTR框架中,输入的轨迹数据和路网数据通过以下几个模块进行计算和处理:1.局部轨迹预处理(LocalTrajectoryPreprocessing)首先,对输入的原始轨迹数据进行地图匹配(MapMatching)。地图匹配会将GPS记录的轨迹点投影到道路网络中的具体道路片......
  • Kafka - 启用安全通信和认证机制_SSL + SASL
    文章目录官方资料概述制作kakfa证书1.1openssl生成CA1.2生成server端秘钥对以及证书仓库1.3CA签名证书1.4服务端秘钥库导入签名证书以及CA根证书1.5生成服务端信任库并导入CA根数据1.6生成客户端信任库并导入CA根证书2配置zookeeperSASL认证2.1编写zk_server......
  • 鸿蒙HarmonyOS证书的安全管家:Device Certificate Kit中的证书管理模块
    本文旨在深入探讨华为鸿蒙HarmonyOSNext系统(截止目前API12)的技术细节,基于实际开发实践进行总结。主要作为技术分享与交流载体,难免错漏,欢迎各位同仁提出宝贵意见和问题,以便共同进步。本文为原创内容,任何形式的转载必须注明出处及原作者。在华为鸿蒙HarmonyOS的世界里,设备的安全......
  • 【模块一】kubernetes容器编排进阶实战之kubeadm部署kubernetes
    kubeadm部署kubernetes准备环境主机名IP地址k8s-master1        10.0.0.121k8s-node110.0.0.101k8s-node210.0.0.102k8s-node310.0.0.103注:提前安装好docker或者containerd环境安装kubeadm、kubectl、kubelet#分别在所有主机依次执行一下命令apt-getupdate&&......
  • 【产品经理修炼之道】-SaaS业务中的销售业务模块设计【营销获客的突破点】
    在SaaS业务领域,销售和营销策略的创新对于企业的增长至关重要。本文深入探讨了如何通过精心设计的营销获客策略来实现销售业务的突破,供大家参考。我认为,SMB的销售业务设计中,最核心的目标是不断地补充新客户数量,通过短期的销售策略和牵引机制实现销售团队业务目标与公司目标的统......
  • "vue-router/composables" 中为什么没有提供 onBeforeRouteEnter?
    在Vue3中,vue-router提供了新的组合式API(Composables),这些API旨在与Vue3的CompositionAPI一起使用。然而,onBeforeRouteEnter这个特定的导航守卫并没有直接在vue-router/composables中提供。原因主要有以下几点:1.组合式API的设计理念组合式API的设计理念是将......
  • 测试平台开发(一)鉴权模块7 Shiro基于JWT的认证
    Shiro简介ApacheShiro是一个强大且易用的Java安全框架,主要用于身份认证、授权、加密和会话管理。它的设计目标是简化安全性的实现,使开发者能够更专注于业务逻辑。以下是Shiro的主要作用和功能:1.身份认证(Authentication)用户登录:Shiro提供了简单而强大的API来处理......
  • [GWCTF 2019]babyRSA
    fromCrypto.Util.numberimport*fromgmpy2import*fromsympyimport*p=797862863902421984951231350430312260517773269684958456342860983236184129602390919026048496119757187702076499551310794177917920137646835888862706126924088411570997141257159563952......