首页 > 其他分享 >(即插即用模块-Attention部分) 三十四、(2022) FACMA 频率感知跨通道注意力

(即插即用模块-Attention部分) 三十四、(2022) FACMA 频率感知跨通道注意力

时间:2025-01-10 10:59:15浏览次数:3  
标签:__ fidx nn self Attention FACMA rgb 2022 channel

在这里插入图片描述

文章目录

paper:FCMNet: Frequency-aware cross-modality attention networks for RGB-D salient object detection

Code:https://github.com/XiaoJinNK/FCMNet


1、Frequency-Aware Cross-Modality Attention

现有的 RGB-D 显著目标检测方法通常将 RGB 图像和深度图视为两种模态,并平等地对待它们。然而,这两种模态在频域中存在差异,例如,RGB 图像包含更多高频成分(细节、纹理),而深度图包含更多低频成分(平坦区域)。而传统的注意力机制(如全局平均池化)则难以保留不同模态中互补的频率成分,从而导致信息丢失。

为此,这篇论文提出一种 频率感知跨通道注意力(Frequency-Aware Cross-Modality Attention)。FACMA 模块的基本思想是从频域的角度出发,自动提取和强化不同模态中互补的信息。

对于输入X,FACMA 的实现过程包含两部分:

SFCA 部分,该部分包含两个组件:空间注意力部分和频率通道注意力 (FCA) 部分

  1. 空间注意力模块:通过 1x1 卷积操作提取位置信息,并使用 sigmoid 函数生成权重图,从而突出显示重要的位置。
  2. FCA 模块:首先对输入特征图进行二维离散余弦变换 (DCT),然后进行全连接层和 ReLU 激活操作,最后使用 sigmoid 函数生成权重图,从而突出显示对显著区域的响应。
  3. 输出:将两个组件的输出进行元素相加,得到 SFCA 模块的最终输出。

FACMA 部分

  1. 将 RGB 分支和 Depth 分支的特征图分别输入两个对称的 SFCA 模块。
  2. 将 SFCA 模块的输出进行元素相乘,从而分别生成 RGB 和 Depth 层面的信息。

Frequency-Aware Cross-Modality Attention 结构图:
在这里插入图片描述

Spatial Frequency Channel attention 结构图:
在这里插入图片描述

2、Weighted Cross-Modality Fusion module

在现有的 RGB-D 显著目标检测方法通常采用简单的融合策略中,例如元素相加或拼接,现有方法忽略了不同模态之间的差异和内容依赖性。此外,这些方法也忽略了神经网络在融合过程中的非线性表示能力。所以,除 FACMA 外,这篇论文还设计了一种即插即用的特征融合模块:加权跨模态融合模块(Weighted Cross-Modality Fusion module)

WCMF 模块旨在自适应地融合多模态特征,并考虑内容依赖性和非线性表示能力。

对于输入X,WCMF 的实现过程包含两部分:

  1. 非线性特征增强 (NFE) 单元:对输入的特征图进行 1x1 卷积、批量归一化和 ReLU 激活操作,从而增强特征的非线性表示能力。对 RGB 分支和深度分支的特征图分别进行 NFE 操作,并将它们拼接在一起。
  2. 计算权重图:对拼接后的特征图进行两次 NFE 操作,得到两个权重图,分别对应 RGB 分支和深度分支的特征图。权重图的大小与输入特征图相同,并且每个像素的值表示对应分支特征图的重要性。
  3. 融合特征图:使用权重图对 RGB 分支和深度分支的特征图进行加权相乘,并使用 ReLU 激活函数进行非线性变换。将两个加权特征图进行元素相加,得到 WCMF 模块的最终输出。

Weighted Cross-Modality Fusion module 结构图:
在这里插入图片描述

3、代码实现

import torch
import torch.nn as nn
import math


def get_1d_dct(i, freq, L):
    result = math.cos(math.pi * freq * (i+0.5)/L) / math.sqrt(L)
    if freq == 0:
        return result
    else:
        return result * math.sqrt(2)

def get_dct_weights(width,height,channel,fidx_u,fidx_v):
    dct_weights = torch.zeros(1, channel, width, height)
    c_part = channel // len(fidx_u)
    for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
        for t_x in range(width):
            for t_y in range(height):
                dct_weights[:, i*c_part: (i+1)*c_part, t_x, t_y] = get_1d_dct(t_x, u_x, width) * get_1d_dct(t_y, v_y, height)
    return dct_weights


