首页 > 其他分享 >【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!

时间:2024-11-03 09:18:38浏览次数:5  
标签:Multiscale 分割 Network nn dim self drop num attn

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!

【语义分割|代码解析】CMTFNet-4: CNN and Multiscale Transformer Fusion Network 用于遥感图像分割!


文章目录


欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

论文地址:https://ieeexplore.ieee.org/document/10247595

前言

在这里插入图片描述
该代码实现了一个多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块 Mutilscal_MHSA、一个块级模块 Block 以及一个融合模块 Fusion。此代码用于遥感图像语义分割模型 CMTFNet 中,主要通过多尺度卷积、MHSA 和融合机制增强图像特征提取。以下是逐行代码解析:

在这里插入图片描述

1. 多尺度多头自注意力(Multi-Head Self-Attention,MHSA)模块

class Mutilscal_MHSA(nn.Module):
    def __init__(self, dim, num_heads, atten_drop = 0., proj_drop = 0., dilation = [3, 5, 7], fc_ratio=4, pool_ratio=16):
        super(Mutilscal_MHSA, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.atten_drop = nn.Dropout(atten_drop)
        self.proj_drop = nn.Dropout(proj_drop)

        self.MSC = MutilScal(dim=dim, fc_ratio=fc_ratio, dilation=dilation, pool_ratio=pool_ratio)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Conv2d(in_channels=dim, out_channels=dim//fc_ratio, kernel_size=1),
            nn.ReLU6(),
            nn.Conv2d(in_channels=dim//fc_ratio, out_channels=dim, kernel_size=1),
            nn.Sigmoid()
        )
        self.kv = Conv(dim, 2 * dim, 1)
  • __init__ 构造函数:
  • super(Mutilscal_MHSA, self).__init__(): 初始化父类 nn.Module
  • assert dim % num_heads == 0: 确保特征维度 dim 可被头数 num_heads 整除。
  • self.dim、self.num_heads: 初始化维度和多头数量。
  • head_dim = dim // num_heads: 每个头的维度大小。
  • self.scale = head_dim ** -0.5: 计算缩放因子,用于稳定点积结果。
  • self.atten_drop、self.proj_drop: 设置注意力和投影的 dropout 层。
  • self.MSC = MutilScal(...): 多尺度卷积模块,用于提取多尺度特征。
  • self.avgpool = nn.AdaptiveAvgPool2d(1): 全局平均池化,将特征图缩小至 (1,1)。
  • self.fc = nn.Sequential(...): 两层全连接网络,用于生成通道注意力权重。
  • self.kv = Conv(dim, 2 * dim, 1): 卷积层,将输入特征转换为键值对。

forward 前向传播函数:

    def forward(self, x):
        u = x.clone()
        B, C, H, W = x.shape
        kv = self.MSC(x)
        kv = self.kv(kv)

        B1, C1, H1, W1 = kv.shape

        q = rearrange(x, 'b (h d) (hh) (ww) -> (b) h (hh ww) d', h=self.num_heads,
                      d=C // self.num_heads, hh=H, ww=W)
        k, v = rearrange(kv, 'b (kv h d) (hh) (ww) -> kv (b) h (hh ww) d', h=self.num_heads,
                         d=C // self.num_heads, hh=H1, ww=W1, kv=2)

        dots = (q @ k.transpose(-2, -1)) * self.scale
        attn = dots.softmax(dim=-1)
        attn = self.atten_drop(attn)
        attn = attn @ v

        attn = rearrange(attn, '(b) h (hh ww) d -> b (h d) (hh) (ww)', h=self.num_heads,
                         d=C // self.num_heads, hh=H, ww=W)
        c_attn = self.avgpool(x)
        c_attn = self.fc(c_attn)
        c_attn = c_attn * u
        return attn + c_attn
  • u = x.clone(): 复制输入 x,用于残差连接。
  • B, C, H, W = x.shape: 获取输入张量的维度信息。
  • kv = self.MSC(x): 将输入 x 传入多尺度卷积模块以提取键值特征。
  • kv = self.kv(kv): 使用 kv 卷积层进一步处理特征。
  • B1, C1, H1, W1 = kv.shape: 获取键值特征的维度信息。
  • q = rearrange(...): 重排 xquery 形式,适用于多头自注意力。
  • k, v = rearrange(...): 重排 kv 为键和值形式,适用于多头自注意力。
  • dots = (q @ k.transpose(-2, -1)) * self.scale: 计算缩放的查询键点积。
  • attn = dots.softmax(dim=-1): 计算点积的 softmax,生成注意力权重。
  • attn = self.atten_drop(attn): 应用注意力 dropout。
  • attn = attn @ v: 将注意力权重和值相乘,得到新的特征表示。
  • attn = rearrange(...): 重排 attn 为原始特征形状。
  • c_attn = self.avgpool(x): 对 x 进行全局平均池化。
  • c_attn = self.fc(c_attn): 通过全连接层生成通道注意力权重。
  • c_attn = c_attn * u: 将通道注意力权重与输入 u 相乘。
  • return attn + c_attn: 返回多头自注意力特征和通道注意力特征的和。

2. 块级模块 Block

class Block(nn.Module):
    def __init__(self, dim=512, num_heads=16,  mlp_ratio=4, pool_ratio=16, drop=0., dilation=[3, 5, 7],
                 drop_path=0., act_layer=nn.ReLU6, norm_layer=nn.BatchNorm2d):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Mutilscal_MHSA(dim, num_heads=num_heads, atten_drop=drop, proj_drop=drop, dilation=dilation,
                                   pool_ratio=pool_ratio, fc_ratio=mlp_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        mlp_hidden_dim = int(dim // mlp_ratio)

        self.mlp = E_FFN(in_features=dim, hidden_features=mlp_hidden_dim, out_features=dim, act_layer=act_layer,
                         drop=drop)
  • super().__init__(): 初始化父类。
  • self.norm1 = norm_layer(dim): 归一化层。
  • self.attn = Mutilscal_MHSA(...): 多尺度多头自注意力模块。
  • self.drop_path = DropPath(...): 随机丢弃路径,用于防止过拟合。
  • mlp_hidden_dim = int(dim // mlp_ratio): 计算多层感知机的隐藏层维度。
  • self.mlp = E_FFN(...): 全连接前馈网络。

forward 前向传播函数:

    def forward(self, x):

        x = x + self.drop_path(self.norm1(self.attn(x)))
        x = x + self.drop_path(self.mlp(x))

        return x
  • x = x + self.drop_path(self.norm1(self.attn(x))): 对注意力模块进行归一化、添加残差连接。
  • x = x + self.drop_path(self.mlp(x)): 对全连接层输出添加残差连接。
  • return x: 返回块的输出。

3. 融合模块 Fusion

class Fusion(nn.Module):
    def __init__(self, dim, eps=1e-8):
        super(Fusion, self).__init__()

        self.weights = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
        self.eps = eps
        self.post_conv = SeparableConvBNReLU(dim, dim, 5)
  • super(Fusion, self).__init__(): 初始化父类。
  • self.weights = nn.Parameter(...): 创建两个可训练的权重参数。
  • self.eps = eps: 用于避免除零的 epsilon。
  • self.post_conv = SeparableConvBNReLU(...): 可分离卷积层,融合后的卷积处理。

forward 前向传播函数:

    def forward(self, x, res):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False)
        weights = nn.ReLU6()(self.weights)
        fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps)
        x = fuse_weights[0] * res + fuse_weights[1] * x
        x = self.post_conv(x)
        return x
  • x = F.interpolate(...): 上采样 x
  • weights = nn.ReLU6()(self.weights): 对权重参数应用 ReLU6 激活。
  • fuse_weights = weights / (torch.sum(weights, dim=0) + self.eps): 归一化权重。
  • x = fuse_weights[0] * res + fuse_weights[1] * x: 加权融合 xres
  • x = self.post_conv(x): 通过可分离卷积进一步处理。
  • return x: 返回融合后的特征。

这些模块配合在一起实现了多尺度、多头自注意力机制以及融合处理,有效提升遥感图像语义分割性能。

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!

大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz

标签:Multiscale,分割,Network,nn,dim,self,drop,num,attn
From: https://blog.csdn.net/gaoxiaoxiao1209/article/details/143393009

相关文章

  • 【bat脚本】批处理如何把文本文件分割成N个文本文件?
    原创bat学习Bat批处理学习站需求比如我现在有一文本文件,我要上传,但是太大(文件8-12M之间),不允许,我想把它分割成N个小文件(按行分割,行数可以自行设定),文件内容全部是数字和部分符号,比如内容:123456+7234567+8345678+9456789+10567891......
  • systemctl restart NetworkManager 重启后,文件/etc/resolv.conf修改失败
    如果你在重启NetworkManager之后发现无法修改/etc/resolv.conf文件,这是因为NetworkManager会自动管理这个文件为了解决这个问题,你可以采取以下两种方法之一:方法一:禁用NetworkManager服务使用以下命令停止NetworkManager服务:sudosystemctlstopNetworkMana......
  • 动态规划-回文串问题——132.分割回文串II
    1.题目解析题目来源:132.分割回文串II——力扣测试用例2.算法原理首先回文串问题一定首先需要保存每个回文子串出现的位置,即二维dp表来存储所有子字符串中符合回文子串的位置,如图1.状态表示创建一个一维dp表来存储第i个位置之前的字符串数组全部划分为回文子......
  • shell中的IFS变量与词分割
    引入在bash、zsh、csh等等各种shell实现中,都有一个特殊的内置变量IFS(InternalFieldSeparator),意为内部字段分隔符。IFS变量值是一个字符序列,shell会将IFS字符序列中的各个字符视为词分割(wordsplitting)过程中分隔不同token的边界。正文1.什么是词分割以及什么情......
  • Virtual Private Network (VPN) Lab
    Task1:VMSetup使用上一个VPN的Labsetup包所构建的实验环境,所以这个任务就相当于是解决了。Task2:CreatingaVPNTunnelusingTUN/TAPStep1:自己构造tun_server.py,加权限并且在server上运行Step2:在HostU上构建tun_client.py,并运行tun_client.py文件:Step3......
  • PCL 法线微分(DoN)分割(C++详细过程版)
    目录一、概述二、代码实现三、结果展示本文由CSDN点云侠原创,原文链接,首发于:2024年11月1日。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的抄袭狗。一、概述  法线微分(DoN)分割在PCL里有现成的调用函数,具体算法原理和实现代码见:PCL基于法线微分(D......
  • COMP3331/9331 Computer Networks and Applications
    COMP3331/9331ComputerNetworksandApplicationsAssignmentforTerm3,2024Version1.1Due:11:59am(noon)Friday,8November2024(Week9)TableofContentsGOALANDLEARNINGOBJECTIVES....................................................................
  • 华为OD机试 E卷|字符串分割转换
    华为OD机试E卷|字符串分割转换0、关于本专栏&刷题交流群本文收录于专栏【2024华为OD机试真题】,专栏共有上千道OD机试真题,包含详细解答思路、与四种代码实现(Python、Java、C++、JavaScript)。点击文末链接加入【华为OD机试交流群】,和群友一起刷题备考。刷的越多,考试中遇到原题的......
  • yolov8+多算法多目标追踪+实例分割+目标检测+姿态估计(代码+教程)
    #多目标追踪+实例分割+目标检测YOLO(YouOnlyLookOnce)是一个流行的目标检测算法,它能够在图像中准确地定位和识别多个物体。在这里插入图片描述本项目是基于YOLO算法的目标跟踪系统,它将YOLO的目标检测功能与目标跟踪技术相结合,实现了实时的多目标跟踪。在目标......
  • CSCI 201 Networked Crossword Puzzle
    Assignment#2CSCI201Fall2024Page1of11Assignment#2CSCI201Fall20246%ofcoursegradeTitleNetworkedCrosswordPuzzleTopicsCoveredNetworkingMulti-ThreadingConcurrencyIssuesIntroductionThisassignmentwillrequireyoutocreatetwodiffe......