首页 > 其他分享 >MOGANET-CA模块

MOGANET-CA模块

时间:2024-11-09 16:56:48浏览次数:1  
标签:nn CA channels MOGANET 模块 act type self size

paper
`

import torch
import torch.nn as nn

def build_act_layer(act_type):
#Build activation layer
if act_type is None:
return nn.Identity()
assert act_type in ['GELU', 'ReLU', 'SiLU']
if act_type == 'SiLU':
return nn.SiLU()
elif act_type == 'ReLU':
return nn.ReLU()
else:
return nn.GELU()

class ElementScale(nn.Module):
#A learnable element-wise scaler.

def __init__(self, embed_dims, init_value=0., requires_grad=True):
    super(ElementScale, self).__init__()
    self.scale = nn.Parameter(
        init_value * torch.ones((1, embed_dims, 1, 1)),
        requires_grad=requires_grad
    )

def forward(self, x):
    return x * self.scale

class ChannelAggregationFFN(nn.Module):
"""An implementation of FFN with Channel Aggregation.

Args:
    embed_dims (int): The feature dimension. Same as
        `MultiheadAttention`.
    feedforward_channels (int): The hidden dimension of FFNs.
    kernel_size (int): The depth-wise conv kernel size as the
        depth-wise convolution. Defaults to 3.
    act_type (str): The type of activation. Defaults to 'GELU'.
    ffn_drop (float, optional): Probability of an element to be
        zeroed in FFN. Default 0.0.
"""

def __init__(self,
             embed_dims,
             kernel_size=3,
             act_type='GELU',
             ffn_drop=0.):
    super(ChannelAggregationFFN, self).__init__()

    self.embed_dims = embed_dims
    self.feedforward_channels = int(embed_dims * 4)

    self.fc1 = nn.Conv2d(
        in_channels=embed_dims,
        out_channels=self.feedforward_channels,
        kernel_size=1)
    self.dwconv = nn.Conv2d(
        in_channels=self.feedforward_channels,
        out_channels=self.feedforward_channels,
        kernel_size=kernel_size,
        stride=1,
        padding=kernel_size // 2,
        bias=True,
        groups=self.feedforward_channels)
    self.act = build_act_layer(act_type)
    self.fc2 = nn.Conv2d(
        in_channels=self.feedforward_channels,
        out_channels=embed_dims,
        kernel_size=1)
    self.drop = nn.Dropout(ffn_drop)

    self.decompose = nn.Conv2d(
        in_channels=self.feedforward_channels,  # C -> 1
        out_channels=1, kernel_size=1,
    )
    self.sigma = ElementScale(
        self.feedforward_channels, init_value=1e-5, requires_grad=True)
    self.decompose_act = build_act_layer(act_type)

def feat_decompose(self, x):
    # x_d: [B, C, H, W] -> [B, 1, H, W]
    t=self.decompose(x)  #  将多通道用一个通道来表示
    t=self.decompose_act(t) # 对单通道应用GELU激活函数 增加非线性  帮助模型学习更加复杂的模式
    t=x - t #原始特征图减去t  去除或削弱了x与temp相似的特征  如果t是全局或主要特征 那么x-t可以理解成局部或差异信息
    t=self.sigma(t) # 包含一个可学习的参数 用来调整每个通道的权重
    x = x + t  # 将原特征图和缩放之后的特定信息相加
    return x

def forward(self, x):
    # proj 1
    x = self.fc1(x)
    x = self.dwconv(x)
    x = self.act(x)
    x = self.drop(x)
    # proj 2
    x = self.feat_decompose(x)
    x = self.fc2(x)
    x = self.drop(x)
    return x

if name == 'main':
input = torch.randn(1, 64, 32, 32).cuda()# 输入 B C H W
block = ChannelAggregationFFN(embed_dims=64).cuda()
output = block(input)
print(input.size())
print(output.size())

`

标签:nn,CA,channels,MOGANET,模块,act,type,self,size
From: https://www.cnblogs.com/plumIce/p/18536973

