首页 > 其他分享 >MAGNet -MAFM 多尺度感知融合模块

MAGNet -MAFM 多尺度感知融合模块

时间:2024-11-27 22:45:31浏览次数:9  
标签:dim nn self MAGNet init bias MAFM 模块 out

import math
import torch.nn as nn
import torch
from timm.models.layers import trunc_normal_

class COI(nn.Module):
    def __init__(self, inc, k=3, p=1):
        super().__init__()
        self.outc = inc
        self.dw = nn.Conv2d(inc, self.outc, kernel_size=k, padding=p, groups=inc)
        self.conv1_1 = nn.Conv2d(inc, self.outc, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(self.outc)
        self.bn2 = nn.BatchNorm2d(self.outc)
        self.bn3 = nn.BatchNorm2d(self.outc)
        self.act = nn.GELU()
        self.apply(self._init_weights)

    def forward(self, x):
        shortcut = self.bn1(x)

        x_dw = self.bn2(self.dw(x))
        x_conv1_1 = self.bn3(self.conv1_1(x))
        return self.act(shortcut + x_dw + x_conv1_1)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
class MHMC(nn.Module):
    def __init__(self, dim, ca_num_heads=4, qkv_bias=True, proj_drop=0., ca_attention=1, expand_ratio=2):
        super().__init__()

        self.ca_attention = ca_attention
        self.dim = dim
        self.ca_num_heads = ca_num_heads

        assert dim % ca_num_heads == 0, f"dim {dim} should be divided by num_heads {ca_num_heads}."

        self.act = nn.GELU()
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.split_groups = self.dim // ca_num_heads

        self.v = nn.Linear(dim, dim, bias=qkv_bias)
        self.s = nn.Linear(dim, dim, bias=qkv_bias)
        for i in range(self.ca_num_heads):
            local_conv = nn.Conv2d(dim // self.ca_num_heads, dim // self.ca_num_heads, kernel_size=(3 + i * 2),
                                   padding=(1 + i), stride=1,
                                   groups=dim // self.ca_num_heads)  # kernel_size 3,5,7,9 大核dw卷积,padding 1,2,3,4
            setattr(self, f"local_conv_{i + 1}", local_conv)
        self.proj0 = nn.Conv2d(dim, dim * expand_ratio, kernel_size=1, padding=0, stride=1,
                               groups=self.split_groups)
        self.bn = nn.BatchNorm2d(dim * expand_ratio)
        self.proj1 = nn.Conv2d(dim * expand_ratio, dim, kernel_size=1, padding=0, stride=1)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

    def forward(self, x, H, W):
        B, N, C = x.shape
        v = self.v(x)
        s = self.s(x).reshape(B, H, W, self.ca_num_heads, C // self.ca_num_heads).permute(3, 0, 4, 1,
                                                                                          2)  # num_heads,B,C,H,W
        for i in range(self.ca_num_heads):
            local_conv = getattr(self, f"local_conv_{i + 1}")
            s_i = s[i]  # B,C,H,W
            s_i = local_conv(s_i).reshape(B, self.split_groups, -1, H, W)
            if i == 0:
                s_out1 = s_i
            else:
                s_out1 = torch.cat([s_out1, s_i], 2)
        s_out1 = s_out1.reshape(B, C, H, W)

        for i in range(self.ca_num_heads):
            local_conv = getattr(self, f"local_conv_{i + 1}")
            s_i = s[i]  # B,C,H,W
            s_i = local_conv(s_i)
            if i == 0:
                s_out = s_i
            else:
                s_out = torch.cat([s_out, s_i], 1)


        s_out = self.proj1(self.act(self.bn(self.proj0(s_out))))
        self.modulator = s_out
        s_out = s_out.reshape(B, C, N).permute(0, 2, 1)
        x = s_out * v

        x = self.proj(x)
        x = self.proj_drop(x)
        return x



# Multi-scale Awareness Fusion Module
class MAFM(nn.Module):
    def __init__(self, inc):
        super().__init__()
        self.outc = inc
        self.attention = MHMC(dim=inc)
        self.coi = COI(inc)
        self.pw = nn.Sequential(
            nn.Conv2d(in_channels=inc, out_channels=inc, kernel_size=1, stride=1),
            nn.BatchNorm2d(inc),
            nn.GELU()
        )
        self.pre_att = nn.Sequential(
            nn.Conv2d(inc * 2, inc * 2, kernel_size=3, padding=1, groups=inc * 2),
            nn.BatchNorm2d(inc * 2),
            nn.GELU(),
            nn.Conv2d(inc * 2, inc, kernel_size=1),
            nn.BatchNorm2d(inc),
            nn.GELU()
        )

        self.apply(self._init_weights)

    def forward(self, x, d):
        # multi = x * d
        # B, C, H, W = x.shape
        # x_cat = torch.cat((x, d, multi), dim=1)

        B, C, H, W = x.shape
        x_cat = torch.cat((x, d), dim=1)
        x_pre = self.pre_att(x_cat)
        # Attention
        x_reshape = x_pre.flatten(2).permute(0, 2, 1)  # B,C,H,W to B,N,C
        attention = self.attention(x_reshape, H, W)  # attention
        attention = attention.permute(0, 2, 1).reshape(B, C, H, W)  # B,N,C to B,C,H,W

        # COI
        x_conv = self.coi(attention)  # dw3*3,1*1,identity
        x_conv = self.pw(x_conv)  # pw

        return x_conv

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()

if __name__ == '__main__':
    x = torch.randn((1, 4, 9, 9)).cuda()
    d = torch.randn((1, 4, 9, 9)).cuda()
    model = MAFM(inc=4).cuda()
    out = model(x,d)
    print(out.shape)

MHMC模块可以看一下 这种做法没见过

标签:dim,nn,self,MAGNet,init,bias,MAFM,模块,out
From: https://www.cnblogs.com/plumIce/p/18573242

相关文章

  • (即插即用模块-Attention部分) 二十、(2021) GAA 门控轴向注意力
    文章目录1、GatedAxial-Attention2、代码实现paper:MedicalTransformer:GatedAxial-AttentionforMedicalImageSegmentationCode:https://github.com/jeya-maria-jose/Medical-Transformer1、GatedAxial-Attention论文首先分析了ViTs在训练小规模数据......
  • 从软件工程的角度,谈模块为什么总是不兼容
    前言今天刚刷上Apatch,发现其没有提供Zygisk,又去酷安搜了一搜,似乎有人反应刷Lsposed不起作用,大致了解了一下,并查了些资料。下面我开始猜测以及进行理论。说是从软件工程出发,但是实际上我并不算一个好学生,更无法代表软件工程,这或许很标题党,但是我确实想以这个名称命名。一言蔽之......
  • python问题解决-外部模块明明安装了,却总是无法找到
    1现象代码中引入了cv2模块,这是一个图像识别模块。但运行时总提示找不到该模块。也按照要求安装了opencv-python等模块。还有其它的,如python-pptx模块,提示如下:Traceback(mostrecentcalllast):File"E:/python/wps/ppt_pic.py",line1,in<module>frompp......
  • re模块 函数模式详解
    re模块python爬虫过程中,实现页面元素解析的方法很多,正则解析只是其中之一,常见的还有BeautifulSoup和lxml,它们都支持网页HTML元素解析,re模块提供了强大的正则表达式功能re模块常用方法compile(pattern,flags=0):用于编译一个正则表达式字符串,生成一个re.pattern对象......
  • xml模块
    importxml.etree.ElementTreeasETtree=ET.parse("test.xml")root=tree.getroot()print(root.tag)#遍历xml文档forchildinroot:print(child.tag,child.attrib)foriinchild:print(i.tag,i.text)#只遍历year节点fornodeinroo......
  • 定时音乐模块-初级程序-极语言教程
    //窗体代码:整数窗体,小时,分钟,标签3,标签4,计时;程序资源24,"清单.xml";程序段加载窗体整数左=(桌面.宽-350)>>1,上=(桌面.高-300)>>1;窗体=创建窗口($200,程序.名称,"定时器",$10CF0064,左,上,350,300,0,0,0,0);小时=创建窗口($200,"Edit","10",$50010000,70,65,45,......
  • Maven构建多模块项目(按层去分)
    标签(空格分隔):springboot为什么要构建多模块项目(1)不同方面的代码之间相互耦合,这时候一系统出现问题很难定位到问题的出现原因,即使定位到问题也很难修正问题,可能在修正问题的时候引入更多的问题。(2)多方面的代码集中在一个整体结构中,新入的开发者很难对整体......
  • 反向代理模块
     1概念  1.1反向代理概念反向代理是指以代理服务器来接收客户端的请求,然后将请求转发给内部网络上的服务器,将从服务器上得到的结果返回给客户端,此时代理服务器对外表现为一个反向代理服务器。对于客户端来说,反向代理就相当于目标服务器,只需要将反向代理当作目标服......
  • 37. socketserver模块
    一、socketserver模块  SocketServer是标准库中的一个高级模块,它的目标是简化很多样板代码,它们是创建网络客户端和服务器所必须的代码。这个模块中有为你创建的各种各样的类。类描述BaseServer包含核心服务器功能和min-in类的钩子;仅用于推导,这样不会创建这个类的......
  • Python -- PyExecJS模块
    PyExecJS介绍PyExecJS是一个可以使用Python来模拟运行JavaScript的库。使用该模块可以通过python程序调用执行js代码,获取js代码返回的结果!注意事项:电脑必须安装好了nodejs开发环境上述模块才可以生效!环境安装:pipinstallPyExecJS使用步骤:导包:importexecjs......