首页 > 其他分享 >2023ICCV_Feature Modulation Transformer: Cross-Refinement of Global Representation via High-Frequenc

2023ICCV_Feature Modulation Transformer: Cross-Refinement of Global Representation via High-Frequenc

时间:2023-12-04 15:00:44浏览次数:51  
标签:dim Transformer via nn self High bias qkv size

一. Motivation

1. transformer的工作主要集中在设计transformer块以获得全局信息,而忽略了合并高频先验的潜力

2. 关于频率对性能的影响的详细分析有限(Additionally, there is limited detailed analysis of the impact of frequency on performance.

注:


 

(1)

 图说明:随着高频信息的丢失(高频Drop Ratio越来越大),虚线CNN明显下降,实线Transformer下降相对比CNN小,所以Transformer对低频信息的捕获能力强,对高频信息的捕获能力弱。

 PSNR Drop Ratio:         

 P(0)代表原始PSNR(without Dropping)

(2)PSNR 高频信息是怎么下降的

 


 

二. Contribution

1. 从频率的角度研究了CNN和transformer对性能影响,发现transformer善于捕获低频信息,不善于捕获高频信息

2. 设计了平行结构,HFERB分支捕捉高频信息,SRAWB分支捕获全局信息

3. HFERB作为高频先验Q,SRAWB作为transformer的K,V进行注意力融合

三. Network

 1. 总结构:首先通过Conv 3×3进行浅层特征提取,送入多个串行的RCRFG中,最后经过Conv 3×3和跳连接进行重建

2. 每个RCRFG包括三个CRFB和一个卷积Conv 3×3残差

 HFERB是高频先验:

class HFERB(nn.Module):
    def __init__(self, dim) -> None:
        super().__init__()
        self.mid_dim = dim//2
        self.dim = dim
        self.act = nn.GELU()
        self.last_fc = nn.Conv2d(self.dim, self.dim, 1)

        # High-frequency enhancement branch
        self.fc = nn.Conv2d(self.mid_dim, self.mid_dim, 1)
        self.max_pool = nn.MaxPool2d(3, 1, 1)

        # Local feature extraction branch
        self.conv = nn.Conv2d(self.mid_dim, self.mid_dim, 3, 1, 1)

    def forward(self, x):
        self.h, self.w = x.shape[2:]
        short = x

        # Local feature extraction branch
        lfe = self.act(self.conv(x[:,:self.mid_dim,:,:]))

        # High-frequency enhancement branch
        hfe = self.act(self.fc(self.max_pool(x[:,self.mid_dim:,:,:])))

        x = torch.cat([lfe, hfe], dim=1)
        x = short + self.last_fc(x)
        return x
HFERB

HFERB模块的核心是高频增强分支,它使用了最大池化层来提取特征图的高频信息。最大池化层的作用是在一个局部区域内选取最大的像素值,这样可以突出特征图中的边缘和纹理等细节特征,也就是高频信息。同时,最大池化层也可以起到降低特征图的空间分辨率的作用,这样可以减少计算量和内存消耗

SRWAB:

class SRWAB(nn.Module):
    r""" Shift Rectangle Window Attention Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        split_size (int): Define the window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self,
                 dim,
                 num_heads,
                 split_size=(2,2),
                 shift_size=(0,0),
                 mlp_ratio=2.,
                 qkv_bias=True,
                 qk_scale=None,
                 act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.norm1 = norm_layer(dim)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)
        self.branch_num = 2
        self.get_v = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1,groups=dim) # DW Conv

        self.attns = nn.ModuleList([
                Attention_regular(
                    dim//2, idx = i,
                    split_size=split_size, num_heads=num_heads//2, dim_out=dim//2,
                    qk_scale=qk_scale, position_bias=True)
                for i in range(self.branch_num)])

        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer)

    def forward(self, x, x_size, params, attn_mask=NotImplementedError):
        h, w = x_size
        self.h,self.w = x_size

        b, l, c = x.shape
        shortcut = x
        x = self.norm1(x)
        qkv = self.qkv(x).reshape(b, -1, 3, c).permute(2, 0, 1, 3) # 3, B, HW, C
        v = qkv[2].transpose(-2,-1).contiguous().view(b, c, h, w)

        # cyclic shift
        if self.shift_size[0] > 0 or self.shift_size[1] > 0:
            qkv = qkv.view(3, b, h, w, c)
            # H-Shift
            qkv_0 = torch.roll(qkv[:,:,:,:,:c//2], shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(2, 3))
            qkv_0 = qkv_0.view(3, b, h*w, c//2)
            # V-Shift
            qkv_1 = torch.roll(qkv[:,:,:,:,c//2:], shifts=(-self.shift_size[1], -self.shift_size[0]), dims=(2, 3))
            qkv_1 = qkv_1.view(3, b, h*w, c//2)

            # H-Rwin
            x1_shift = self.attns[0](qkv_0, h, w, mask=attn_mask[0], rpi=params['rpi_sa_h'], rpe_biases=params['biases_h'])
            # V-Rwin
            x2_shift = self.attns[1](qkv_1, h, w, mask=attn_mask[1], rpi=params['rpi_sa_v'], rpe_biases=params['biases_v'])

            x1 = torch.roll(x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2))
            x2 = torch.roll(x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2))
            # Concat
            attened_x = torch.cat([x1,x2], dim=-1)
        else:
            # H-Rwin
            x1 = self.attns[0](qkv[:,:,:,:c//2], h, w, rpi=params['rpi_sa_h'], rpe_biases=params['biases_h'])
            # V-Rwin
            x2 = self.attns[1](qkv[:,:,:,c//2:], h, w, rpi=params['rpi_sa_v'], rpe_biases=params['biases_v'])
            # Concat
            attened_x = torch.cat([x1,x2], dim=-1)

        attened_x = attened_x.view(b, -1, c).contiguous()

        # Locality Complementary Module
        lcm = self.get_v(v)
        lcm = lcm.permute(0, 2, 3, 1).contiguous().view(b, -1, c)

        attened_x = attened_x + lcm

        attened_x = self.proj(attened_x)

        # FFN
        x = shortcut + attened_x
        x = x + self.mlp(self.norm2(x))
        return x
SRWAB

3. HFERB的输出作为高频Xh,SRWAB作为低频Xs

class HFB(nn.Module):
    r""" Hybrid Fusion Block.

    Args:
        dim (int): Number of input channels.
        num_heads (int): Number of attention heads.
        ffn_expansion_factor (int): Define the window size.
        bias (int): Shift size for SW-MSA.
        LayerNorm_type (float): Ratio of mlp hidden dim to embedding dim.
    """
    def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
        super(HFB, self).__init__()

        self.norm1 = LayerNorm(dim, LayerNorm_type)
        self.attn = Attention(dim, num_heads, bias)
        self.norm2 = LayerNorm(dim, LayerNorm_type)
        self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
        self.dim = dim

    def forward(self, low, high):
        self.h, self.w = low.shape[2:]
        x = low + self.attn(self.norm1(low), high)
        x = x + self.ffn(self.norm2(x))
HFB
## High-frequency prior query inter attention layer
class Attention(nn.Module):
    def __init__(self, dim, num_heads, bias, train_size=(1, 3, 48, 48), base_size=(int(48 * 1.5), int(48 * 1.5))):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        self.train_size = train_size
        self.base_size = base_size
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.dim = dim
        self.softmax = nn.Softmax(dim=-1)

        self.q = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
        self.q_dwconv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
        self.kv = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
        self.kv_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
        self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)

    def _forward(self, q, kv):
        k,v = kv.chunk(2, dim=1)
        q = rearrange(q, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        k = rearrange(k, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
        v = rearrange(v, 'b (head c) h w -> b head c (h w)', head=self.num_heads)

        q = torch.nn.functional.normalize(q, dim=-1)
        k = torch.nn.functional.normalize(k, dim=-1)

        attn = (q @ k.transpose(-2, -1)) * self.temperature
        attn = self.softmax(attn)
        out = (attn @ v)
        return out

    def forward(self, low, high):
        self.h, self.w = low.shape[2:]

        q = self.q_dwconv(self.q(high))
        kv = self.kv_dwconv(self.kv(low))
        out = self._forward(q, kv)
        out = rearrange(out, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=kv.shape[-2], w=kv.shape[-1])
        out = self.project_out(out)
        return out
Attention

 

标签:dim,Transformer,via,nn,self,High,bias,qkv,size
From: https://www.cnblogs.com/yyhappy/p/17874450.html

相关文章

  • 简化版Transformer :Simplifying Transformer Block论文详解
    前言 本文探讨了来自苏黎世联邦理工学院计算机科学系的BobbyHe和ThomasHofmann在他们的论文“SimplifyingTransformerBlocks”中介绍的Transformer技术的进化步骤。这是自Transformer开始以来,我看到的最好的改进。本文转载自DeephubImba作者|FreedomPreetham仅用于学......
  • std::is_trivially_destructible的作用
    template<classTy>voiddestroy(Ty*pointer){destroy_one(pointer,std::is_trivially_destructible<Ty>{});}这样设计的好处主要体现在对泛型编程和内存管理的灵活性上。下面是一些可能的好处:1.**泛型性质:这种设计允许`destroy_one`在不同的上下文中使用,因为它是模......
  • Highcharts饼图的主要属性和网格线属性​
    需求在Highcharts中,需要更改图表里的网格线如何去完成;在Highcharts中,你可以通过设置不同的属性来自定义你的饼图,饼图的属性于其他图表存在差别。分析饼图属性:legend.enabled:控制图例的显示与隐藏。设置为false则隐藏图例,默认为true。legend.layout:设置图例的布局方式。可......
  • LLM 学习笔记-transformers库的 PreTrainedModel 和 ModelOutput 到底是什么?
    闲言碎语我在刚开始接触huggingface(后简称hf)的transformers库时候感觉很冗杂,比如就模型而言,有PretrainedModel,AutoModel,还有各种ModelForClassification,ModelForCausalLM,AutoModelForPreTraining,AutoModelForCausalLM等等;不仅如此,还设计了多到让人头皮发麻的各......
  • Google Colab 现已支持直接使用 transformers 库
    GoogleColab,全称Colaboratory,是GoogleResearch团队开发的一款产品。在Colab中,任何人都可以通过浏览器编写和执行任意Python代码。它尤其适合机器学习、数据分析和教育目的。从技术上来说,Colab是一种托管式Jupyter笔记本服务。用户无需设置,就可以直接使用,同时还能获得......
  • Transformer
    Attention什么是注意力机制?对于人类来说,注意力机制是在注意力有限的情况下,只关注接受信息的一部分,而忽略其他部分。对于Transformer来说,以NLP为例,注意力机制就是对于当前token来说,为其所在序列中对任务而言更重要的元素赋予更高权重(注意力)。感知机可以认为是对不同选项赋......
  • 简化版Transformer来了,网友:年度论文
    前言 从大模型的根源开始优化。本文转载自机器之心仅用于学术分享,若侵权请联系删除欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。CV各大方向专栏与各个部署框架最全教程整理【CV技术指南】CV全栈指导班、基础入门班、论文......
  • HighCharts 极地图图表绘制及添加标示线+柱状图找最值
    需求:绘制极地图并给极地图图表加上标示线,在柱状图中找出最值分析:图表加上标示线在需要的轴上面用plotLines(标示线)属性来进行添加,极地图则是在chart(图表)属性里开启polar(极)属性然后进行绘制,找出最值需要在plotOptions属性里进行修改解决:源代码:示例:标示线总是垂直于它属于的轴。......
  • ENTROFORMER: A TRANSFORMER-BASED ENTROPY MODEL基于transformer的熵模型
    目录简介模型核心代码性能实验简介\(\quad\)由于cnn在捕获全局依赖关系方面效率低,因此该文章提出了基于tansformer的熵模型——Entoformer;并针对图像压缩进行了top-kself-attention和adiamondrelativepositionencoding的优化;同时使用双向上下文模型加快解码。模型核心代......
  • 简化版Transformer :Simplifying Transformer Block论文详解
    在这篇文章中我将深入探讨来自苏黎世联邦理工学院计算机科学系的BobbyHe和ThomasHofmann在他们的论文“SimplifyingTransformerBlocks”中介绍的Transformer技术的进化步骤。这是自Transformer开始以来,我看到的最好的改进。大型语言模型(llm)可以通过各种扩展策略扩展其功......