首页 > 其他分享 >【每天一篇深度学习论文】即插即用频域增强通道注意力机制EFCAttention

【每天一篇深度学习论文】即插即用频域增强通道注意力机制EFCAttention

时间:2024-11-29 12:58:27浏览次数:10  
标签:nn EFCAttention torch 频域 FECAM DCT dct 即插即用

目录

论文介绍

题目:

FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting

论文地址:

https://arxiv.org/abs/2212.01209

创新点

  • 频域增强的通道注意机制 (FECAM):提出了一种新的频域增强通道注意力机制,通过离散余弦变换 (DCT) 代替传统的傅里叶变换 (FT) 提取频率信息。相比傅里叶变换,DCT天然避免了因周期性问题带来的吉布斯现象 (Gibbs Phenomenon),从而减少了高频噪声的引入。
  • 模块的通用性: FECAM不仅可以作为独立模型用于时间序列预测,还可以无缝嵌入主流的时间序列模型(如基于Transformer的模型和LSTM),提升这些模型在时间序列预测任务中的表现。这种通用性使得FECAM具有较高的实用价值和灵活性。
  • 理论证明与验证:论文中通过理论分析和实验证明了FECAM在频域建模的有效性,尤其是利用DCT进行频率信息提取,显著提升了模型的预测性能。
  • 实验结果表现优异:在六个真实世界的时间序列数据集上,FECAM在预测准确性方面达到了最新的最佳效果,并且相比其他方法具有更少的参数增量和计算开销。

方法

整体结构

论文中的模型结构主要是利用离散余弦变换(DCT)提取时间序列数据的频域信息,通过频域增强通道注意力机制(FECAM)在不同通道和频率分量之间自适应建模,从而捕捉到更多关键特征,最终结合全连接层或投影层生成增强的预测输出。这一结构既可独立用于预测,也能无缝集成到其他模型中,提升其预测性能。
在这里插入图片描述

  • 输入处理与通道划分:首先将多变量时间序列数据按通道维度拆分成多个子序列,每个子序列包含不同的变量特征。
  • 离散余弦变换 (DCT):对每个通道进行DCT变换,提取出对应的频率分量。这一步骤能够避免传统傅里叶变换带来的周期性问题(即吉布斯现象),从而更高效地捕捉低频信息,同时避免高频噪声的干扰。
  • 通道注意力机制:通过频域的特征图,FECAM可以在不同通道和频率分量之间自适应地建模。使用全连接层对频率增强后的特征图进行加权学习,从而获得每个通道和频率分量的重要性。
    在这里插入图片描述
    重建与输出: 最后,将学习到的频域信息和通道注意力机制的结果重新组合,生成增强后的时间序列预测。这一步通过全连接层或投影层来进行,确保频域和时间域信息能够在预测中得到充分利用。

即插即用模块作用

EFCAttention 作为一个即插即用模块,主要适用于:

  • 时间序列预测场景:用于电力负荷预测、气象数据预测、金融数据分析、交通流量预测等领域,帮助模型处理周期性和趋势性强的数据。
  • 频域信息的增强:FECAttention通过离散余弦变换(DCT)获取数据的频域特征,有效捕捉低频和高频信息,避免传统傅里叶变换带来的高频噪声问题。
  • 增强特征重要性:该模块在不同通道和频率分量之间自适应地建模,提升特征的表达能力,使模型能够更精准地学习到时间序列数据中的重要模式。
  • 提升模型鲁棒性与预测精度:通过引入频域信息,FECAttention可以显著提升各种时序模型(如LSTM、Transformer等)的预测性能,尤其在包含丰富低频信息的数据集上效果尤为显著。

消融实验结果

在这里插入图片描述

  • 该表显示了在不同数据集上,将FECAM模块嵌入到主流的Transformer和RNN模型(如LSTM、Reformer、Informer、Autoformer等)后所带来的性能提升情况。实验结果表明,FECAM模块显著提升了各个模型的预测精度,尤其是在Exchange、ETTm2和Weather等包含丰富低频信息的数据集上效果尤为显著,而在Traffic数据集上提升相对较小,表明FECAM对频率信息较为敏感的数据集更具优势。

即插即用模块代码


import torch.nn as nn
import numpy as np
import torch
#论文:FECAM: Frequency Enhanced Channel Attention Mechanism for Time Series Forecasting
#论文地址:https://arxiv.org/abs/2212.01209

try:
    from torch import irfft
    from torch import rfft
except ImportError:
    def rfft(x, d):
        t = torch.fft.fft(x, dim=(-d))
        r = torch.stack((t.real, t.imag), -1)
        return r


    def irfft(x, d):
        t = torch.fft.ifft(torch.complex(x[:, :, 0], x[:, :, 1]), dim=(-d))
        return t.real


def dct(x, norm=None):
    """
    Discrete Cosine Transform, Type II (a.k.a. the DCT)

    For the meaning of the parameter `norm`, see:
    https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html

    :param x: the input signal
    :param norm: the normalization, None or 'ortho'
    :return: the DCT-II of the signal over the last dimension
    """
    x_shape = x.shape
    N = x_shape[-1]
    x = x.contiguous().view(-1, N)

    v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)

    Vc = rfft(v, 1)

    k = - torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * np.pi / (2 * N)
    W_r = torch.cos(k)
    W_i = torch.sin(k)

    V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i

    if norm == 'ortho':
        V[:, 0] /= np.sqrt(N) * 2
        V[:, 1:] /= np.sqrt(N / 2) * 2

    V = 2 * V.view(*x_shape)

    return V


