首页 > 其他分享 >(即插即用模块-特征处理部分) 十二、(2023) SDM 语义差异引导模块

(即插即用模块-特征处理部分) 十二、(2023) SDM 语义差异引导模块

时间:2024-12-22 11:29:09浏览次数:6  
标签:kernel self channels 模块 2023 SDM diff guidance size

在这里插入图片描述

文章目录

paper:PnPNet: Pull-and-Push Networks for Volumetric Segmentation with Boundary Confusion

Code:https://github.com/AlexYouXin/PnPNet


1、Semantic Difference Guidance Module

为了解决以下几个问题:边界特征提取困难: 神经网络擅长处理大规模特征,而边界区域仅包含一个像素宽度,属于微小结构,难以准确提取其特征。边界形状约束缺失: U 形网络等传统网络缺乏对边界形状的约束,导致在处理边界模糊区域时容易产生错误预测。这篇论文提出一种 语义差异引导模块(Semantic Difference Module)用于增强边界特征,缩小边界不确定性。

SDM 的原理基于扩散理论,将边界特征视为需要平滑的函数,通过扩散过程使其更接近真实边界。其核心思想是将边界特征与语义信息相结合,利用扩散过程进行细化,从而更精确地定位类别之间的边界。

对于特征X,SDM 具体步骤如下:

  1. 计算语义指导图:利用深度特征 G 的梯度 ∇G 作为语义指导图,其值越大表示边界特征越显著。
  2. 构建 EID 核:使用 EID 核进行特征差分,该核包含显式和隐式差分信息,能够更好地提取边界特征。
  3. 计算特征差分:利用 EID 核对特征 F 进行差分,得到特征差分 ∇F。
  4. 扩散过程:利用扩散方程对特征进行迭代更新,其中扩散系数 D 由语义指导图控制,靠近边界区域的扩散速度较慢,远离边界区域的扩散速度较快。
  5. 特征融合:将原始特征与扩散后的增强特征进行融合,得到最终的特征。

Semantic Difference Guidance Module 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F


class Conv3dbn(nn.Sequential):
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size,
            padding=0,
            stride=1,
            use_batchnorm=True,
    ):
        conv = nn.Conv3d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )

        bn = nn.BatchNorm3d(out_channels)

        super(Conv3dbn, self).__init__(conv, bn)


class SDC(nn.Module):
    def __init__(self, in_channels, guidance_channels, kernel_size=3, stride=1,
                 padding=1, dilation=1, groups=1, bias=False, theta=0.7):
        super(SDC, self).__init__()
        self.conv = nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
                              dilation=dilation, groups=groups, bias=bias)
        self.conv1 = Conv3dbn(guidance_channels, in_channels, kernel_size=3, padding=1)
        # self.conv1 = Conv3dGN(guidance_channels, in_channels, kernel_size=3, padding=1)
        self.theta = theta
        self.guidance_channels = guidance_channels
        self.in_channels = in_channels
        self.kernel_size = kernel_size

        # initialize
        x_initial = torch.randn(in_channels, 1, kernel_size, kernel_size, kernel_size)
        x_initial = self.kernel_initialize(x_initial)

        self.x_kernel_diff = nn.Parameter(x_initial)
        self.x_kernel_diff[:, :, 0, 0, 0].detach()
        self.x_kernel_diff[:, :, 0, 0, 2].detach()
        self.x_kernel_diff[:, :, 0, 2, 0].detach()
        self.x_kernel_diff[:, :, 2, 0, 0].detach()
        self.x_kernel_diff[:, :, 0, 2, 2].detach()
        self.x_kernel_diff[:, :, 2, 0, 2].detach()
        self.x_kernel_diff[:, :, 2, 2, 0].detach()
        self.x_kernel_diff[:, :, 2, 2, 2].detach()

        guidance_initial = torch.randn(in_channels, 1, kernel_size, kernel_size, kernel_size)
        guidance_initial = self.kernel_initialize(guidance_initial)

        self.guidance_kernel_diff = nn.Parameter(guidance_initial)
        self.guidance_kernel_diff[:, :, 0, 0, 0].detach()
        self.guidance_kernel_diff[:, :, 0, 0, 2].detach()
        self.guidance_kernel_diff[:, :, 0, 2, 0].detach()
        self.guidance_kernel_diff[:, :, 2, 0, 0].detach()
        self.guidance_kernel_diff[:, :, 0, 2, 2].detach()
        self.guidance_kernel_diff[:, :, 2, 0, 2].detach()
        self.guidance_kernel_diff[:, :, 2, 2, 0].detach()
        self.guidance_kernel_diff[:, :, 2, 2, 2].detach()

    def kernel_initialize(self, kernel):
        kernel[:, :, 0, 0, 0] = -1

        kernel[:, :, 0, 0, 2] = 1
        kernel[:, :, 0, 2, 0] = 1
        kernel[:, :, 2, 0, 0] = 1

        kernel[:, :, 0, 2, 2] = -1
        kernel[:, :, 2, 0, 2] = -1
        kernel[:, :, 2, 2, 0] = -1

        kernel[:, :, 2, 2, 2] = 1

        return kernel

    def forward(self, x, guidance):
        guidance_channels = self.guidance_channels
        in_channels = self.in_channels
        kernel_size = self.kernel_size

        guidance = self.conv1(guidance)

        x_diff = F.conv3d(input=x, weight=self.x_kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=1,
                          groups=in_channels)

        guidance_diff = F.conv3d(input=guidance, weight=self.guidance_kernel_diff, bias=self.conv.bias,
                                 stride=self.conv.stride, padding=1, groups=in_channels)
        out = self.conv(x_diff * guidance_diff * guidance_diff)
        return out


