首页 > 其他分享 >一文读懂SEnet:如何让机器学习模型学会“重点观察”

一文读懂SEnet:如何让机器学习模型学会“重点观察”

时间:2024-08-03 10:26:23浏览次数:16  
标签:一文 nn self SEnet 读懂 init se tensor

深入探讨一个在图像识别、自然语言处理等众多领域大放异彩的注意力模块——Squeeze-and-Excitation Networks(SEnet)。本文不仅会理论剖析SEnet的核心原理,还会手把手带你完成在TensorFlow和Pytorch这两个主流框架上的代码实现。准备好了吗?一起步入注意力机制的精妙世界。

一、SEnet:深度网络的敏锐洞察力

SEnet首次提出于2017年,它革新了传统神经网络对特征处理的方式。不同于常规的前馈结构,SEnet通过在每个卷积块后添加一个轻量级的注意力模块,动态地为每个通道分配权重,强调重要特征而抑制不相关的信号。这项创新不仅提高了模型效率,还显著提升了在多个基准测试任务上的性能。

论文与代码资源

在深入探讨之前,这里附上SEnet的原始论文链接[1],以及官方提供的代码仓库[2],供各位进一步探索。

二、SE模块的奥秘

1. 输入与结构概述

SE模块处理的输入通常是具有形态为D×H×W×C的数据,其中D、H、W分别代表深度、高度和宽度,C则代表通道数。特别地,在二维图像处理场景中,深度D一般为1。

2. “挤压”(Squeeze)操作

通道压缩

SE模块首先通过全局平均池化将三维空间特征(H×W×C)压扁至一维(1×1×C),这个过程被称为“挤压”。此步骤有效地提取了全局上下文信息,为接下来的特征重标定奠定基础。

3. “激励”(Excitation)操作

特征再赋权
  • 双全连接层:压缩后的特征经过两个全连接层处理。第一个全连接层大幅度降低维度(使用降维比ratio,如默认的2^n),然后通过ReLU激活引入非线性;第二个全连接层将维度恢复到原始通道数C,为每个通道产生一个标量权重。
  • Sigmoid激活:最终,利用Sigmoid函数将得到的权重映射至(0,1)区间,实现对每个通道的重要性判断。

4. “Scale”操作

得到的通道权重被用于逐元素地缩放原始特征图的通道值,强化重要特征,抑制无关特征。

三、代码实现:从理论到实践

1. TensorFlow实现

