首页 > 其他分享 >(即插即用模块-Attention部分) 三十六、(2023) DCA 二重交叉注意力

(即插即用模块-Attention部分) 三十六、(2023) DCA 二重交叉注意力

时间:2025-01-07 11:33:27浏览次数:3  
标签:__ features nn self Attention DCA 2023 norm out

在这里插入图片描述

文章目录

paper:Dual Cross-Attention for Medical Image Segmentation

Code:https://github.com/gorkemcanates/Dual-Cross-Attention


1、Dual Cross-Attention

U-Net 及其变体尽管在医学图像分割任务中取得了良好的性能,但仍然存在一些局限性,具体来说,卷积操作的局部性: 无法捕捉不同特征之间的长距离依赖关系。跳跃连接的语义差距: 简单地连接编码器和解码器特征会导致语义信息丢失,难以有效地融合低级特征。为了解决这些问题,这篇论文提出一种 二重交叉注意力(Dual Cross-Attention)。DCA 模块利用交叉注意力机制,有效地提取多尺度编码器特征中的通道和空间依赖关系,从而缩小编码器和解码器之间的语义差距。

DCA 的基本思想包括以下两点:通道交叉注意力(CCA): 利用交叉注意力机制捕捉多尺度编码器特征中的通道依赖关系,提取全局通道信息。空间交叉注意力(SCA): 利用交叉注意力机制捕捉多尺度编码器特征中的空间依赖关系,提取全局空间信息。DCA 模块通过将 CCA 和 SCA 模块串联使用,首先通过 CCA 提取全局通道信息,然后将 CCA 的输出作为 SCA 的输入,进一步提取全局空间信息。这种串联方式可以更有效地融合低级特征,并提取更精细的特征表示。

对于输入X,DCA的实现过程:

  1. 多尺度特征提取: 从编码器网络的多个阶段提取多尺度特征。
  2. Patch Embedding: 使用二维平均池化将多尺度特征转换为 tokens,并通过深度可分离卷积进行投影。
  3. CCA: 对每个 token 进行层归一化,并将其沿着通道维度拼接,形成 keys 和 values。使用深度可分离卷积进行线性投影,然后进行交叉注意力操作,提取全局通道信息。
  4. SCA: 对 CCA 的输出进行层归一化,并将其沿着通道维度拼接,形成 queries 和 keys。使用深度可分离卷积进行线性投影,并将每个 token 作为 values。进行交叉注意力操作,提取全局空间信息。
  5. 上采样和连接: 将 DCA 的输出进行层归一化和 GeLU 激活,然后进行上采样,并连接到解码器网络中。

Dual Cross-Attention 结构图:
在这里插入图片描述


DCA Block with U-Net 结构图:
在这里插入图片描述

2、代码实现

import  torch
import torch.nn as nn
import einops


class depthwise_conv_block(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 kernel_size=(3, 3),
                 stride=(1, 1),
                 padding=(1, 1),
                 dilation=(1, 1),
                 groups=None,
                 norm_type='bn',
                 activation=True,
                 use_bias=True,
                 pointwise=False,
                 ):
        super().__init__()
        self.pointwise = pointwise
        self.norm = norm_type
        self.act = activation
        self.depthwise = nn.Conv2d(
            in_channels=in_features,
            out_channels=in_features if pointwise else out_features,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=groups,
            dilation=dilation,
            bias=use_bias)
        if pointwise:
            self.pointwise = nn.Conv2d(in_features,
                                       out_features,
                                       kernel_size=(1, 1),
                                       stride=(1, 1),
                                       padding=(0, 0),
                                       dilation=(1, 1),
                                       bias=use_bias)

        self.norm_type = norm_type
        self.act = activation

        if self.norm_type == 'gn':
            self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)
        if self.norm_type == 'bn':
            self.norm = nn.BatchNorm2d(out_features)
        if self.act:
            self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.depthwise(x)
        if self.pointwise:
            x = self.pointwise(x)
        if self.norm_type is not None:
            x = self.norm(x)
        if self.act:
            x = self.relu(x)
        return x


