文章目录
paper:FCMNet: Frequency-aware cross-modality attention networks for RGB-D salient object detection
1、Frequency-Aware Cross-Modality Attention
现有的 RGB-D 显著目标检测方法通常将 RGB 图像和深度图视为两种模态,并平等地对待它们。然而,这两种模态在频域中存在差异,例如,RGB 图像包含更多高频成分(细节、纹理),而深度图包含更多低频成分(平坦区域)。而传统的注意力机制(如全局平均池化)则难以保留不同模态中互补的频率成分,从而导致信息丢失。
为此,这篇论文提出一种 频率感知跨通道注意力(Frequency-Aware Cross-Modality Attention)。FACMA 模块的基本思想是从频域的角度出发,自动提取和强化不同模态中互补的信息。
对于输入X,FACMA 的实现过程包含两部分:
SFCA 部分,该部分包含两个组件:空间注意力部分和频率通道注意力 (FCA) 部分:
- 空间注意力模块:通过 1x1 卷积操作提取位置信息,并使用 sigmoid 函数生成权重图,从而突出显示重要的位置。
- FCA 模块:首先对输入特征图进行二维离散余弦变换 (DCT),然后进行全连接层和 ReLU 激活操作,最后使用 sigmoid 函数生成权重图,从而突出显示对显著区域的响应。
- 输出:将两个组件的输出进行元素相加,得到 SFCA 模块的最终输出。
FACMA 部分:
- 将 RGB 分支和 Depth 分支的特征图分别输入两个对称的 SFCA 模块。
- 将 SFCA 模块的输出进行元素相乘,从而分别生成 RGB 和 Depth 层面的信息。
Frequency-Aware Cross-Modality Attention 结构图:
Spatial Frequency Channel attention 结构图:
2、Weighted Cross-Modality Fusion module
在现有的 RGB-D 显著目标检测方法通常采用简单的融合策略中,例如元素相加或拼接,现有方法忽略了不同模态之间的差异和内容依赖性。此外,这些方法也忽略了神经网络在融合过程中的非线性表示能力。所以,除 FACMA 外,这篇论文还设计了一种即插即用的特征融合模块:加权跨模态融合模块(Weighted Cross-Modality Fusion module)。
WCMF 模块旨在自适应地融合多模态特征,并考虑内容依赖性和非线性表示能力。
对于输入X,WCMF 的实现过程包含两部分:
- 非线性特征增强 (NFE) 单元:对输入的特征图进行 1x1 卷积、批量归一化和 ReLU 激活操作,从而增强特征的非线性表示能力。对 RGB 分支和深度分支的特征图分别进行 NFE 操作,并将它们拼接在一起。
- 计算权重图:对拼接后的特征图进行两次 NFE 操作,得到两个权重图,分别对应 RGB 分支和深度分支的特征图。权重图的大小与输入特征图相同,并且每个像素的值表示对应分支特征图的重要性。
- 融合特征图:使用权重图对 RGB 分支和深度分支的特征图进行加权相乘,并使用 ReLU 激活函数进行非线性变换。将两个加权特征图进行元素相加,得到 WCMF 模块的最终输出。
Weighted Cross-Modality Fusion module 结构图:
3、代码实现
import torch
import torch.nn as nn
import math
def get_1d_dct(i, freq, L):
result = math.cos(math.pi * freq * (i+0.5)/L) / math.sqrt(L)
if freq == 0:
return result
else:
return result * math.sqrt(2)
def get_dct_weights(width,height,channel,fidx_u,fidx_v):
dct_weights = torch.zeros(1, channel, width, height)
c_part = channel // len(fidx_u)
for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
for t_x in range(width):
for t_y in range(height):
dct_weights[:, i*c_part: (i+1)*c_part, t_x, t_y] = get_1d_dct(t_x, u_x, width) * get_1d_dct(t_y, v_y, height)
return dct_weights
class FCABlock(nn.Module):
def __init__(self, channel,width,height,fidx_u, fidx_v, reduction=16):
super(FCABlock, self).__init__()
mid_channel = channel // reduction
self.register_buffer('pre_computed_dct_weights', get_dct_weights(width,height,channel,fidx_u,fidx_v))
self.excitation = nn.Sequential(
nn.Linear(channel, mid_channel, bias=False),
nn.ReLU(inplace=True),
nn.Linear(mid_channel, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = torch.sum(x * self.pre_computed_dct_weights, dim=[2,3])
z = self.excitation(y).view(b, c, 1, 1)
return x * z.expand_as(x)
class SFCA(nn.Module):
def __init__(self, in_channel,width,height,fidx_u,fidx_v):
super(SFCA, self).__init__()
fidx_u = [temp_u * (width // 8) for temp_u in fidx_u]
fidx_v = [temp_v * (width // 8) for temp_v in fidx_v]
self.FCA = FCABlock(in_channel, width, height, fidx_u, fidx_v)
self.conv1 = nn.Conv2d(in_channel, 1, kernel_size=1, bias=False)
self.norm = nn.Sigmoid()
def forward(self, x):
# FCA
F_fca = self.FCA(x)
#context attention
con = self.conv1(x) # c,h,w -> 1,h,w
con = self.norm(con)
F_con = x * con
return F_fca + F_con
class FACMA(nn.Module):
def __init__(self,in_channel,width,height,fidx_u=[0,1],fidx_v=[0,1]):
super(FACMA, self).__init__()
self.sfca_depth = SFCA(in_channel, width, height, fidx_u, fidx_v)
self.sfca_rgb = SFCA(in_channel, width, height, fidx_u, fidx_v)
def forward(self, rgb, depth):
out_d = self.sfca_depth(depth)
out_d = rgb * out_d
out_rgb = self.sfca_rgb(rgb)
out_rgb = depth * out_rgb
return out_rgb, out_d
class WCMF(nn.Module):
def __init__(self,channel=256):
super(WCMF, self).__init__()
self.conv_r1 = nn.Sequential(nn.Conv2d(channel, channel, 1, 1, 0), nn.BatchNorm2d(channel), nn.ReLU())
self.conv_d1 = nn.Sequential(nn.Conv2d(channel, channel, 1, 1, 0), nn.BatchNorm2d(channel), nn.ReLU())
self.conv_c1 = nn.Sequential(nn.Conv2d(2*channel, channel, 3, 1, 1), nn.BatchNorm2d(channel), nn.ReLU())
self.conv_c2 = nn.Sequential(nn.Conv2d(channel, 2, 3, 1, 1), nn.BatchNorm2d(2), nn.ReLU())
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
def fusion(self,f1,f2,f_vec):
w1 = f_vec[:, 0, :, :].unsqueeze(1)
w2 = f_vec[:, 1, :, :].unsqueeze(1)
out1 = (w1 * f1) + (w2 * f2)
out2 = (w1 * f1) * (w2 * f2)
return out1 + out2
def forward(self,rgb,depth):
Fr = self.conv_r1(rgb)
Fd = self.conv_d1(depth)
f = torch.cat([Fr, Fd],dim=1)
f = self.conv_c1(f)
f = self.conv_c2(f)
# f = self.avgpool(f)
Fo = self.fusion(Fr, Fd, f)
return Fo
if __name__ == '__main__':
rgb_x = torch.randn(4, 512, 7, 7)
depth_x = torch.randn(4, 512, 7, 7)
model = FACMA(512, 7, 7)
out_rgb, out_depth = model(rgb_x, depth_x)
print('FACMA_RGB:' + str(out_rgb.shape))
print('FACMA_DEPTH:' + str(out_depth.shape))
model2 = WCMF(512)
wcmf_output = model2(rgb_x, depth_x)
print('WCMF:' + str(wcmf_output.shape))
标签:__,fidx,nn,self,Attention,FACMA,rgb,2022,channel
From: https://blog.csdn.net/wei582636312/article/details/144795240