首页 > 其他分享 >【Block总结】门控结构的MLP结构

【Block总结】门控结构的MLP结构

时间:2025-01-08 19:31:06浏览次数:3  
标签:__ features 特征 self MLP 维度 Gate 门控 Block

模块

记录一个具有门控模块的MLP,这个模块可以降低MLP的参数量,还可以提高模型的精度,很多模型都用到了这样的结构,代码如下:

class Gate(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) # DW Conv

    def forward(self, x, H, W):
        # Split
        x1, x2 = x.chunk(2, dim = -1)
        B, N, C = x.shape
        x2 = self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C//2, H, W)).flatten(2).transpose(-1, -2).contiguous()

        return x1 * x2


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.sg = Gate(hidden_features//2)
        self.fc2 = nn.Linear(hidden_features//2, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x, H, W):
        """
        Input: x: (B, H*W, C), H, W
        Output: x: (B, H*W, C)
        """
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)

        x = self.sg(x, H, W)
        x = self.drop(x)

        x = self.fc2(x)
        x = self.drop(x)
        return x

这个代码定义了两个类:GateMLP,它们都是基于PyTorch框架实现的神经网络模块。下面是对这两个类及其功能的讲解:

Gate 类

Gate类是一个自定义的神经网络层,主要用于对输入特征进行特定的变换。它的结构如下:

  • 初始化 (__init__ 方法):

    • dim: 输入特征的维度。
    • self.norm: 一个Layer Normalization层,用于对输入特征进行归一化处理,有助于加速训练过程并提高模型的稳定性。
    • self.conv: 一个深度可分离卷积层(Depthwise Convolution,通过设置groups=dim实现),卷积核大小为3x3,步长为1,填充为1。这意味着卷积操作在每个输入通道上独立进行,有助于捕捉局部特征。
  • 前向传播 (forward 方法):

    • 输入x的形状为(B, N, C),其中B是批次大小,N是特征的数量(可能是空间维度H*W的展平),C是特征维度。
    • x被沿着最后一个维度分成两部分x1x2
    • x2首先经过Layer Normalization,然后重塑并转置以适配卷积层的输入要求,接着进行深度可分离卷积,最后再次重塑和转置以恢复原始形状的一部分维度。
    • 输出是x1与变换后的x2的逐元素乘积,这种操作可能有助于特征之间的交互和信息流动。

MLP 类

MLP类是一个多层感知机(Multilayer Perceptron),其结构如下:

  • 初始化 (__init__ 方法):

    • in_features: 输入特征的维度。
    • hidden_features: 隐藏层的特征维度,默认为输入特征的维度。
    • out_features: 输出特征的维度,默认为输入特征的维度。
    • act_layer: 激活函数层,默认为GELU(Gaussian Error Linear Unit)。
    • drop: Dropout比率,用于减少过拟合。
    • self.fc1: 第一个全连接层,将输入特征映射到隐藏层特征空间。
    • self.act: 激活函数层。
    • self.sg: 一个Gate层,对隐藏层特征的一部分进行特定的变换。
    • self.fc2: 第二个全连接层,将变换后的隐藏层特征映射到输出特征空间。
    • self.drop: Dropout层。
  • 前向传播 (forward 方法):

    • 输入x的形状为(B, H*W, C),其中B是批次大小,H*W是空间维度的展平,C是特征维度。
    • 输入x首先经过第一个全连接层、激活函数层和Dropout层。
    • 然后,x被传递给Gate层进行处理,这一步可能涉及特征的重新组合和局部信息的捕捉。
    • 经过Gate层处理后,x再次经过Dropout层,然后传递给第二个全连接层。
    • 最后,输出经过另一个Dropout层处理,得到最终的输出。

总的来说,这个MLP类通过结合全连接层、激活函数、Dropout和自定义的Gate层,实现了一个具有复杂特征变换能力的多层感知机,适用于处理具有空间维度的特征数据。

在这里插入图片描述

标签:__,features,特征,self,MLP,维度,Gate,门控,Block
From: https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/144981245

相关文章

  • 【Block总结】SGE注意力机制
    一、论文介绍论文链接:https://arxiv.org/pdf/1905.09646研究背景:论文首先提及了在计算机视觉领域,特征分组的思想由来已久,并介绍了相关背景。研究目的:旨在通过引入SGE模块,改善特征图的空间分布,提升模型对特定语义特征的表示能力。实验平台:实验代码和预训练模型可在https://......
  • 【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速水平集演化结合的方
    【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速水平集演化结合的方式?附代码(二)【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速水平集演化结合的方式?附代码(二)文章目录【深度学习|变化检测】如何理解基于门控注意力的池化层及其与快速......
  • YOLO11改进:block优化 | PKIBlock多尺度卷积核,助力小目标涨点 | CVPR2024 PKINet 遥感
     ......
  • 【GreatSQL优化器-09】make_join_query_block
    【GreatSQL优化器-09】make_join_query_block一、make_join_query_block介绍GreatSQL优化器对于多张表join的连接顺序在前面的章节介绍过的best_access_path函数已经执行了,接着就是把where条件进行切割然后推给合适的表。这个过程就是由函数make_join_query_block来执行的。下......
  • 【Block总结】CrossFormerBlock
    论文介绍链接:https://arxiv.org/pdf/2108.00154CrossFormerBlock模块提出:论文提出了一种名为CrossFormer的视觉Transformer模型,其中重点介绍了CrossFormerBlock模块的设计。研究背景:针对视觉任务中自注意力模块计算成本高、难以处理跨尺度交互的问题,CrossFormerBlock模块......
  • 在 PowerShell 中实时监控与 SMB(Server Message Block)协议相关的所有活动和功能,通常可
    在PowerShell中实时监控与SMB(ServerMessageBlock)协议相关的所有活动和功能,通常可以通过以下几个方式来实现:1. 监控SMB共享的访问可以通过Get-SmbSession和Get-SmbShare等cmdlet来查看SMB共享的活动状态。这些cmdlet允许你获取有关当前SMB会话、共享、客户端......
  • CBAM (Convolutional Block Attention Module)注意力机制详解
    定义与起源CBAM(ConvolutionalBlockAttentionModule)是一种专为卷积神经网络(CNN)设计的注意力机制,旨在增强模型对关键特征的捕捉能力。这一创新概念首次出现在2018年的研究论文《CBAM:ConvolutionalBlockAttentionModule》中。CBAM的核心思想是在通道和空间两个维......
  • 【AI学习笔记5】用C语言实现一个最简单的MLP网络 A simple MLP Neural network in C
    用C语言实现一个最简单的MLP网络AsimpleMLPNeural NetworkinClanguage 从图像中识别英文字母【1】从图像中识别多个不同的数字,属于多分类问题;每个图像是5*5的像素矩阵,分别包含1-5五个字母数字; 网络结构:一个隐藏层的MLP网络;       每个图像是5x5个......
  • Python机器学习算法KNN、MLP、NB、LR助力油气钻井大数据提速参数优选及模型构建研究
    全文链接:https://tecdat.cn/?p=38601原文出处:拓端数据部落公众号分析师:HuayanMu随着机器学习和大数据分析技术的发展,帮助客户进行油气行业数字化转型势在必行,钻井提速参数优选呈现由经验驱动、逻辑驱动向数据驱动转变的趋势。机械钻速最大化、机械比能最小化是钻井过程中常考......
  • Oracle 20c Native Blockchain Table vs. Hyperledger
     一、OracleNativeBlockchain(甲骨文原生区块链)(一)特点紧密集成数据库OracleNativeBlockchain与Oracle数据库紧密集成。这意味着对于已经在使用Oracle数据库的企业来说,能够很方便地利用现有基础设施。例如,企业的ERP(企业资源规划)系统等依......