首页 > 编程语言 >【即插即用】GnConv递归门控卷积(附源码)

【即插即用】GnConv递归门控卷积(附源码)

时间:2024-03-29 19:01:01浏览次数:25  
标签:__ dim None nn self init 源码 即插即用 门控

论文地址:

https://arxiv.org/abs/2207.14284


源码地址:

https://github.com/raoyongming/HorNetGnConvGnConvHorNet

摘要简介:

最近,视觉Transformer在各种任务中取得了巨大成功,这主要得益于基于点积自注意力的新型空间建模机制。在本文中,我们发现视觉Transformer的关键要素,即输入自适应、长距离和高阶空间交互,也可以高效地在基于卷积的框架中实现。

我们提出了一种递归门控卷积(gnConv),它利用门控卷积和递归设计实现高阶空间交互。这种新操作高度灵活且可定制,能与各种卷积变体兼容,并将自注意力中的二阶交互扩展到任意阶,而不会引入大量额外计算。gnConv可以作为一个即插即用的模块,用于改进各种视觉Transformer和基于卷积的模型。

基于gnConv操作,我们构建了一个新的通用视觉骨干网络家族,名为HorNet。在ImageNet分类、COCO目标检测和ADE20K语义分割等多个实验中,HorNet展现出比Swin Transformer和ConvNeXt更优越的性能,而整体架构和训练配置相似。此外,HorNet还表现出对更多训练数据和更大模型尺寸的良好可扩展性。

除了在视觉编码器中的有效性,我们还展示了gnConv可以应用于特定任务的解码器,并在减少计算的同时持续提高密集预测性能。我们的结果表明,gnConv可以成为视觉建模的新基本模块,有效结合视觉Transformer和卷积神经网络(CNN)的优点。相关代码可以在此网址找到。

简单来说,我们发现了一种新的卷积方法,可以像视觉Transformer那样捕捉图像中的高阶空间交互。基于这种新方法,我们创建了一个新的图像处理模型HorNet,它在多个视觉任务中都表现得比现有方法更好。这个新方法不仅适用于图像的主要处理部分,还可以用在特定的任务中,帮助我们更准确地预测图像中的细节,同时减少了计算量。

结构图:
Pytorch版源码:
import torch
import torch.nn as nn
import torch.nn.functional as F

def get_dwconv(dim, kernel, bias):
    return nn.Conv2d(dim, dim, kernel_size=kernel, padding=(kernel-1)//2, bias=bias, groups=dim)

class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

class GnConv(nn.Module):
    def __init__(self, dim, order=5, gflayer=None, h=14, w=8, s=1.0):
        super().__init__()
        self.order = order
        self.dims = [dim // 2 ** i for i in range(order)]
        self.dims.reverse()
        self.proj_in = nn.Conv2d(dim, 2 * dim, 1)

        if gflayer is None:
            self.dwconv = get_dwconv(sum(self.dims), 7, True)
        else:
            self.dwconv = gflayer(sum(self.dims), h=h, w=w)

        self.proj_out = nn.Conv2d(dim, dim, 1)

        self.pws = nn.ModuleList(
            [nn.Conv2d(self.dims[i], self.dims[i + 1], 1) for i in range(order - 1)]
        )

        self.scale = s

        print('[gconv]', order, '阶与维度=', self.dims, '尺度=%.4f' % self.scale)

    def forward(self, x, mask=None, dummy=False):
        B, C, H, W = x.shape

        fused_x = self.proj_in(x)
        pwa, abc = torch.split(fused_x, (self.dims[0], sum(self.dims)), dim=1)

        dw_abc = self.dwconv(abc) * self.scale

        dw_list = torch.split(dw_abc, self.dims, dim=1)
        x = pwa * dw_list[0]

        for i in range(self.order - 1):
            x = self.pws[i](x) * dw_list[i + 1]

        x = self.proj_out(x)

        return x

class GnBlock(nn.Module):
    def __init__(self, dim, shortcut=False, layer_scale_init_value=1e-6):
        super().__init__()
        self.shortcut = shortcut
        self.norm1 = LayerNorm(dim, eps=1e-6, data_format='channels_first')
        self.gnconv = GnConv(dim, order=5)
        self.norm2 = LayerNorm(dim, eps=1e-6, data_format='channels_last')
        self.pwconv1 = nn.Linear(dim, 2 * dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(2 * dim, dim)
        self.gamma1 = nn.Parameter(layer_scale_init_value * torch.ones(dim),
                                   requires_grad=True) if layer_scale_init_value > 0 else None
        self.gamma2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)),
                                   requires_grad=True) if layer_scale_init_value > 0 else None


    def forward(self, x):
        B, C, H, W = x.shape
        if self.gamma1 is not None:
            gamma1 = self.gamma1.view(1, C, 1, 1)
        else:
            gamma1 = 1
        x = (x + gamma1 * self.gnconv(self.norm1(x))) if self.shortcut else gamma1 * self.gnconv(self.norm1(x))
        input = x
        x = x.permute(0, 2, 3, 1)  # (N, C, H, W) -> (N, H, W, C)
        x = self.norm2(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2)  # (N, H, W, C) -> (N, C, H, W)
        if self.gamma2 is not None:
            gamma2 = self.gamma2.view(1, C, 1, 1)
            x = gamma2 * x
        else:
            gamma2 = 1
        x = (input + x) if self.shortcut else x
        return x


