文章目录
paper:PnPNet: Pull-and-Push Networks for Volumetric Segmentation with Boundary Confusion
1、Semantic Difference Guidance Module
为了解决以下几个问题:边界特征提取困难: 神经网络擅长处理大规模特征,而边界区域仅包含一个像素宽度,属于微小结构,难以准确提取其特征。边界形状约束缺失: U 形网络等传统网络缺乏对边界形状的约束,导致在处理边界模糊区域时容易产生错误预测。这篇论文提出一种 语义差异引导模块(Semantic Difference Module)用于增强边界特征,缩小边界不确定性。
SDM 的原理基于扩散理论,将边界特征视为需要平滑的函数,通过扩散过程使其更接近真实边界。其核心思想是将边界特征与语义信息相结合,利用扩散过程进行细化,从而更精确地定位类别之间的边界。
对于特征X,SDM 具体步骤如下:
- 计算语义指导图:利用深度特征 G 的梯度 ∇G 作为语义指导图,其值越大表示边界特征越显著。
- 构建 EID 核:使用 EID 核进行特征差分,该核包含显式和隐式差分信息,能够更好地提取边界特征。
- 计算特征差分:利用 EID 核对特征 F 进行差分,得到特征差分 ∇F。
- 扩散过程:利用扩散方程对特征进行迭代更新,其中扩散系数 D 由语义指导图控制,靠近边界区域的扩散速度较慢,远离边界区域的扩散速度较快。
- 特征融合:将原始特征与扩散后的增强特征进行融合,得到最终的特征。
Semantic Difference Guidance Module 结构图:
2、代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv3dbn(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
bn = nn.BatchNorm3d(out_channels)
super(Conv3dbn, self).__init__(conv, bn)
class SDC(nn.Module):
def __init__(self, in_channels, guidance_channels, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=False, theta=0.7):
super(SDC, self).__init__()
self.conv = nn.Conv3d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.conv1 = Conv3dbn(guidance_channels, in_channels, kernel_size=3, padding=1)
# self.conv1 = Conv3dGN(guidance_channels, in_channels, kernel_size=3, padding=1)
self.theta = theta
self.guidance_channels = guidance_channels
self.in_channels = in_channels
self.kernel_size = kernel_size
# initialize
x_initial = torch.randn(in_channels, 1, kernel_size, kernel_size, kernel_size)
x_initial = self.kernel_initialize(x_initial)
self.x_kernel_diff = nn.Parameter(x_initial)
self.x_kernel_diff[:, :, 0, 0, 0].detach()
self.x_kernel_diff[:, :, 0, 0, 2].detach()
self.x_kernel_diff[:, :, 0, 2, 0].detach()
self.x_kernel_diff[:, :, 2, 0, 0].detach()
self.x_kernel_diff[:, :, 0, 2, 2].detach()
self.x_kernel_diff[:, :, 2, 0, 2].detach()
self.x_kernel_diff[:, :, 2, 2, 0].detach()
self.x_kernel_diff[:, :, 2, 2, 2].detach()
guidance_initial = torch.randn(in_channels, 1, kernel_size, kernel_size, kernel_size)
guidance_initial = self.kernel_initialize(guidance_initial)
self.guidance_kernel_diff = nn.Parameter(guidance_initial)
self.guidance_kernel_diff[:, :, 0, 0, 0].detach()
self.guidance_kernel_diff[:, :, 0, 0, 2].detach()
self.guidance_kernel_diff[:, :, 0, 2, 0].detach()
self.guidance_kernel_diff[:, :, 2, 0, 0].detach()
self.guidance_kernel_diff[:, :, 0, 2, 2].detach()
self.guidance_kernel_diff[:, :, 2, 0, 2].detach()
self.guidance_kernel_diff[:, :, 2, 2, 0].detach()
self.guidance_kernel_diff[:, :, 2, 2, 2].detach()
def kernel_initialize(self, kernel):
kernel[:, :, 0, 0, 0] = -1
kernel[:, :, 0, 0, 2] = 1
kernel[:, :, 0, 2, 0] = 1
kernel[:, :, 2, 0, 0] = 1
kernel[:, :, 0, 2, 2] = -1
kernel[:, :, 2, 0, 2] = -1
kernel[:, :, 2, 2, 0] = -1
kernel[:, :, 2, 2, 2] = 1
return kernel
def forward(self, x, guidance):
guidance_channels = self.guidance_channels
in_channels = self.in_channels
kernel_size = self.kernel_size
guidance = self.conv1(guidance)
x_diff = F.conv3d(input=x, weight=self.x_kernel_diff, bias=self.conv.bias, stride=self.conv.stride, padding=1,
groups=in_channels)
guidance_diff = F.conv3d(input=guidance, weight=self.guidance_kernel_diff, bias=self.conv.bias,
stride=self.conv.stride, padding=1, groups=in_channels)
out = self.conv(x_diff * guidance_diff * guidance_diff)
return out
class SDM(nn.Module):
def __init__(self, in_channel=3, guidance_channels=2):
super(SDM, self).__init__()
self.sdc1 = SDC(in_channel, guidance_channels)
self.relu = nn.ReLU(inplace=True)
self.bn = nn.BatchNorm3d(in_channel)
def forward(self, feature, guidance):
boundary_enhanced = self.sdc1(feature, guidance)
boundary = self.relu(self.bn(boundary_enhanced))
boundary_enhanced = boundary + feature
return boundary_enhanced
if __name__ == '__main__':
"""
输入维度需要是 5 维
"""
x = torch.randn(1, 3, 32, 32, 32).cuda()
y = torch.randn(1, 2, 32, 32, 32).cuda()
model = SDM(3, 2).cuda()
out = model(x, y)
print(out.shape)
标签:kernel,self,channels,模块,2023,SDM,diff,guidance,size
From: https://blog.csdn.net/wei582636312/article/details/144516196