文章目录
1、Dual Cross-Attention
U-Net 及其变体尽管在医学图像分割任务中取得了良好的性能,但仍然存在一些局限性,具体来说,卷积操作的局部性: 无法捕捉不同特征之间的长距离依赖关系。跳跃连接的语义差距: 简单地连接编码器和解码器特征会导致语义信息丢失,难以有效地融合低级特征。为了解决这些问题,这篇论文提出一种 二重交叉注意力(Dual Cross-Attention)。DCA 模块利用交叉注意力机制,有效地提取多尺度编码器特征中的通道和空间依赖关系,从而缩小编码器和解码器之间的语义差距。
DCA 的基本思想包括以下两点:通道交叉注意力(CCA): 利用交叉注意力机制捕捉多尺度编码器特征中的通道依赖关系,提取全局通道信息。空间交叉注意力(SCA): 利用交叉注意力机制捕捉多尺度编码器特征中的空间依赖关系,提取全局空间信息。DCA 模块通过将 CCA 和 SCA 模块串联使用,首先通过 CCA 提取全局通道信息,然后将 CCA 的输出作为 SCA 的输入,进一步提取全局空间信息。这种串联方式可以更有效地融合低级特征,并提取更精细的特征表示。
对于输入X,DCA的实现过程:
- 多尺度特征提取: 从编码器网络的多个阶段提取多尺度特征。
- Patch Embedding: 使用二维平均池化将多尺度特征转换为 tokens,并通过深度可分离卷积进行投影。
- CCA: 对每个 token 进行层归一化,并将其沿着通道维度拼接,形成 keys 和 values。使用深度可分离卷积进行线性投影,然后进行交叉注意力操作,提取全局通道信息。
- SCA: 对 CCA 的输出进行层归一化,并将其沿着通道维度拼接,形成 queries 和 keys。使用深度可分离卷积进行线性投影,并将每个 token 作为 values。进行交叉注意力操作,提取全局空间信息。
- 上采样和连接: 将 DCA 的输出进行层归一化和 GeLU 激活,然后进行上采样,并连接到解码器网络中。
Dual Cross-Attention 结构图:
DCA Block with U-Net 结构图:
2、代码实现
import torch
import torch.nn as nn
import einops
class depthwise_conv_block(nn.Module):
def __init__(self,
in_features,
out_features,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
dilation=(1, 1),
groups=None,
norm_type='bn',
activation=True,
use_bias=True,
pointwise=False,
):
super().__init__()
self.pointwise = pointwise
self.norm = norm_type
self.act = activation
self.depthwise = nn.Conv2d(
in_channels=in_features,
out_channels=in_features if pointwise else out_features,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
dilation=dilation,
bias=use_bias)
if pointwise:
self.pointwise = nn.Conv2d(in_features,
out_features,
kernel_size=(1, 1),
stride=(1, 1),
padding=(0, 0),
dilation=(1, 1),
bias=use_bias)
self.norm_type = norm_type
self.act = activation
if self.norm_type == 'gn':
self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)
if self.norm_type == 'bn':
self.norm = nn.BatchNorm2d(out_features)
if self.act:
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x = self.depthwise(x)
if self.pointwise:
x = self.pointwise(x)
if self.norm_type is not None:
x = self.norm(x)
if self.act:
x = self.relu(x)
return x
class conv_block(nn.Module):
def __init__(self,
in_features,
out_features,
kernel_size=(3, 3),
stride=(1, 1),
padding=(1, 1),
dilation=(1, 1),
norm_type='bn',
activation=True,
use_bias=True,
):
super().__init__()
self.conv = nn.Conv2d(in_channels=in_features,
out_channels=out_features,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=use_bias)
self.norm_type = norm_type
self.act = activation
if self.norm_type == 'gn':
self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)
if self.norm_type == 'bn':
self.norm = nn.BatchNorm2d(out_features)
if self.act:
self.relu = nn.ReLU(inplace=False)
def forward(self, x):
x = self.conv(x)
if self.norm_type is not None:
x = self.norm(x)
if self.act:
x = self.relu(x)
return x
class ScaleDotProduct(nn.Module):
def __init__(self) -> None:
super().__init__()
self.softmax = nn.Softmax(dim=-1)
def forward(self, x1, x2, x3, scale):
x2 = x2.transpose(-2, -1)
x12 = torch.einsum('bhcw, bhwk -> bhck', x1, x2) * scale
att = self.softmax(x12)
x123 = torch.einsum('bhcw, bhwk -> bhck', att, x3)
return x123
class PoolEmbedding(nn.Module):
def __init__(self,
pooling,
patch,
) -> None:
super().__init__()
self.projection = pooling(output_size=(patch, patch))
def forward(self, x):
x = self.projection(x)
x = einops.rearrange(x, 'B C H W -> B (H W) C')
return x
class depthwise_projection(nn.Module):
def __init__(self,
in_features,
out_features,
groups,
kernel_size=(1, 1),
padding=(0, 0),
norm_type=None,
activation=False,
pointwise=False) -> None:
super().__init__()
self.proj = depthwise_conv_block(in_features=in_features,
out_features=out_features,
kernel_size=kernel_size,
padding=padding,
groups=groups,
pointwise=pointwise,
norm_type=norm_type,
activation=activation)
def forward(self, x):
P = int(x.shape[1] ** 0.5)
x = einops.rearrange(x, 'B (H W) C-> B C H W', H=P)
x = self.proj(x)
x = einops.rearrange(x, 'B C H W -> B (H W) C')
return x
class UpsampleConv(nn.Module):
def __init__(self,
in_features,
out_features,
kernel_size=(3, 3),
padding=(1, 1),
norm_type=None,
activation=False,
scale=(2, 2),
conv='conv') -> None:
super().__init__()
self.up = nn.Upsample(scale_factor=scale,
mode='bilinear',
align_corners=True)
if conv == 'conv':
self.conv = conv_block(in_features=in_features,
out_features=out_features,
kernel_size=(1, 1),
padding=(0, 0),
norm_type=norm_type,
activation=activation)
elif conv == 'depthwise':
self.conv = depthwise_conv_block(in_features=in_features,
out_features=out_features,
kernel_size=kernel_size,
padding=padding,
norm_type=norm_type,
activation=activation)
def forward(self, x):
x = self.up(x)
x = self.conv(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, in_features, out_features, n_heads=1) -> None:
super().__init__()
self.n_heads = n_heads
self.q_map = depthwise_projection(in_features=out_features,
out_features=out_features,
groups=out_features)
self.k_map = depthwise_projection(in_features=in_features,
out_features=in_features,
groups=in_features)
self.v_map = depthwise_projection(in_features=in_features,
out_features=in_features,
groups=in_features)
self.projection = depthwise_projection(in_features=out_features,
out_features=out_features,
groups=out_features)
self.sdp = ScaleDotProduct()
def forward(self, x):
q, k, v = x[0], x[1], x[2]
q = self.q_map(q)
k = self.k_map(k)
v = self.v_map(v)
b, hw, c_q = q.shape
c = k.shape[2]
scale = c ** -0.5
q = q.reshape(b, hw, self.n_heads, c_q // self.n_heads).permute(0, 2, 1, 3).transpose(2, 3)
k = k.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3).transpose(2, 3)
v = v.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3).transpose(2, 3)
att = self.sdp(q, k, v, scale).permute(0, 3, 1, 2).flatten(2)
att = self.projection(att)
return att
class SpatialAttention(nn.Module):
def __init__(self, in_features, out_features, n_heads=4) -> None:
super().__init__()
self.n_heads = n_heads
self.q_map = depthwise_projection(in_features=in_features,
out_features=in_features,
groups=in_features)
self.k_map = depthwise_projection(in_features=in_features,
out_features=in_features,
groups=in_features)
self.v_map = depthwise_projection(in_features=out_features,
out_features=out_features,
groups=out_features)
self.projection = depthwise_projection(in_features=out_features,
out_features=out_features,
groups=out_features)
self.sdp = ScaleDotProduct()
def forward(self, x):
q, k, v = x[0], x[1], x[2]
q = self.q_map(q)
k = self.k_map(k)
v = self.v_map(v)
b, hw, c = q.shape
c_v = v.shape[2]
scale = (c // self.n_heads) ** -0.5
q = q.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3)
k = k.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3)
v = v.reshape(b, hw, self.n_heads, c_v // self.n_heads).permute(0, 2, 1, 3)
att = self.sdp(q, k, v, scale).transpose(1, 2).flatten(2)
x = self.projection(att)
return x
class CCSABlock(nn.Module):
def __init__(self,
features,
channel_head,
spatial_head,
spatial_att=True,
channel_att=True) -> None:
super().__init__()
self.channel_att = channel_att
self.spatial_att = spatial_att
if self.channel_att:
self.channel_norm = nn.ModuleList([nn.LayerNorm(in_features,
eps=1e-6)
for in_features in features])
self.c_attention = nn.ModuleList([ChannelAttention(
in_features=sum(features),
out_features=feature,
n_heads=head,
) for feature, head in zip(features, channel_head)])
if self.spatial_att:
self.spatial_norm = nn.ModuleList([nn.LayerNorm(in_features,
eps=1e-6)
for in_features in features])
self.s_attention = nn.ModuleList([SpatialAttention(
in_features=sum(features),
out_features=feature,
n_heads=head,
)
for feature, head in zip(features, spatial_head)])
def forward(self, x):
if self.channel_att:
x_ca = self.channel_attention(x)
x = self.m_sum(x, x_ca)
if self.spatial_att:
x_sa = self.spatial_attention(x)
x = self.m_sum(x, x_sa)
return x
def channel_attention(self, x):
x_c = self.m_apply(x, self.channel_norm)
x_cin = self.cat(*x_c)
x_in = [[q, x_cin, x_cin] for q in x_c]
x_att = self.m_apply(x_in, self.c_attention)
return x_att
def spatial_attention(self, x):
x_c = self.m_apply(x, self.spatial_norm)
x_cin = self.cat(*x_c)
x_in = [[x_cin, x_cin, v] for v in x_c]
x_att = self.m_apply(x_in, self.s_attention)
return x_att
def m_apply(self, x, module):
return [module[i](j) for i, j in enumerate(x)]
def m_sum(self, x, y):
return [xi + xj for xi, xj in zip(x, y)]
def cat(self, *args):
return torch.cat((args), dim=2)
class DCA(nn.Module):
def __init__(self,
features,
strides=[8,4,2,1],
patch=28,
channel_att=True,
spatial_att=True,
n=1,
channel_head=[1, 1, 1, 1],
spatial_head=[4, 4, 4, 4],
):
super().__init__()
self.n = n
self.features = features
self.spatial_head = spatial_head
self.channel_head = channel_head
self.channel_att = channel_att
self.spatial_att = spatial_att
self.patch = patch
self.patch_avg = nn.ModuleList([PoolEmbedding(
pooling=nn.AdaptiveAvgPool2d,
patch=patch,
)
for _ in features])
self.avg_map = nn.ModuleList([depthwise_projection(in_features=feature,
out_features=feature,
kernel_size=(1, 1),
padding=(0, 0),
groups=feature
)
for feature in features])
self.attention = nn.ModuleList([
CCSABlock(features=features,
channel_head=channel_head,
spatial_head=spatial_head,
channel_att=channel_att,
spatial_att=spatial_att)
for _ in range(n)])
self.upconvs = nn.ModuleList([UpsampleConv(in_features=feature,
out_features=feature,
kernel_size=(1, 1),
padding=(0, 0),
norm_type=None,
activation=False,
scale=stride,
conv='conv')
for feature, stride in zip(features, strides)])
self.bn_relu = nn.ModuleList([nn.Sequential(
nn.BatchNorm2d(feature),
nn.ReLU()
)
for feature in features])
def forward(self, raw):
x = self.m_apply(raw, self.patch_avg)
x = self.m_apply(x, self.avg_map)
for block in self.attention:
x = block(x)
x = [self.reshape(i) for i in x]
x = self.m_apply(x, self.upconvs)
x_out = self.m_sum(x, raw)
x_out = self.m_apply(x_out, self.bn_relu)
return (*x_out,)
def m_apply(self, x, module):
return [module[i](j) for i, j in enumerate(x)]
def m_sum(self, x, y):
return [xi + xj for xi, xj in zip(x, y)]
def reshape(self, x):
return einops.rearrange(x, 'B (H W) C-> B C H W', H=self.patch)
if __name__ == '__main__':
x = torch.randn(4, 32, 224, 224)
y = torch.randn(4, 64, 112, 112)
z = torch.randn(4, 128, 56, 56)
v = torch.randn(4, 256, 28, 28)
model = DCA([32,64,128,256])
output1, output2, output3, output4 = model((x,y,z,v))
print(output1.shape)
print(output2.shape)
print(output3.shape)
print(output4.shape)
标签:__,features,nn,self,Attention,DCA,2023,norm,out
From: https://blog.csdn.net/wei582636312/article/details/144892288