class conv_block(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 kernel_size=(3, 3),
                 stride=(1, 1),
                 padding=(1, 1),
                 dilation=(1, 1),
                 norm_type='bn',
                 activation=True,
                 use_bias=True,
                 ):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_features,
                              out_channels=out_features,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding,
                              dilation=dilation,
                              bias=use_bias)

        self.norm_type = norm_type
        self.act = activation

        if self.norm_type == 'gn':
            self.norm = nn.GroupNorm(32 if out_features >= 32 else out_features, out_features)
        if self.norm_type == 'bn':
            self.norm = nn.BatchNorm2d(out_features)
        if self.act:
            self.relu = nn.ReLU(inplace=False)

    def forward(self, x):
        x = self.conv(x)
        if self.norm_type is not None:
            x = self.norm(x)
        if self.act:
            x = self.relu(x)
        return x


class ScaleDotProduct(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x1, x2, x3, scale):
        x2 = x2.transpose(-2, -1)
        x12 = torch.einsum('bhcw, bhwk -> bhck', x1, x2) * scale
        att = self.softmax(x12)
        x123 = torch.einsum('bhcw, bhwk -> bhck', att, x3)
        return x123


class PoolEmbedding(nn.Module):
    def __init__(self,
                 pooling,
                 patch,
                 ) -> None:
        super().__init__()
        self.projection = pooling(output_size=(patch, patch))

    def forward(self, x):
        x = self.projection(x)
        x = einops.rearrange(x, 'B C H W -> B (H W) C')
        return x


class depthwise_projection(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 groups,
                 kernel_size=(1, 1),
                 padding=(0, 0),
                 norm_type=None,
                 activation=False,
                 pointwise=False) -> None:
        super().__init__()

        self.proj = depthwise_conv_block(in_features=in_features,
                                         out_features=out_features,
                                         kernel_size=kernel_size,
                                         padding=padding,
                                         groups=groups,
                                         pointwise=pointwise,
                                         norm_type=norm_type,
                                         activation=activation)

    def forward(self, x):
        P = int(x.shape[1] ** 0.5)
        x = einops.rearrange(x, 'B (H W) C-> B C H W', H=P)
        x = self.proj(x)
        x = einops.rearrange(x, 'B C H W -> B (H W) C')
        return x


class UpsampleConv(nn.Module):
    def __init__(self,
                 in_features,
                 out_features,
                 kernel_size=(3, 3),
                 padding=(1, 1),
                 norm_type=None,
                 activation=False,
                 scale=(2, 2),
                 conv='conv') -> None:
        super().__init__()
        self.up = nn.Upsample(scale_factor=scale,
                              mode='bilinear',
                              align_corners=True)
        if conv == 'conv':
            self.conv = conv_block(in_features=in_features,
                                   out_features=out_features,
                                   kernel_size=(1, 1),
                                   padding=(0, 0),
                                   norm_type=norm_type,
                                   activation=activation)
        elif conv == 'depthwise':
            self.conv = depthwise_conv_block(in_features=in_features,
                                             out_features=out_features,
                                             kernel_size=kernel_size,
                                             padding=padding,
                                             norm_type=norm_type,
                                             activation=activation)

    def forward(self, x):
        x = self.up(x)
        x = self.conv(x)
        return x