class SDM(nn.Module):
    def __init__(self, in_channel=3, guidance_channels=2):
        super(SDM, self).__init__()
        self.sdc1 = SDC(in_channel, guidance_channels)
        self.relu = nn.ReLU(inplace=True)
        self.bn = nn.BatchNorm3d(in_channel)

    def forward(self, feature, guidance):
        boundary_enhanced = self.sdc1(feature, guidance)
        boundary = self.relu(self.bn(boundary_enhanced))
        boundary_enhanced = boundary + feature

        return boundary_enhanced


if __name__ == '__main__':
    """
    输入维度需要是 5 维
    """
    x = torch.randn(1, 3, 32, 32, 32).cuda()
    y = torch.randn(1, 2, 32, 32, 32).cuda()
    model = SDM(3, 2).cuda()
    out = model(x, y)
    print(out.shape)

标签:kernel,self,channels,模块,2023,SDM,diff,guidance,size
From: https://blog.csdn.net/wei582636312/article/details/144516196

相关文章

  • 2000-2023年 上市公司-企业数字化转型(报告词频、文本统计)原始数据、参考文献、代码、
    一、数据介绍数据名称:企业数字化转型-年度报告词频、文本统计数据范围:1999-2023年5630家上市公司样本数量:63051条,345个变量数据来源:上市公司年度报告数据说明:内含数字化转型314个词频、各维度水平、文本统计面板二、整理说明爬取1999-2023年上市公司年报将原始报告文本......
  • Python模块之threading
    模块作用简介:Python模块之threadingthread模块基本被废弃了,现在多用threading模块来创建和管理子线程有两种方式来创建线程:第一种是:用class继承Thread类,并重写它的run()方法;第二种是:在实例化threading.Thread对象的时候,将线程要执行的任务函数作为参数传入线程。......
  • Python模块之thread
    模块作用简介:Python模块之thread,此模块基本废弃,建议使用threadingPython模块之threading:https://www.cnblogs.com/wutou/p/18621520官方英文帮助:https://docs.python.org/3/library/官方简体中文帮助:https://docs.python.org/zh-cn/3/library/必要操作:>>>......
  • 【工具变量】上市公司企业供应链成本分担数据(2010-2023年)
    一、测算方式:参考C刊《经济管理》刘红霞老师(2024)的做法,从绿色投资企业与供应链其他成员企业关系层面出发,使用两个指标测度供应链成本分担:一是单向供应链成本分担总额(CS_get),是绿色投资企业从供应链其他成员企业获取的成本分担,强调了链上企业对绿色投资企业单向的成本分担水平,使......
  • 32.Python基础篇-socketserver模块
    socketserver模块是什么?是Python中一个用于简化基于socket的网络服务实现的模块。它提供了一些高层次的类,帮助开发者更容易地实现网络服务。可以实现并发请求处理使用socketserver实现的server端,代码演示:importsocketserver#导入socketserver模块,用于简化基于sock......
  • 31.Python基础篇-hmac模块
    hmac与hashlib模块的区别hmac模块基于hashlib提供的哈希算法,在计算哈希时加入了一个“密钥”。主要用于生成“消息认证码”(MAC),通过一个密钥和数据共同生成哈希值,以此来验证数据的完整性和身份。需要密钥,它的目的是防止消息篡改并验证消息是否来源于可信的发送方。hashlib......
  • 2024-2025-1 20231420《计算机基础与程序设计》第十二周助教总结
    课程答疑C语言:指针、结构体、文件C语言中,指针、结构体十分重要,也较为困难;大家刚接触文件的操作,可能比较陌生,未能掌握文件的相关操作。大家要深入学习指针、结构体、文件的相关语法,这对之后的课程也很有帮助,可以借助网络资源学习,也要多敲代码,多多练习。课程作业中出现的问题格......
  • 30.Python基础篇-socket模块
    介绍socket模块是用于实现网络通信的模块。它提供了底层网络操作的接口,使得用户可以通过网络实现客户端和服务器之间的数据传输。通过socket模块,程序可以通过网络进行数据传输、连接和通信。使用socket模块创建一个TCP服务server端代码#server端代码importsocketsk......
  • 27.Python基础篇-configparse模块
    介绍用于处理配置文件的读取和写入。配置文件通常包含以键值对的形式存储的配置信息,常见的格式是.ini文件。该模块提供了对这些配置文件的解析功能,支持读取、写入、更新和删除配置。 配置文件的格式配置文件一般由多个部分(Section)组成,每个部分下面有多个键值对(Option)。配置......
  • 28.Python基础篇-logging模块
    介绍:logging模块是Python内置的强大日志记录工具,支持多种输出方式、格式化选项及多进程支持。日志的级别logging模块有五个内置的日志级别,从低到高:DEBUG:详细信息,用于诊断问题。INFO:常规信息,表示程序正常运行的状态。WARNING:警告信息,表示潜在问题或即将发生的错误。ERROR......