首页 > 编程语言 >【即插即用】RefConv-重聚焦卷积模块(附源码)

【即插即用】RefConv-重聚焦卷积模块(附源码)

时间:2024-03-18 12:30:45浏览次数:28  
标签:RefConv nn 卷积 模型 源码 参数 即插即用 聚焦

论文地址: http://arxiv.org/pdf/2310.10563.pdf
源码地址:GitHub - Aiolus-X/RefConv
概述:

作者提出了一种可重参数化的重新聚焦卷积(RefConv),作为常规卷积层的即插即用替代品,能够在不引入额外推理成本的情况下显著提高基于CNN的模型性能。RefConv利用预训练参数编码的表示作为先验,通过重新聚焦这些参数来学习新的表示,进一步增强了模型结构的先验,提升了预训练模型的表示能力。实验证明,RefConv在图像分类、目标检测和语义分割等任务中表现出色,并能够减少通道冗余、平滑损失景观,从而解释了其有效性。

作者首先将预训练好的卷积模型的卷积层替换为重参数化重聚焦卷积 (RefConv),如图1所示。

RefConv的核心理念:

通过重新参数化和重新聚焦来增强卷积神经网络的特征提取能力。简单来说,它让网络更聪明地学习如何处理复杂的数据,从而在各种计算机视觉任务中表现更出色。它通过在卷积核之间建立额外的联系来提升模型的先验知识。RefConv是一种方便的模块,可以直接插入到现有模型中,而无需改变模型的结构或增加额外的计算成本,就可以明显提高模型的性能。此外,论文还指出,RefConv有助于减少模型中不必要的参数和优化过程中的损失函数,这进一步证明了它的有效性。这些发现可能会引发更多关于训练模型时的动态理论方面的研究。

在传统的卷积神经网络中,卷积核的权重是固定的,只是在训练过程中微调。但RefConv却有所不同,它引入了额外的参数来调整这些权重,让每个卷积核可以学到更多不同的特征。这种重新参数化的方法让网络能够更灵活地对输入特征做出反应,而且不会增加太多计算成本。

Pytorch源码:
import torch
import torch.nn as nn
import torch.nn.functional as F

class RepConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=None, groups=1,
                 map_k=3):
        super(RepConv, self).__init__()
        assert map_k <= kernel_size
        # 记录原始卷积核形状
        self.origin_kernel_shape = (out_channels, in_channels // groups, kernel_size, kernel_size)
        self.register_buffer('weight', torch.zeros(*self.origin_kernel_shape))
        G = in_channels * out_channels // (groups ** 2)
        self.num_2d_kernels = out_channels * in_channels // groups
        self.kernel_size = kernel_size
        # 使用 2D 卷积生成映射
        self.convmap = nn.Conv2d(in_channels=self.num_2d_kernels,
                                 out_channels=self.num_2d_kernels, kernel_size=map_k, stride=1, padding=map_k // 2,
                                 groups=G, bias=False)
        self.bias = None
        self.stride = stride
        self.groups = groups
        if padding is None:
            padding = kernel_size // 2
        self.padding = padding

    def forward(self, inputs):
        # 生成权重矩阵
        origin_weight = self.weight.view(1, self.num_2d_kernels, self.kernel_size, self.kernel_size)
        # 使用卷积映射更新权重
        kernel = self.weight + self.convmap(origin_weight).view(*self.origin_kernel_shape)
        return F.conv2d(inputs, kernel, stride=self.stride, padding=self.padding, dilation=1, groups=self.groups, bias=self.bias)

class RepConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(RepConvBlock, self).__init__()
        # 定义 RepConv 模块
        self.conv = RepConv(in_channels, out_channels, kernel_size=3, stride=stride, padding=None, groups=1, map_k=3)
        # 批量归一化层
        self.bn = nn.BatchNorm2d(out_channels)
        # 激活函数
        self.act = Hswish()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

class Hswish(nn.Module):
    def __init__(self, inplace=True):
        super(Hswish, self).__init__()
        self.inplace = inplace

    def forward(self, x):
        # H-swish 激活函数
        return x * F.relu6(x + 3., inplace=self.inplace) / 6.

# 测试模块
if __name__ == "__main__":
    # 创建 RepConvBlock 实例并进行前向传播测试
    block = RepConvBlock(in_channels=3, out_channels=64, stride=1)
    x = torch.randn(1, 3, 224, 224)
    output = block(x)
    print("Output shape:", output.shape)

标签:RefConv,nn,卷积,模型,源码,参数,即插即用,聚焦
From: https://blog.csdn.net/weixin_45694817/article/details/136805323

相关文章

  • 【即插即用】ELA注意力机制(附源码)
    原文地址:[2403.01123]ELA:EfficientLocalAttentionforDeepConvolutionalNeuralNetworks(arxiv.org)与SE、CA注意力机制的区别:ELA通过在空间维度采用带状池化来提取水平和垂直方向的特征向量,维持细长的核形状以捕捉远距离的依赖关系,同时避免不相关区域对标签预测的......
  • 基于SpringBoot的“乐校园二手书交易管理系统”的设计与实现(源码+数据库+文档+PPT)
    基于SpringBoot的“乐校园二手书交易管理系统”的设计与实现(源码+数据库+文档+PPT)开发语言:Java数据库:MySQL技术:SpringBoot工具:IDEA/Ecilpse、Navicat、Maven系统展示系统首页界面图用户注册界面图二手图书界面图留言反馈界面图个人中心界面图管理员......
  • 基于SpringBoot的“书籍学习平台”的设计与实现(源码+数据库+文档+PPT)
    基于SpringBoot的“书籍学习平台”的设计与实现(源码+数据库+文档+PPT)开发语言:Java数据库:MySQL技术:SpringBoot工具:IDEA/Ecilpse、Navicat、Maven系统展示平台首页界面图用户注册界面图付费专区界面图个人中心界面图后台登录界面图管理员功能界面图......
  • 基于微信小程序的高校跑腿小程序,附源码
    博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝15w+、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、Python技术领域和毕业项目实战✌......
  • Swoole 源码分析之 WebSocket 模块
    首发原文链接:Swoole源码分析之WebSocket模块大家好,我是码农先森。Swoole源码分析之WebSocket模块引言WebSocket是一种在单个TCP连接上进行全双工通信的协议。它允许客户端和服务器之间进行实时数据传输。与传统的HTTP请求-响应模型不同,WebSocket可以保持双向通信......
  • 【前端素材】推荐优质电影票购票商城网站设计Ticket平台模板(附源码)
     一、需求分析1、功能分析在线电影票购票商城是指一个通过互联网提供电影票购买服务的平台。它通常包括以下功能:电影信息展示:商城会展示当前热映电影、即将上映电影和影片详情,包括电影名称、演员阵容、导演、剧情简介、上映时间等信息,帮助用户选择电影。影院选择和座位......
  • 【前端素材】推荐优质在线创意家居电商网站设计Umbra平台模板(附源码)
    一、需求分析1、功能分析在线家具装饰商城是指通过互联网平台提供家具和装饰产品购买服务的电子商务平台。以下是关于在线家具装饰商城的具体功能和特点的详细分析:产品展示和购买:在线家具装饰商城通过网站或应用程序展示各种家具和装饰产品的图片、描述、价格等信息,方便用......
  • LinkedList源码解析和设计思路
    一、继承体系LinkedList类位于java.util包中,它实现了List接口和Deque接口,LinkedList可以被当做链表、双端队列使用,并且继承自AbstractSequentialList类。在继承关系中,它的父类是AbstractSequentialList,而AbstractSequentialList又继承自AbstractList,AbstractList继承自Abs......
  • 基于springboot实现大学生租房平台项目设计与实现演示【附项目源码+论文说明】
    基于springboot实现大学生租房平台的设计与实现演示摘要互联网发展至今,无论是其理论还是技术都已经成熟,而且它广泛参与在社会中的方方面面。它让信息都可以通过网络传播,搭配信息管理工具可以很好地为人们提供服务。针对大学生租房信息管理混乱,出错率高,信息安全性差,劳动强......
  • 基于SpringBoot实现网上订餐系统项目演示【附项目源码+论文说明】
    基于SpringBoot的网上订餐系统演示摘要随着我国经济的飞速发展,人们的生活速度明显加快,在餐厅吃饭排队的情况到处可见,近年来由于新兴IT行业的空前发展,它与传统餐饮行业也进行了新旧的结合,很多餐饮商户开始通过网络建设订餐系统,通过专门的网上订餐系统,一方面节省了用户订餐......