定义squeeze_excite_block
def squeeze_excite_block(input, ratio=16):
    init = tf.keras.initializers.he_normal()
    filters = input.shape[-1]

    # Squeeze
    se_shape = (1, 1, filters)
    se = tf.keras.layers.GlobalAveragePooling2D()(input)
    se = tf.keras.layers.Reshape(se_shape)(se)
    
    # Excitation
    se = tf.keras.layers.Dense(filters // ratio, activation='relu', kernel_initializer=init)(se)
    se = tf.keras.layers.Dense(filters, activation='sigmoid', kernel_initializer=init)(se)
    
    # Scale
    scaled = tf.keras.layers.multiply([input, se])
    return scaled

2. PyTorch实现

squeeze_excite_block
import torch.nn as nn
from torch import sigmoid

class SqueezeExcite(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super(SqueezeExcite, self).__init__()
        reduced_dim = in_channels // ratio
        self.fc1 = nn.Linear(in_channels, reduced_dim)
        self.fc2 = nn.Linear(reduced_dim, in_channels)
        self.init_weights()
    
    def init_weights(self):
        nn.init.kaiming_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.uniform_(self.fc2.weight, a=-1e-3, b=1e-3)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        se_tensor = F.avg_pool2d(x, (height, width)).view(batch_size, channels)
        se_tensor = self.fc1(se_tensor)
        se_tensor = F.relu(se_tensor)
        se_tensor = self.fc2(se_tensor)
        se_tensor = sigmoid(se_tensor)
        se_tensor = se_tensor.view(batch_size, channels, 1, 1)
        return x * se_tensor

四、结语

SEnet通过其巧妙的注意力机制,让模型更加专注于关键信息,有效促进了深度学习在诸多领域的应用。本文从理论层面剖析了SE模块的内部机制,并通过示例代码展示了在TensorFlow和Pytorch中的具体实现路径,希望对你理解并运用这一先进工具带来启发。

目前PlugLink发布了开源版和应用版,开源版下载地址:
Github地址:https://github.com/zhengqia/PlugLink
Gitcode地址:https://gitcode.com/zhengiqa8/PlugLink/overview
Gitee地址:https://gitee.com/xinyizq/PlugLink

标签:一文,nn,self,SEnet,读懂,init,se,tensor
From: https://blog.csdn.net/zhengiqa8/article/details/140886429

相关文章

  • 一文搞定:Syncthing多平台文件同步工具安装全攻略
    简介Syncthing是一款开源的文件同步工具,可以通过本地网络或互联网实现多台设备之间的文件同步。与其他同步工具不同,Syncthing强调隐私和安全,确保用户的数据始终处于用户的控制之下。功能与特点开源软件:Syncthing是完全开源的,源代码托管在GitHub上,任何人都可以查看、审查和......
  • 一文掌握Python全部条件执行语句(基础篇)
    前言本文,小编将总结一个非常实用而且非常基础的Python知识点“条件语句”。熟练掌握python条件语句,让你的程序代码做出精准判断,实现智能决策。废话不多说,接下来在正文中,将结合实际代码案例进行详细说明。正文1.if基础语句我们直接看下面的代码示例,如下所示:#假设这是......
  • 一文读懂CST电磁仿软件的TLM算法原理和历史背景
    这期我们免公式地介绍一下TLM原理。TLM(TransmissionLineMethod)是传输线矩阵算法,基于Huygens的波传播模型的三维全波电磁算法,注意是fullwave哦!什么是Huygens原理?惠更斯原理能准确计算波的传播。简单讲就是波传播的最前沿(wavefront)上每个点都可以看作是下一时刻的波的点源。......
  • 一文读完CST软件的发展历程
    CST软件经过近三十年的发展形成了自己独特技术路线,相较于市面上其他的产品最大的特征就是完备的技术(Completetechnology)。本期我们借由对CST历史的介绍,逐渐展开对CST软件的各个算法的特点介绍。CST工作室套装本身就超过20个求解器,能解决从直流、低频到高频,光学,多物理场,PCB,线缆......
  • 【Linux应急响应—下 】一文解明Linux应急响应(hw蓝队兄弟看这里):主机资源异常如何排查?C
    Linux应急响应重要声明linux应急响应各项资源异常CPU排查内存网络带宽网络连接关闭进程Linux系统日志排查登入验证日志登入失败次数登入成功统计攻击者IP个数攻击次数排列,由高到低中间件日志nginxapachetomcat分析维度:上篇文章在此处:【Linux应急响应—上】一文......
  • 一文详解Denoising Diffusion Implicit Models(DDIM)
    目录0前言1DDIM2总结0前言  上一篇博文我们介绍了目前流行的扩散模型基石DDPM,并且给出了代码讲解,有不了解的小伙伴可以跳转到前面先学习一下。今天我们再来介绍下DDPM的改进版本。DDPM虽然对生成任务带来了新得启发,但是他有一个致命的缺点,就是推理速度比较慢,......
  • Java并发(十六)一文搞懂Java 线程池原理
    简介什么是线程池线程池是一种多线程处理形式,处理过程中将任务添加到队列,然后在创建线程后自动启动这些任务。为什么要用线程池如果并发请求数量很多,但每个线程执行的时间很短,就会出现频繁的创建和销毁线程。如此一来,会大大降低系统的效率,可能频繁创建和销毁线程的时间......
  • 【Python正则-驯化】一文学会通过Python中的正则表达式提取文本数据中的电话号码:re
    【Python正则-驯化】一文学会通过Python中的正则表达式提取文本数据中的电话号码:re 本次修炼方法请往下查看......
  • 【Python正则-驯化】一文学会通过Python中的正则表达式提取文本中的网址
    【Python正则-驯化】一文学会通过Python中的正则表达式提取文本中的网址 本次修炼方法请往下查看......
  • 一文带你了解CAP的全部特性,你学会了吗?
    目录前言消息发布携带消息头设置消息前缀原生支持的延迟消息并行发布消息事务消息事务消息发送事务消息消费事务补偿消息处理序列化过滤器消息重试多线程处理自动恢复/重连分布式存储锁消息版本隔离优化的雪花算法消息自动清理消费者特性Attribute订阅多Attribute订阅通配符订阅......