class dct_channel_block(nn.Module):
    def __init__(self, channel):
        super(dct_channel_block, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(channel, channel * 2, bias=False),
            nn.Dropout(p=0.1),
            nn.ReLU(inplace=True),
            nn.Linear(channel * 2, channel, bias=False),
            nn.Sigmoid()
        )

        self.dct_norm = nn.LayerNorm([96], eps=1e-6) # for lstm on length-wise

    def forward(self, x):
        b, c, l = x.size() # (B,C,L) (32,96,512)
        list = []
        for i in range(c):
            freq = dct(x[:, i, :])
            list.append(freq)

        stack_dct = torch.stack(list, dim=1)

        lr_weight = self.dct_norm(stack_dct)
        lr_weight = self.fc(lr_weight)
        lr_weight = self.dct_norm(lr_weight)

        return x * lr_weight # result


if __name__ == '__main__':
    input = torch.rand(8, 7, 96)
    block = dct_channel_block(96)
    result = block(input)
    print(input.size())    print(result.size())
    

标签:nn,EFCAttention,torch,频域,FECAM,DCT,dct,即插即用
From: https://blog.csdn.net/Magnolia_He/article/details/144117844

相关文章

  • (即插即用模块-Attention部分) 二十一、(2021) Polarized Self-Attention 极化自注意
    文章目录1、PolarizedSelf-Attention2、代码实现paper:PolarizedSelf-Attention:TowardsHigh-qualityPixel-wiseRegressionCode:https://github.com/DeLightCMU/PSA1、PolarizedSelf-Attention像素级回归是细粒度计算机视觉任务中的常见问题。回归问题往......
  • (即插即用模块-Attention部分) 二十、(2021) GAA 门控轴向注意力
    文章目录1、GatedAxial-Attention2、代码实现paper:MedicalTransformer:GatedAxial-AttentionforMedicalImageSegmentationCode:https://github.com/jeya-maria-jose/Medical-Transformer1、GatedAxial-Attention论文首先分析了ViTs在训练小规模数据......
  • Yolo11改进策略:Block改进|VOLO,视觉识别中的视觉展望器|即插即用|附代码+改进方法
    摘要论文介绍VOLO模型概述:本文提出了一种名为VOLO的视觉识别模型,该模型旨在通过创新的注意力机制——前景器(Outlooker)来提高视觉识别的性能。VOLO模型在ImageNet等基准测试上取得了优异的结果。研究背景:传统的视觉Transformer(ViT)模型在全局依赖性建模上表现出色,但在将精......
  • (即插即用模块-Attention部分) 十七、(CVPR 2022) HiLo Attention
    文章目录1、HiLoAttention2、LITv23、代码实现paper:FastVisionTransformerswithHiLoAttentionCode:https://github.com/ziplab/LITv21、HiLoAttention论文中指出多头自注意力(MSA)在高分辨率图像上存在巨大的计算开销。为解决这一问题,本文引入一种Hi......
  • 基于FFT + CNN - BiGRU-Attention 时域、频域特征注意力融合的电能质量扰动识别模型
    往期精彩内容:Python-电能质量扰动信号数据介绍与分类-CSDN博客Python电能质量扰动信号分类(一)基于LSTM模型的一维信号分类-CSDN博客Python电能质量扰动信号分类(二)基于CNN模型的一维信号分类-CSDN博客Python电能质量扰动信号分类(三)基于Transformer的一维信号分类模型-......
  • 我谈频域高斯滤波器
    目录写在前面的内容我谈频域高斯滤波器离谱的指数滤波器第一,截止频率。第二,低通与高通的截止频率。写在前面的内容冈萨雷斯给的频域高斯滤波器。111减去高斯......
  • Auto-Animate:是一款零配置、即插即用的动画工具,可以为您的 Web 应用添加流畅的过渡效
    嗨,大家好,我是小华同学,关注我们获得“最新、最全、最优质”开源项目和高效工作学习方法用户体验成为了检验产品成功与否的关键因素。而动画效果,作为提升用户体验的重要手段,在网页和应用开发中扮演着举足轻重的角色。今天,就让我们一起来探索一款名为Auto-Animate的动画工具,它......
  • FredNormer: 非平稳时间序列预测的频域正则化方法
    时间序列预测是一个具有挑战性的任务,尤其是在处理非平稳数据时。现有的基于正则化的方法虽然在解决分布偏移问题上取得了一定成功但仍存在局限性。这些方法主要在时间域进行操作,可能无法充分捕捉在频域中更明显的动态模式,从而导致次优的结果。FredNormer论文的研究目的主要包......
  • 即插即用篇 | YOLOv10 引入单头视觉Transformer模块 | CVPR 2024
    本改进已同步到YOLO-Magic框架!最近,高效的视觉Transformer在资源受限的设备上以低延迟表现出了出色的性能。传统上,它们在宏观层面上采用4×4的Patch嵌入和四阶段结构,而在微观层面上使用多头配置的复杂注意力机制。本文旨在通过内存高效的方式解决各个设计层面的计算冗余......
  • YOLOv9改进策略【卷积层】| SCConv:即插即用,减少冗余计算并提升特征学习
    一、本文介绍本文记录的是利用SCConv优化YOLOv9的目标检测网络模型。深度神经网络中存在大量冗余,不仅在密集模型参数中,而且在特征图的空间和通道维度中。SCConv模块通过联合减少卷积层中空间和通道的冗余,有效地限制了特征冗余,本文利用SCConv模块改进YOLOv9,提高了模型的性能......