class ChannelAttention(nn.Module):
    def __init__(self, in_features, out_features, n_heads=1) -> None:
        super().__init__()
        self.n_heads = n_heads
        self.q_map = depthwise_projection(in_features=out_features,
                                          out_features=out_features,
                                          groups=out_features)
        self.k_map = depthwise_projection(in_features=in_features,
                                          out_features=in_features,
                                          groups=in_features)
        self.v_map = depthwise_projection(in_features=in_features,
                                          out_features=in_features,
                                          groups=in_features)

        self.projection = depthwise_projection(in_features=out_features,
                                               out_features=out_features,
                                               groups=out_features)
        self.sdp = ScaleDotProduct()

    def forward(self, x):
        q, k, v = x[0], x[1], x[2]
        q = self.q_map(q)
        k = self.k_map(k)
        v = self.v_map(v)
        b, hw, c_q = q.shape
        c = k.shape[2]
        scale = c ** -0.5
        q = q.reshape(b, hw, self.n_heads, c_q // self.n_heads).permute(0, 2, 1, 3).transpose(2, 3)
        k = k.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3).transpose(2, 3)
        v = v.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3).transpose(2, 3)
        att = self.sdp(q, k, v, scale).permute(0, 3, 1, 2).flatten(2)
        att = self.projection(att)
        return att


