首页 > 其他分享 >DRANet-RAB模块

DRANet-RAB模块

时间:2024-11-25 20:12:47浏览次数:4  
标签:__ kernel RAB self 模块 x2 DRANet out size

class Basic(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, padding=0, bias=False):
        super(Basic, self).__init__()
        self.out_channels = out_planes
        groups = 1
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, padding=padding, groups=groups, bias=bias)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

class ChannelPool(nn.Module):
    def __init__(self):
        super(ChannelPool, self).__init__()

    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)

class SAB(nn.Module):
    def __init__(self):
        super(SAB, self).__init__()
        kernel_size = 5
        self.compress = ChannelPool()
        self.spatial = Basic(2, 1, kernel_size, padding=(kernel_size - 1) // 2, bias=False)

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid(x_out)
        return x * scale
class RAB(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, bias=True):
        super(RAB, self).__init__()
        kernel_size = 3
        stride = 1
        padding = 1
        layers = []
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
        layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias))
        self.res = nn.Sequential(*layers)
        self.sab = SAB()

    def forward(self, x,path):
        # path1 和 path2 的计算结果完全相同 path2 没有保留 不需要保存的中间值
        if path==1:
            x1 = x + self.res(x)
            x2 = x1 + self.res(x1)
            x3 = x2 + self.res(x2)

            x3_1 = x1 + x3
            x4 = x3_1 + self.res(x3_1)
            x4_1 = x + x4
            # sab:在通道维度上求平均值和最大值 将得到的两个通道叠加在一起 然后再通过卷积变成1通道,再通过sigmoid得到权重,然后让每个通道都和这个1通道的权重相乘
            x5 = self.sab(x4_1)
            x5_1 = x + x5
            return x5_1
        else:
            x1 = x + self.res(x)

            x2 = x1 + self.res(x1)

            x2 = x2 + self.res(x2) + x1

            x2 = x2 + self.res(x2) + x

            x2 = self.sab(x2) + x
            return x2
if __name__ == '__main__':
    block1 = RAB(in_channels=3,out_channels=3,bias=True)
    input = torch.rand(1, 3, 9, 9)
    output1 = block1(input,1)
    output2 = block1(input,2)
    print("")

标签:__,kernel,RAB,self,模块,x2,DRANet,out,size
From: https://www.cnblogs.com/plumIce/p/18568524

相关文章

  • RT-DETR融合[CVPRW2024]MAN中的MLKA模块及相关改进思路
    RT-DETR使用教程: RT-DETR使用教程RT-DETR改进汇总贴:RT-DETR更新汇总贴《Multi-scaleAttentionNetworkforSingleImageSuper-Resolution》一、模块介绍    论文链接:https://arxiv.org/pdf/2209.14145v2    代码链接:https://github.com/icand......
  • Jupyter Notebook无法导入外部模块—引出对环境变量的思考
    JupyterNotebook简介JupyterNotebook是一种交互式的计算环境,允许用户通过Notebook形式创建和共享代码、可视化和文档的组合。它是一个非常流行的数据科学工具,广泛用于数据分析、机器学习。今天主要使用了NumPy——科学计算库;Matplotlib——数据绘图库下文中,为方便起......
  • python中的包和模块(非常详细),零基础入门到精通,看这一篇就够了
    文章目录一、包与模块二、第三方包的安装2.1pipinstall2.2使用curl+管道2.3其他安装方法三、导入单元的构成3.1pip的使用3.2模块的缓存3.3源码包与二进制包四、setup.py的编写零基础入门AI大模型1.学习路线图2.视频教程3.技术文档和电子书4.LLM面试题和面经合......
  • 【FAQ】Harmo【FAQ】HarmonyOS SDK 闭源开放能力 — 公共模块
    1.问题描述:文档哪里能找到所有的权限查看该权限是用户级的还是系统级的。解决方案:您好,可以看一下下方链接是否可以解决问题:https://developer.huawei.com/consumer/cn/doc/harmonyos-guides-V5/permissions-for-all-V5https://developer.huawei.com/consumer/cn/doc/harmonyo......
  • RabbitMQ---如何保证MQ幂等性?
    保证MQ幂等性通常是指保证消费者消费消息的幂等性。1、使用数据库的唯一约束去控制。添加唯一索引保证添加数据的幂等性。例如,对于订单处理场景,将订单号设置为唯一约束。当重复插入具有相同订单号的订单记录时,数据库会抛出异常,从而保证幂等性2、使用token机制总结:发送......
  • 【web】Gin+Go-Micro +Vue+Nodejs+jQuery+ElmentUI 用户模块之vue登录开发以及接口联
    在现代Web应用中,实现用户登录模块是一个关键功能。本文将分为初级、中级、高级阶段,详细说明如何使用Vue、ElementUI进行登录开发,并与Gin、Go-Micro、Node.js进行接口联调。初级用法介绍在初级阶段,主要关注于使用Vue和ElementUI创建一个简单的登录界面,并通过Node.js后端进......
  • springboot 整合 rabbitMQ (延迟队列)
    前言:延迟队列是一个内部有序的数据结构,其主要功能体现在其延时特性上。这种队列存储的元素都设定了特定的处理时间,意味着它们需要在规定的时间点或者延迟之后才能被取出并进行相应的处理。简而言之,延时队列被设计用于存放那些需要在特定时间到达时才处理的元素。使用场景:1、......
  • 【深入理解RabbitMQ】七大工作模式
    文章目录七种工作模式介绍简单模式基本概念代码实现工作队列模式基本概念代码实现发布订阅模式基本概念代码实现路由模式基本概念代码实现通配符模式基本概念代码实现`RPC`(远程过程调用模式)基本概念代码实现`PublisherConfirms`(发布确认模式)`MQ`是如何保证消息的......
  • RabbitMQ5:Fanout交换机、Direct交换机、Topic交换机
    欢迎来到“雪碧聊技术”CSDN博客!在这里,您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者,还是具有一定经验的开发者,相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导,我将不断探索Java的深邃世界,分享最新的技术动态、实战经验以及项目......
  • 如何通过命令行创建一个Maven多模块项目
    本教程将引导您使用命令行创建一个简单的Maven多模块项目,以一个博客应用为例,该应用包含一个父项目和三个子模块:blogger-core、blogger-common和blogger-web。我们将使用最新的Java版本和依赖项。准备工作确保您的系统已安装以下软件:JDK21Maven文本编辑器或IDE(如Eclip......