相关文章

  • [COCI2022-2023#5] Slastičarnica 题解
    前言题目链接:洛谷。题意简述一个长为\(n\)的序列\(\{a_n\}\)和\(q\)次操作,第\(i\)次操作中,你可以删除序列长为\(d_i\)的前缀或后缀,并需要保证删除的所有数\(\geqs_i\)。每次操作前,你可以选择任意长度的前缀或后缀,并将其删除,也可以不操作。请问,在你不能进行下一次操......
  • EMCAD: Efficient Multi-scale Convolutional Attention Decoding for Medical Image
    论文代码`importtorchimporttorch.nnasnnfromfunctoolsimportpartialfromtorch.nn.initimporttrunc_normal_importmathfromtimm.models.helpersimportnamed_applydefact_layer(act,inplace=False,neg_slope=0.2,n_prelu=1):#activationlayeract=......
  • (Lin的实施运维笔记06)解决Tomcat服务器在控制台窗口中的乱码问题
    产生乱码的根本原因就是编码和解码不一致,比较常见的编码格式有Unicode、ASCll码、GBK、UTF-8等,Tomcat控制台的乱码问题只需要把日志配置文件中的UTF-8格式改成GBK格式就行解决方法:1、找到Tomcat的安装目录下conf文件夹2、打开conf文件夹中的logging.properties文件,并搜索找......
  • 第一章:实现基础 HTTP 服务器-MiniTomcat系列
    上一章内容MiniTomcat项目大纲第一章:实现基础HTTP服务器-MiniTomcat系列在这一章中,我们将从零开始编写一个简单的HTTP服务器。这个服务器的基本功能是监听一个端口,接收来自客户端的HTTP请求,并返回一个HTTP响应。我们将使用Java的ServerSocket类来实现网络监......
  • 知识点:用例图(Use Case Diagram)
    知识点:该题目考查的是面向对象的分析与设计方法(Object-OrientedAnalysisandDesign,OOAD),特别是用例图(UseCaseDiagram)的相关知识点。用例图是UML(统一建模语言)中的一种图表,用于描述系统的功能需求,它展示了系统如何与外部用户或其他系统交互。知识点相关内容:用例(UseCase):用......
  • DDCA —— 缓存(Cache):缓存体系结构、缓存操作
    1.存储器层次(TheMemoryHierarchy)1.1现代系统中的存储器其中包括L1、L2、L3和DRAM1.2存储器的局限理想存储器的需求如下:零延迟容量无限零成本带宽无限零功耗但理想存储器的需求彼此冲突:容量更大的存储器意味着更大的延迟:需要花更长的时间来确定数据所在位置更......
  • bert-base-uncased处理文档
    1.安装必要的库确保安装transformers和torch库:pipinstalltransformerstorch2.加载本地BERT模型和分词器由于已将模型和分词器下载到本地,可以指定文件路径加载。确保路径与本地文件结构一致。fromtransformersimportBertTokenizer,BertModel#指定模型和分......
  • SCAU 高级程序设计语言 教材习题
    SCAU高级程序设计语言教材习题第三章18041分期还款(加强版)Description从银行贷款金额为\(d\),准备每月还款额为\(p\),月利率为\(r\)。请编写程序输入这三个数值,计算并输出多少个月能够还清贷款,输出时保留\(1\)位小数。如果无法还清,请输出“God”计算公式如下:\[m=......
  • Web缓存中毒(Web Cache Poisoning)是一种网络攻击技术,攻击者通过篡改或伪造Web服务器的
    Web缓存中毒(WebCachePoisoning)是一种网络攻击技术,攻击者通过篡改或伪造Web服务器的缓存内容,使得用户在访问网站时,获得恶意内容或错误内容的攻击方式。这种攻击通常依赖于利用Web缓存的设计缺陷或未充分验证的请求参数,从而让缓存服务器存储并返回恶意的、篡改过的响应。工作原理......
  • 【Tomcat】Tomat 处理请求的过程(图解)
    1 前言最近在复习Tomcat的请求处理过程,之前也看过一些局部的细节,【SpringBoot+Tomcat】【一】请求到达后端服务进程后的处理过程-连接器的创建和执行、【SpringBoot+Tomcat】【二】请求到达后端服务进程后的处理过程-连接的处理细节,但是没看完整,这节我们从整体看一下Tom......