首页 > 其他分享 >(ICCV2023)多尺度空间特征提取模块,有效涨点,即插即用

(ICCV2023)多尺度空间特征提取模块,有效涨点,即插即用

时间:2024-10-31 08:51:02浏览次数:8  
标签:__ dim 涨点 模块 nn 特征 self 尺度空间 特征提取

题目:SAFMN:Spatially-Adaptive Feature Modulation for Efficient Image Super-Resolution

期刊:CVPR (Conference on Computer Vision and Pattern Recognition)

GitHub地址:https://github.com/sunny2109/SAFMN

年份:2023

作者单位:The Chinese University of Hong Kong (CUHK)

创新点

  • 空间自适应特征调制机制:文中提出了一种新的特征调制方法,称为空间自适应特征调制(SAFMN),能够动态调整每个像素位置的特征,使得超分辨率重建更加准确。与传统方法不同,它通过对图像特征进行空间局部自适应调制,提升图像质量。

  • 高效计算结构:SAFMN采用轻量化设计,能够在不增加计算复杂度的前提下,显著提升超分辨率模型的效果。它有效减少了冗余计算,保证了效率与性能的平衡。

  • 优异的超分辨率效果:文献中的模型在多个超分辨率基准数据集上都表现出了优异的性能,尤其是在保持图像细节和纹理方面有显著优势。与现有方法相比,它在图像质量和推理速度之间实现了更好的折衷。

方法

整体结构

SAFMN模型由三个核心部分组成:通过将空间自适应特征调制(SAFM)、跨通道混合(CCM)和特征混合模块(Feature Mixing Module)结合在一起,作者提出了以SAFM模块和CCM模块作为基本构件的网络架构。SAFMN分为三部分,分别是特征提取部分(Encoder)、特征调制和混合部分(Feature Transformation)、以及上采样重建部分(Decoder)。

  • SAFM (Spatially-Adaptive Feature Modulation):SAFM模块是核心部分,它利用局部信息自适应地调制特征,从而使得模型可以为每个像素位置选择最适合的调制方式,增强对不同区域的适应性。

  • CCM (Cross-Channel Mixing):跨通道混合模块对不同通道的特征进行交互,进一步增强了图像细节恢复能力。它通过LayerNorm规范化特征,随后进行跨通道的特征融合。

  • 特征混合模块 (Feature Mixing Module):特征提取后进入特征混合模块。该模块结合不同的特征信息,以增强对图像细节的捕捉和恢复能力。这些特征通过一系列混合操作整合信息。

消融实验

 

即插即用模块
import torch
import torch.nn as nn
import torch.nn.functional as F
#https://github.com/sunny2109/SAFMN
#论文:https://arxiv.org/pdf/2302.13800
class SAFM(nn.Module):
    def __init__(self, dim, n_levels=4):
        super().__init__()
        self.n_levels = n_levels
        chunk_dim = dim // n_levels

        # Spatial Weighting
        self.mfr = nn.ModuleList(
            [nn.Conv2d(chunk_dim, chunk_dim, 3, 1, 1, groups=chunk_dim) for i in range(self.n_levels)])

        # # Feature Aggregation
        self.aggr = nn.Conv2d(dim, dim, 1, 1, 0)

        # Activation
        self.act = nn.GELU()

    def forward(self, x):
        h, w = x.size()[-2:]

        xc = x.chunk(self.n_levels, dim=1)
        out = []
        for i in range(self.n_levels):
            if i > 0:
                p_size = (h // 2 ** i, w // 2 ** i)
                s = F.adaptive_max_pool2d(xc[i], p_size)
                s = self.mfr[i](s)
                s = F.interpolate(s, size=(h, w), mode='nearest')
            else:
                s = self.mfr[i](xc[i])
            out.append(s)

        out = self.aggr(torch.cat(out, dim=1))
        out = self.act(out) * x
        return out


if __name__ == '__main__':
    input = torch.randn(3,36,64,64) #输入b c h w

    block = SAFM(dim=36)
    output =block(input)
    print(output.size())

标签:__,dim,涨点,模块,nn,特征,self,尺度空间,特征提取
From: https://blog.csdn.net/Angelina_Jolie/article/details/143305945

相关文章