class SpatialAttention(nn.Module):
    def __init__(self, in_features, out_features, n_heads=4) -> None:
        super().__init__()
        self.n_heads = n_heads

        self.q_map = depthwise_projection(in_features=in_features,
                                          out_features=in_features,
                                          groups=in_features)
        self.k_map = depthwise_projection(in_features=in_features,
                                          out_features=in_features,
                                          groups=in_features)
        self.v_map = depthwise_projection(in_features=out_features,
                                          out_features=out_features,
                                          groups=out_features)

        self.projection = depthwise_projection(in_features=out_features,
                                               out_features=out_features,
                                               groups=out_features)
        self.sdp = ScaleDotProduct()

    def forward(self, x):
        q, k, v = x[0], x[1], x[2]
        q = self.q_map(q)
        k = self.k_map(k)
        v = self.v_map(v)
        b, hw, c = q.shape
        c_v = v.shape[2]
        scale = (c // self.n_heads) ** -0.5
        q = q.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3)
        k = k.reshape(b, hw, self.n_heads, c // self.n_heads).permute(0, 2, 1, 3)
        v = v.reshape(b, hw, self.n_heads, c_v // self.n_heads).permute(0, 2, 1, 3)
        att = self.sdp(q, k, v, scale).transpose(1, 2).flatten(2)
        x = self.projection(att)
        return x


class CCSABlock(nn.Module):
    def __init__(self,
                 features,
                 channel_head,
                 spatial_head,
                 spatial_att=True,
                 channel_att=True) -> None:
        super().__init__()
        self.channel_att = channel_att
        self.spatial_att = spatial_att
        if self.channel_att:
            self.channel_norm = nn.ModuleList([nn.LayerNorm(in_features,
                                                            eps=1e-6)
                                               for in_features in features])

            self.c_attention = nn.ModuleList([ChannelAttention(
                in_features=sum(features),
                out_features=feature,
                n_heads=head,
            ) for feature, head in zip(features, channel_head)])
        if self.spatial_att:
            self.spatial_norm = nn.ModuleList([nn.LayerNorm(in_features,
                                                            eps=1e-6)
                                               for in_features in features])

            self.s_attention = nn.ModuleList([SpatialAttention(
                in_features=sum(features),
                out_features=feature,
                n_heads=head,
            )
                for feature, head in zip(features, spatial_head)])

    def forward(self, x):
        if self.channel_att:
            x_ca = self.channel_attention(x)
            x = self.m_sum(x, x_ca)
        if self.spatial_att:
            x_sa = self.spatial_attention(x)
            x = self.m_sum(x, x_sa)
        return x

    def channel_attention(self, x):
        x_c = self.m_apply(x, self.channel_norm)
        x_cin = self.cat(*x_c)
        x_in = [[q, x_cin, x_cin] for q in x_c]
        x_att = self.m_apply(x_in, self.c_attention)
        return x_att

    def spatial_attention(self, x):
        x_c = self.m_apply(x, self.spatial_norm)
        x_cin = self.cat(*x_c)
        x_in = [[x_cin, x_cin, v] for v in x_c]
        x_att = self.m_apply(x_in, self.s_attention)
        return x_att

    def m_apply(self, x, module):
        return [module[i](j) for i, j in enumerate(x)]

    def m_sum(self, x, y):
        return [xi + xj for xi, xj in zip(x, y)]

    def cat(self, *args):
        return torch.cat((args), dim=2)


class DCA(nn.Module):
    def __init__(self,
                 features,
                 strides=[8,4,2,1],
                 patch=28,
                 channel_att=True,
                 spatial_att=True,
                 n=1,
                 channel_head=[1, 1, 1, 1],
                 spatial_head=[4, 4, 4, 4],
                 ):
        super().__init__()
        self.n = n
        self.features = features
        self.spatial_head = spatial_head
        self.channel_head = channel_head
        self.channel_att = channel_att
        self.spatial_att = spatial_att
        self.patch = patch
        self.patch_avg = nn.ModuleList([PoolEmbedding(
            pooling=nn.AdaptiveAvgPool2d,
            patch=patch,
        )
            for _ in features])
        self.avg_map = nn.ModuleList([depthwise_projection(in_features=feature,
                                                           out_features=feature,
                                                           kernel_size=(1, 1),
                                                           padding=(0, 0),
                                                           groups=feature
                                                           )
                                      for feature in features])

        self.attention = nn.ModuleList([
            CCSABlock(features=features,
                      channel_head=channel_head,
                      spatial_head=spatial_head,
                      channel_att=channel_att,
                      spatial_att=spatial_att)
            for _ in range(n)])

        self.upconvs = nn.ModuleList([UpsampleConv(in_features=feature,
                                                   out_features=feature,
                                                   kernel_size=(1, 1),
                                                   padding=(0, 0),
                                                   norm_type=None,
                                                   activation=False,
                                                   scale=stride,
                                                   conv='conv')
                                      for feature, stride in zip(features, strides)])
        self.bn_relu = nn.ModuleList([nn.Sequential(
            nn.BatchNorm2d(feature),
            nn.ReLU()
        )
            for feature in features])

    def forward(self, raw):
        x = self.m_apply(raw, self.patch_avg)
        x = self.m_apply(x, self.avg_map)
        for block in self.attention:
            x = block(x)
        x = [self.reshape(i) for i in x]
        x = self.m_apply(x, self.upconvs)
        x_out = self.m_sum(x, raw)
        x_out = self.m_apply(x_out, self.bn_relu)
        return (*x_out,)

    def m_apply(self, x, module):
        return [module[i](j) for i, j in enumerate(x)]

    def m_sum(self, x, y):
        return [xi + xj for xi, xj in zip(x, y)]

    def reshape(self, x):
        return einops.rearrange(x, 'B (H W) C-> B C H W', H=self.patch)


if __name__ == '__main__':
    x = torch.randn(4, 32, 224, 224)
    y = torch.randn(4, 64, 112, 112)
    z = torch.randn(4, 128, 56, 56)
    v = torch.randn(4, 256, 28, 28)
    model = DCA([32,64,128,256])
    output1, output2, output3, output4 = model((x,y,z,v))
    print(output1.shape)
    print(output2.shape)
    print(output3.shape)
    print(output4.shape)

标签:__,features,nn,self,Attention,DCA,2023,norm,out
From: https://blog.csdn.net/wei582636312/article/details/144892288

相关文章

  • ruoyi若依前端验证码不显示的终极解决方法.20230721
    ​搞了3天啊,查了各种资料啊。然后使劲的看log啊,总算搞定了啊。一般情况,本地开发环境测试没问题,部署到服务器就各种不适应,就是服务器配置的问题了。本次这种验证码不显示,典型的nginx的配置问题。正确的nginx配置如下:events{worker_connections1024;}http{i......
  • Flash Attention V3使用
    FlashAttentionV3概述FlashAttention是一种针对Transformer模型中注意力机制的优化实现,旨在提高计算效率和内存利用率。随着大模型的普及,FlashAttentionV3在H100GPU上实现了显著的性能提升,相比于前一版本,V3通过异步化计算、优化数据传输和引入低精度计算等技术......
  • CBAM (Convolutional Block Attention Module)注意力机制详解
    定义与起源CBAM(ConvolutionalBlockAttentionModule)是一种专为卷积神经网络(CNN)设计的注意力机制,旨在增强模型对关键特征的捕捉能力。这一创新概念首次出现在2018年的研究论文《CBAM:ConvolutionalBlockAttentionModule》中。CBAM的核心思想是在通道和空间两个维......
  • 基于雾凇优化算法RIME优化CNN-BiGRU-Attention锂电池健康寿命预测算法研究Matlab实现
    基于雾凇优化算法(RIME,灵感可能来源于自然界中的雾凇形态或其形成过程的某种优化特性,这里假设为一种新的或假设的优化算法)优化CNN-BiGRU-Attention模型的锂电池健康寿命预测算法是一个复杂但具有潜力的研究方向。虽然RIME算法的具体实现细节可能因研究者的设计而异,但我们可以......
  • YOLOv11改进策略【Neck】| ArXiv 2023,基于U - Net v2中的的高效特征融合模块:SDI
    一、本文介绍本文聚焦于利用U-Netv2中的SDI模块优化YOLOv11的目标检测网络模型。SDI模块相较于传统模块独具特色,它融合了先进的特征融合思想,借助精心设计的结构,在确保计算资源高效利用的前提下,巧妙地融合不同层级特征的语义信息与细节,实现特征的全方位增强。在应用于YOL......
  • YOLOv11改进策略【Neck】| PRCV 2023,SBA:特征融合模块,描绘物体轮廓重新校准物体位置,解
    一、本文介绍本文主要利用DuAT中的SBA模块优化YOLOv11的目标检测网络模型。SBA模块借鉴了医疗图像分割中处理边界信息的独特思路,通过创新性的结构设计,在维持合理计算复杂度的基础上,巧妙融合浅层的边界细节特征与深层的语义信息,实现边界特征的精准提取与语义信息的有效......
  • 分析师关注度、分析师跟踪、研报关注度(2001-2023年)原始数据、参考文献、代码do文件、
    分析师关注度、分析师跟踪、研报关注度(2001-2023年)原始数据、参考文献、代码do文件、最终结果 https://download.csdn.net/download/2401_84585615/90025540           https://download.csdn.net/download/2401_84585615/90025540      ......
  • 说说你对2023年前端技术趋势的了解
    对于2023年的前端技术趋势,可以从以下几个方面进行归纳:WebAssembly的广泛应用:WebAssembly(简称Wasm)是一种二进制格式,能在浏览器中运行C、C++、Rust等编程语言,实现高效的代码执行,它支持多线程和内存管理,以及与JavaScript的无缝互操作。在2023年,WebAssembly得到了更广泛的应用,为......
  • YOLOv11改进 | 注意力篇 | YOLOv11引入24年Fine-Grained Channel Attention(FCAttenti
    1.FCAttention介绍1.1 摘要:近年来,无监督算法在图像去雾方面取得了显著的效果。然而,CycleGAN框架会因数据分布不一致而导致生成器学习混乱,而DisentGAN框架对生成的图像缺乏有效约束,导致图像内容细节丢失和颜色失真。此外,Squeeze和Excitation通道仅利用完全连通的层来获取全......
  • [中文流行] 阿杜[2002-2023年]所有专辑歌曲合集[无损FLAC/MP3/4.61GB]
    发布时间:2023-05-21语言种类:国语音乐类型:阿杜歌曲大全音源格式:高品质MP3+WAV+FLAC共计大小:4.61GB歌曲简介:阿杜,新加坡华人男歌手,凭借《他一定很爱你》、《撕夜》、《坚持到底》等广为流传的歌曲被大家熟知。他拥有极具个人魅力的烟嗓,歌声总能传递出生动的画面感,一口沙哑的特殊嗓音......