题目:SAFMN:Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution
期刊:CVPR (Conference on Computer Vision and Pattern Recognition)
GitHub地址:https://github.com/sunny2109/SAFMN
年份:2023
作者单位:The Chinese University of Hong Kong (CUHK)
创新点
-
空间自适应特征调制机制:文中提出了一种新的特征调制方法,称为空间自适应特征调制(SAFMN),能够动态调整每个像素位置的特征,使得超分辨率重建更加准确。与传统方法不同,它通过对图像特征进行空间局部自适应调制,提升图像质量。
-
高效计算结构:SAFMN采用轻量化设计,能够在不增加计算复杂度的前提下,显著提升超分辨率模型的效果。它有效减少了冗余计算,保证了效率与性能的平衡。
-
优异的超分辨率效果:文献中的模型在多个超分辨率基准数据集上都表现出了优异的性能,尤其是在保持图像细节和纹理方面有显著优势。与现有方法相比,它在图像质量和推理速度之间实现了更好的折衷。
方法
整体结构
SAFMN模型由三个核心部分组成:通过将空间自适应特征调制(SAFM)、跨通道混合(CCM)和特征混合模块(Feature Mixing Module)结合在一起,作者提出了以SAFM模块和CCM模块作为基本构件的网络架构。SAFMN分为三部分,分别是特征提取部分(Encoder)、特征调制和混合部分(Feature Transformation)、以及上采样重建部分(Decoder)。
-
SAFM (Spatially-Adaptive Feature Modulation):SAFM模块是核心部分,它利用局部信息自适应地调制特征,从而使得模型可以为每个像素位置选择最适合的调制方式,增强对不同区域的适应性。
-
CCM (Cross-Channel Mixing):跨通道混合模块对不同通道的特征进行交互,进一步增强了图像细节恢复能力。它通过LayerNorm规范化特征,随后进行跨通道的特征融合。
-
特征混合模块 (Feature Mixing Module):特征提取后进入特征混合模块。该模块结合不同的特征信息,以增强对图像细节的捕捉和恢复能力。这些特征通过一系列混合操作整合信息。
消融实验
即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
#https://github.com/sunny2109/SAFMN
#论文:https://arxiv.org/pdf/2302.13800
class SAFM(nn.Module):
def __init__(self, dim, n_levels=4):
super().__init__()
self.n_levels = n_levels
chunk_dim = dim // n_levels
# Spatial Weighting
self.mfr = nn.ModuleList(
[nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])
# # Feature Aggregation
self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)
# Activation
self.act = nn.GELU()
def forward(self, x):
h, w = x.size()[-2:]
xc = x.chunk(self.n_levels, dim=1)
out = []
for i in range(self.n_levels):
if i > 0:
p_size = (h // 2 ** i, w // 2 ** i)
s = F.adaptive_max_pool2d(xc[i], p_size)
s = self.mfr[i](s)
s = F.interpolate(s, size=(h, w), mode='nearest')
else:
s = self.mfr[i](xc[i])
out.append(s)
out = self.aggr(torch.cat(out, dim=1))
out = self.act(out) * x
return out
if __name__ == '__main__':
input = torch.randn(3,36,64,64) #输入b c h w
block = SAFM(dim=36)
output =block(input)
print(output.size())
标签:__,dim,涨点,模块,nn,特征,self,尺度空间,特征提取
From: https://blog.csdn.net/Angelina_Jolie/article/details/143305945