class FCABlock(nn.Module):
    def __init__(self, channel,width,height,fidx_u, fidx_v, reduction=16):
        super(FCABlock, self).__init__()
        mid_channel = channel // reduction
        self.register_buffer('pre_computed_dct_weights', get_dct_weights(width,height,channel,fidx_u,fidx_v))
        self.excitation = nn.Sequential(
            nn.Linear(channel, mid_channel, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(mid_channel, channel, bias=False),
            nn.Sigmoid()
        )
    def forward(self, x):
        b, c, _, _ = x.size()
        y = torch.sum(x * self.pre_computed_dct_weights, dim=[2,3])
        z = self.excitation(y).view(b, c, 1, 1)
        return x * z.expand_as(x)


class SFCA(nn.Module):
    def __init__(self, in_channel,width,height,fidx_u,fidx_v):
        super(SFCA, self).__init__()

        fidx_u = [temp_u * (width // 8) for temp_u in fidx_u]
        fidx_v = [temp_v * (width // 8) for temp_v in fidx_v]
        self.FCA = FCABlock(in_channel, width, height, fidx_u, fidx_v)
        self.conv1 = nn.Conv2d(in_channel, 1, kernel_size=1, bias=False)
        self.norm = nn.Sigmoid()
    def forward(self, x):
        # FCA
        F_fca = self.FCA(x)
        #context attention
        con = self.conv1(x) # c,h,w -> 1,h,w
        con = self.norm(con)
        F_con = x * con
        return F_fca + F_con

class FACMA(nn.Module):
    def __init__(self,in_channel,width,height,fidx_u=[0,1],fidx_v=[0,1]):
        super(FACMA, self).__init__()
        self.sfca_depth = SFCA(in_channel, width, height, fidx_u, fidx_v)
        self.sfca_rgb = SFCA(in_channel, width, height, fidx_u, fidx_v)
    def forward(self, rgb, depth):
        out_d = self.sfca_depth(depth)
        out_d = rgb * out_d

        out_rgb = self.sfca_rgb(rgb)
        out_rgb = depth * out_rgb
        return out_rgb, out_d


class WCMF(nn.Module):
    def __init__(self,channel=256):
        super(WCMF, self).__init__()
        self.conv_r1 = nn.Sequential(nn.Conv2d(channel, channel, 1, 1, 0), nn.BatchNorm2d(channel), nn.ReLU())
        self.conv_d1 = nn.Sequential(nn.Conv2d(channel, channel, 1, 1, 0), nn.BatchNorm2d(channel), nn.ReLU())

        self.conv_c1 = nn.Sequential(nn.Conv2d(2*channel, channel, 3, 1, 1), nn.BatchNorm2d(channel), nn.ReLU())
        self.conv_c2 = nn.Sequential(nn.Conv2d(channel, 2, 3, 1, 1), nn.BatchNorm2d(2), nn.ReLU())
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))

    def fusion(self,f1,f2,f_vec):

        w1 = f_vec[:, 0, :, :].unsqueeze(1)
        w2 = f_vec[:, 1, :, :].unsqueeze(1)
        out1 = (w1 * f1) + (w2 * f2)
        out2 = (w1 * f1) * (w2 * f2)
        return out1 + out2

    def forward(self,rgb,depth):
        Fr = self.conv_r1(rgb)
        Fd = self.conv_d1(depth)
        f = torch.cat([Fr, Fd],dim=1)
        f = self.conv_c1(f)
        f = self.conv_c2(f)
        # f = self.avgpool(f)
        Fo = self.fusion(Fr, Fd, f)
        return Fo


if __name__ == '__main__':
    rgb_x = torch.randn(4, 512, 7, 7)
    depth_x = torch.randn(4, 512, 7, 7)
    model = FACMA(512, 7, 7)
    out_rgb, out_depth = model(rgb_x, depth_x)
    print('FACMA_RGB:' + str(out_rgb.shape))
    print('FACMA_DEPTH:' + str(out_depth.shape))

    model2 = WCMF(512)
    wcmf_output = model2(rgb_x, depth_x)
    print('WCMF:' + str(wcmf_output.shape))

标签:__,fidx,nn,self,Attention,FACMA,rgb,2022,channel
From: https://blog.csdn.net/wei582636312/article/details/144795240

相关文章

  • (即插即用模块-Attention部分) 三十三、(2021) SPA 显著位置注意力
    文章目录1、SalientPositionsAttention2、代码实现paper:SalientPositionsbasedAttentionNetworkforImageClassificationCode:https://github.com/likyoo/SPANet1、SalientPositionsAttention在现有的自注意力机制中,其建模长距离依赖关系方面表现出色......
  • Adobe Premiere Pro 2022 下载安装教程,亲测有效
    简介嗨咯,大家好,今天为大家带来的事AdobePremierePro2022下载安装教程,亲测有效。AdobePremierePro是一款领先的视频编辑软件,适用于电影、电视和网络内容创作。该软件结合强大的创意工具、Adobe应用程序和服务的深度集成以及AdobeSensei的AI技术,可帮助用户轻......
  • 剑指核心!注意力机制+时空特征融合!组合模型集成学习预测!GRU-Attention-Adaboost多变量
    剑指核心!注意力机制+时空特征融合!组合模型集成学习预测!GRU-Attention-Adaboost多变量时序预测目录剑指核心!注意力机制+时空特征融合!组合模型集成学习预测!GRU-Attention-Adaboost多变量时序预测效果一览基本介绍程序设计参考资料效果一览基本介绍1.Matlab......
  • 自注意力self-attention理解(qkv计算、代码)
    1.自注意力的个人理解   self-attention中的核心便是qkv的计算,首先是将输入向量分别乘上三个可学习的的矩阵得到Query(查询)、Key(键)、Value(值);再将q和k点乘达到全局建模的作用,将qk结果进行softmax得到Attention分数;最后将Attention和v相乘这个操作我的理解是:可以把Val......
  • vs2022遇到“停止生成”的问题
    关闭vs2012,提示必须停止生成,第一次遇见不知道怎么办,查了下,第一种解决了。在使用VisualStudio2022时,如果遇到“停止生成”的问题,可以尝试以下几种解决方案:取消生成:在VisualStudio的主界面,点击“生成”菜单,然后选择“取消”。使用快捷键 Ctrl+Break 来取消当前正......
  • Visual Studio 2022 上架腾讯云 AI 代码助手了
    近期在VisualStudio市场上上架了腾讯云AI代码助手。该插件可以在VisualStudio2022版本(含社区版,版本不低于17.6即可)使用智能辅助编码能力,助力VisualStudio的开发者提高效率。我们在该平台上支持技术对话、代码补全、单元测试生成、解释代码、修复代码等场景。如何安装......
  • (即插即用模块-Attention部分) 三十六、(2023) DCA 二重交叉注意力
    文章目录1、DualCross-Attention2、代码实现paper:DualCross-AttentionforMedicalImageSegmentationCode:https://github.com/gorkemcanates/Dual-Cross-Attention1、DualCross-AttentionU-Net及其变体尽管在医学图像分割任务中取得了良好的性能,但仍然存......
  • 用 2025 年的工具,秒杀了 2022 年的题目。
    你好呀,我是歪歪。前几天打开知乎的时候,在付费咨询模块,我看到了一个差不多两年半前没有回答的技术问题。其实这个问题问的很清晰了,但是当时我拒绝了:虽然过去快两年半的时间,但是我记得还是比较清楚,当时拒绝的理由是如果让我来回答这个问题,我肯定是首选基于Redis来做。大家想......
  • E94 Tarjan边双缩点+树形DP P8867 [NOIP2022] 建造军营
    视频链接:E94Tarjan边双缩点+树形DPP8867[NOIP2022]建造军营_哔哩哔哩_bilibili  P8867[NOIP2022]建造军营-洛谷|计算机科学教育新生态//Tarjan边双缩点+树形DPO(n)#include<bits/stdc++.h>usingnamespacestd;intread(){intx=0,f=1;charc=getchar......
  • Flash Attention V3使用
    FlashAttentionV3概述FlashAttention是一种针对Transformer模型中注意力机制的优化实现,旨在提高计算效率和内存利用率。随着大模型的普及,FlashAttentionV3在H100GPU上实现了显著的性能提升,相比于前一版本,V3通过异步化计算、优化数据传输和引入低精度计算等技术......