# 使用一个示例输入进行代码测试
if __name__ == "__main__":
    # 创建一个示例输入张量
    input_tensor = torch.randn(1, 32, 128, 128)  # 批量大小,通道数,高度,宽度
    model = GnBlock(32)  # 使用与输入通道数相同的维度初始化 GnBlock
    output = model(input_tensor)  # 前向传播

    # 打印输出的尺寸
    print("输出尺寸:", output.size())

标签:__,dim,None,nn,self,init,源码,即插即用,门控
From: https://blog.csdn.net/weixin_45694817/article/details/137153863

相关文章

  • Node.js毕业设计合同管理(Express+附源码)
    本系统(程序+源码)带文档lw万字以上  文末可获取本课题的源码和程序系统程序文件列表系统的选题背景和意义选题背景:在当今信息化快速发展的时代,合同管理系统作为企业日常运营中不可或缺的一部分,扮演着至关重要的角色。合同管理涉及合同的起草、审核、签订、执行以及存档等......
  • C#手术麻醉系统源码 可对接HIS LIS PACS 医疗系统各类设备 医院手麻系统源码
    C#手术麻醉系统源码可对接HIS LIS  PACS医疗系统各类设备手术麻醉信息管理系统主要还是为了手术室开发提供全面帮助的系统,其主要是由监护设备数据采集子系统和麻醉临床系统两个子部分组成。包括从手术申请到手术分配,再到术前访视、术中记录及术后恢复的全过程中都可以......
  • 2024年1000个计算机毕业设计项目推荐(源码+论文【万字】)
    2024年最新计算机毕业设计题目推荐,项目汇总!本科、专科。项目设计、项目定制、辅导、万字文档哈喽,大家好,大四的同学马上要开始做毕业设计了,大家做好准备了吗?博主给大家详细整理了计算机毕业设计最新项目,对项目有任何疑问,都可以问博主哦~技术栈包括但不限于:Java、JavaWeb......
  • 电子招标采购系统源码之从供应商管理到采购招投标、采购合同、采购执行的全过程数字化
    随着市场竞争的加剧和企业规模的扩大,招采管理逐渐成为企业核心竞争力的重要组成部分。为了提高招采工作的效率和质量,我们提出了一种基于电子化平台的解决方案。该方案旨在通过电子化招投标,使得招标采购的质量更高、速度更快,同时节约招标成本,提升企业的资金节约率。 项目说明......
  • 【附源码】JAVA计算机毕业设计在线考研刷题系统(springboot+mysql+开题+论文)
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着信息技术的飞速发展,计算机在教育领域的应用日益广泛。特别是在线教育平台,以其便捷性、灵活性和资源共享性受到了广大师生的青睐。近年来,考研热潮......
  • 【附源码】JAVA计算机毕业设计在线考试系统的设计与实现(springboot+mysql+开题+论文)
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着信息技术的迅猛发展和互联网的普及,传统的教育模式正面临着深刻的变革。在线考试系统作为教育信息化进程中的重要一环,正逐渐取代传统的纸质考试方......
  • 【附源码】JAVA计算机毕业设计在线考试答题系统(springboot+mysql+开题+论文)
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着互联网技术的迅猛发展,教育信息化已成为现代教育发展的重要趋势。在线考试答题系统作为教育信息化的重要组成部分,能够打破传统考试的时间和空间限......
  • Node.js毕业设计航空订票系统(Express+附源码)
    本系统(程序+源码)带文档lw万字以上  文末可获取本课题的源码和程序系统程序文件列表系统的选题背景和意义选题背景:随着航空业的迅速发展,越来越多的人选择飞机作为出行的主要交通方式。航空订票系统作为航空公司与旅客之间的桥梁,其重要性不言而喻。一个好的航空订票系统能......
  • 【附源码】JAVA计算机毕业设计在线考试(springboot+mysql+开题+论文)
    本系统(程序+源码)带文档lw万字以上 文末可获取一份本项目的java源码和数据库参考。系统程序文件列表开题报告内容研究背景随着互联网技术的飞速发展,传统的教育模式正在经历深刻的变革。在线教育以其灵活、便捷的特性受到了广大师生的青睐。在线考试作为在线教育的重要组成......
  • Node.js毕业设计航空订票系统(Express+附源码)
    本系统(程序+源码)带文档lw万字以上  文末可获取本课题的源码和程序系统程序文件列表系统的选题背景和意义选题背景:随着互联网技术的不断发展,人们的生活越来越离不开网络。航空订票系统作为在线旅游行业的重要组成部分,为用户提供了便捷的航班查询、预订、支付等服务。传统......