首页 > 其他分享 >无问芯穹DiTFastAttn 中稿NeurIPS,减少76%注意力计算量,加速比最高可达180%

无问芯穹DiTFastAttn 中稿NeurIPS,减少76%注意力计算量,加速比最高可达180%

时间:2024-10-30 11:17:54浏览次数:5  
标签:输出 窗口 DiTFastAttn 无问 生成 180% 注意力 冗余

《DiTFastAttn: Attention Compression for Diffusion Transformer Models》一文由清华大学、无问芯穹和上海交通大学的研究团队联合发表,成功入选NeurIPS 会议。该研究针对图像生成模型中的计算效率问题,提出了一种新的后训练压缩方法DiTFastAttn。这种方法最多可减少 76% 的注意力计算量,在高分辨率生成方面最多可实现 1.8 倍的速度提升。 

图片

论文链接:https://arxiv.org/pdf/2406.08552

NeurIPS 会议的论文录取率为25.8%,略低于去年的26.1%,竞争激烈。无问芯穹另一篇入选论文《Can LLMs Learn by Teaching? A Preliminary Study》也将在后续与大家见面,敬请期待。

图片

引言

最近,扩散变换器(DiT)在图像和视频生成中越来越受欢迎。然而,DiTs 面临的一个主要挑战是其巨大的计算需求,在生成高分辨率内容时尤为明显。一方面,传统的转换器架构具有自注意力机制,对输入标记长度 L 的复杂度为 O(L2)。如图 1 所示,随着图像分辨率的提高,注意力计算成为推理过程中的主要计算瓶颈。具体来说,如果将 2K × 2K 的图像标记化为 16K 标记,即使在 Nvidia A100 等高端 GPU 上,注意力计算也需要数秒时间。另一方面,由于采用了多个去噪步骤和无分类器引导(CFG)技术,扩散推理过程需要大量的神经网络推理。以往加速注意机制的努力,如局部注意、Swin Transformer和群组查询注意(GQA),主要集中在设计注意机制或网络架构上。这些方法虽然能有效降低计算成本,但需要大量的再训练成本。由于训练 DiT 需要大量的数据和计算,因此有必要采用训练后压缩方法。在这项工作中,我们确定了 DiT 推理中注意力计算的三种冗余类型,并提出了一种训练后模型压缩方法 DiTFastAttn 来解决这些冗余问题:

图片

图 1:左图:在 PixArt-Sigma上应用 DiTFastAttn 时,生成不同分辨率图像的效率优势。Y 轴显示了以原始模型的 #FLOPs 为标准的 #FLOPs 分数。右图 在 1024×1024 PixArt-Sigma 上应用 DiTFastAttn 的定性结果。

(1) 空间维度的冗余。许多注意力头主要捕捉局部空间信息,对远处标记的注意力值接近零。为了减少冗余,我们选择在某些层使用窗口注意力而不是完全注意力。然而,直接丢弃窗口外的所有注意力计算会导致性能显著下降。为了在无需训练的情况下保持性能,我们建议在一个步骤中缓存全注意力和窗口注意力输出之间的残差,并在随后的几个步骤中重复使用该残差。我们将这种技术称为带残差共享的窗口注意力(WA-RS)。

(2) 相邻步骤之间的注意力输出相似。同一注意力头在相邻步骤中的注意力输出可能高度相似。我们提出了跨时间步注意力共享(AST)技术,利用这种步进相似性来加速注意力计算。

(3) 有条件推理与无条件推理在注意力输出上的相似性。我们观察到,在 CFG 中,条件推理和非条件推理的注意力输出在某些头部和时间步上表现出显著的相似性(SSIM ≥ 0.95)。因此,我们提出了跨 CFG 的注意力共享(ASC)技术,以跳过条件生成过程中的冗余计算。我们使用多个 DiT 模型进行了广泛的实验,以评估DiTFastAttn,包括用于图像生成的 DiT-XL 和 PixArt-Sigma,以及用于视频生成的 Open-Sora。我们的研究结果表明,DiTFastAttn 能持续降低计算成本。值得注意的是,分辨率越高,计算和延迟的节省就越大。例如,在使用 PixArt-Sigma 生成 2048×2048 图像时,DiTFastAttn 可将注意力计算量减少 20% 至 76%,加速比最高可达 1.8 倍。

图片

相关工作

2.1 扩散模型

扩散模型因其优于 GAN 的生成性能而备受关注。早期的扩散模型是基于 U-Net 架构实现的。为了实现更好的可扩展性,DiT利用trnasformer架构取代了 U-Net。扩散变换器被应用于图像和视频生成领域。PixArt-Sigma展示了扩散变换器生成高达 4K 高分辨率图像的能力。OpenSora展示了扩散变换器生成视频的能力。

2.2 视觉transformer 压缩

注意力的计算开销引起了广泛关注。FlashAttention将输入标记划分成更小的块,以尽量减少冗余内存访问和优化延迟。一些研究强调了注意力计算的二次复杂性,并通过在网络的不同阶段过滤或合并标记来实现标记剪枝,从而提高效率。动态 ViT采用预测网络动态过滤标记。Adaptive Sparse ViT通过同时考虑注意力值和特征的 L2 准则来过滤 token。Lu 等人(2023 年)利用分段标签训练网络,以指导具有相似内容的区域中标记的合并操作。Huang 等人(2023 年)在对标记进行下采样后进行注意力计算,然后再进行上采样以恢复空间分辨率。Wu 等人(2023 年)证明,较深的层更适合过滤标记,而较浅的层更适合合并标记。

2.3 局部注意力

各种研究都对局部注意力模式的利用进行了深入探讨,在这种模式下,每个标记都会在一个固定的窗口大小内关注一组相邻的标记,目的是减轻处理长序列时的计算负担。Beltagy 等人最初在 Longformer 中引入了局部窗口注意的概念,提出了一种随序列长度线性扩展的注意机制。Bigbird扩展了这一概念,将窗口注意、随机注意和全局注意机制结合在一起,在降低计算成本的同时保留了长程依赖性。在计算机视觉领域,Swin Transformer采用了类似的方法,将注意力计算限制在非重叠的局部窗口,利用不同层之间的移动窗口来有效捕捉全局上下文。Twins Transformer和 FasterViT采用基于窗口的注意力来提高计算效率,利用全局子采样注意力和分层注意力等不同的模块设计来有效利用全局上下文。在我们的工作中,我们采用固定大小的窗口注意力来加速预训练的扩散变换器模型,并引入了一种名为 “残留共享窗口注意力 ”的新技术,以保留图像标记的长程依赖性。

2.4 注意力共享

GQA将查询头分为 G 组。每个查询保留自己的参数,而每组共享一个键和值,从而减少内存使用并提高效率。PSVIT表明,ViT 中不同层之间的注意力图具有显著的相似性,并建议跨层共享注意力图以减少冗余计算。Deep- cache证明了 U-Net 框架扩散模型中的高级特征在不同时间步之间具有相似性。Deepcache 建议重复使用 U-Net 的高级特征,跳过中间层的计算,以加速去噪过程。TGATE表明,文本条件扩散模型的交叉注意输出在经过几个去噪时间步后会收敛到一个固定点。一旦收敛,TGATE 就会缓存该输出,并在剩余的去噪步骤中保持固定,以降低计算成本。在 DiTFastAttn 中,我们展示了按 CFG 和按步骤的注意力输出的相似性。我们还考虑了不同层在不同步骤中的相似性差异,以共享 CFG 和分步注意力输出。

图片

方法

3.1 概述

在本节中,我们将展示带有变换器的扩散模型推理过程中的冗余。在去噪过程中,我们发现了三种类型的冗余,如图 2 所示:(1) 空间维度上的冗余。(2) 注意力输出中相邻步骤之间的相似性。(3) 注意力输出中条件推理和非条件推理之间的相似性。为了解决这些冗余问题,我们提出了三种压缩技术,如图 2 所示:(1) 在第 3.2 节中,我们引入了带有残差共享的窗口注意力,以减少空间冗余。(2) 在第 3.3 节中,我们引入跨时间步的注意力共享来利用步进相似性,从而提高模型效率。(3) 在第 3.4 节中,我们引入了跨 CFG 的注意力共享,通过利用有条件生成和无条件生成之间的相似性来减少冗余。在第 3.5 节中,我们引入了一种简单的贪婪方法来决定压缩方案,即为每层和每步选择合适的压缩技术。

图片

图 2:冗余类型和相应的压缩技术。左:空间维度的冗余、去噪步骤和 CFG。右图 DiTFastAttn 为减少每种类型的冗余而采用的技术。DiTFastAttn 采用窗口注意力来减少注意力冗余,同时利用残差保持性能。此外,按步骤和按 CFG 共享注意力输出,以减少冗余。

3.2 带有残差共享的窗口注意力(WA-RS)

我们可以在预训练 DiTs 的许多变换层中观察到注意力的空间位置性。如图 3(a)所示,注意力值集中在注意力矩阵对角线区域的一个窗口内。因此,在推理过程中,用固定大小的窗口注意力取代某些层的完全注意力,可以保留注意力矩阵中的大部分值。通过只计算指定窗口内的注意力值,可以大大降低注意力的计算成本。不过,有些标记仍然会关注一小部分空间距离较远的标记。放弃这些注意力会对模型性能产生负面影响。仅使用窗口注意力来缓解这一问题,就必须使用较大的窗口尺寸来捕捉这些依赖关系。因此,这种方法对计算成本的降低微乎其微,从而阻碍了加速工作。

缓存和重复使用窗口注意的残余信息。为了解决上述问题,我们研究了使用窗口注意力所造成的信息损失。如图 3(a)所示,与窗口注意的输出不同,完全注意和窗口注意的输出之间的残差在各步中的变化很小。这一观察结果促使我们在一个步骤中缓存窗口注意力和完全注意力的残差,并在后续步骤中重复使用。图 3(b) 展示了 WA-RS 的计算过程:在每一步,对于每个窗口注意力层,我们计算窗口注意力,并将上一步缓存的残差添加到输出中。我们将共享残差值 Rr 的步骤集合记为 K,步骤 r 的全注意力记为 Or,步骤 k 的窗口注意力记为 Wk。对于 r = min (K) 集合中的第一步,WA-RS 的计算过程如下:

图片

图片

3.3 跨时间步的注意力共享(AST)

扩散模型中去噪过程的顺序性是推理速度的主要瓶颈。在此,我们比较了去噪过程中不同步骤的注意力输出。我们发现,在某些层中,某些步骤的注意力输出与相邻步骤的注意力输出具有显著的相似性。图 4(a) 显示了不同步骤的注意力输出之间的余弦相似性。我们可以得出两个主要结论:(1) 注意力输出之间存在明显的时间相似性;(2) 这种相似性在不同步骤和不同层之间存在差异。

图片

图 3:残差共享的窗口注意力。(a) 左图:显示窗口模式的注意力图例。右图 上一步和当前一步的窗口注意力输出之间的 MSE 与上一步和当前一步的窗口注意力和完全注意力输出残差之间的 MSE 比较。输出残差在各步中的变化非常小。(b) 计算带有残差共享的窗口注意力。重新计算发生重大变化的窗口注意力。变化极小的残差被缓存并在后续步骤中重复使用。

为了利用这种相似性降低计算成本,我们提出了 AST 技术。具体来说,对于一组注意力输出相似的步骤,我们缓存最早步骤的注意力输出 O 并重复使用,从而跳过后续步骤的计算。(b) 有条件生成和无条件生成的注意力输出相似性 (a) 不同层中不同时间步骤的注意力输出相似性。

图片

图 4:DiT 中不同步长和 CFG 维度的注意力输出相似性。(a) 不同层中不同步长维度的注意力输出相似性。(b) 不同层中不同步长的有条件和无条件注意力输出的相似性

3.4 跨 CFG 的注意力共享(ASC)

无分类器引导(CFG)被广泛用于条件生成(Ho & Salimans, 2022; Ramesh et al.) 在条件生成推理过程的每一步中,CFG 都要执行两次神经网络推理:一次有条件输入,一次无条件输入。与无条件生成相比,计算成本增加了一倍。如图 4(b)所示,在许多层和步骤中,有条件和无条件神经网络评估的注意力输出相似度很高。

基于这一观点,我们提出了 ASC 技术,即在无条件神经网络评估中重复使用条件神经网络评估的注意力输出。

3.5 决定压缩计划的方法

上述技术,包括 WA-RS、AST 和 ASC,可以在保持性能的同时有效地降低计算成本。如图 3 和图 4 所示,不同层在不同的时间步长有不同的冗余。因此,正确决定压缩方案,即在每一步对每一层应用哪种技术至关重要。我们开发了一种简单的贪婪方法,从策略列表 S =[AST、WA-RS + ASC、WA-RS、ASC]中为每一步和每一层选择合适的策略(技术组合)。如 Alg. 1 所示,我们一步一步、一层一层地确定策略。对于每个步骤和转换层,我们分别采用四种压缩策略,并计算当前步骤有压缩和无压缩的模型输出之间的损耗 L(O,O′)。然后,我们选择计算减少率最高且损失低于阈值的策略。如果四种策略都没有达到阈值,我们就不在该步骤中对该层进行压缩。

图片

图片

实验

4.1 设置

我们在三种常用的扩散变换器上对 DiTFastAttn 进行了评估:DiT (Peebles & Xie, 2023) 和 Pixart-Sigma (Chen et al., 2024) 用于图像生成任务,Open-Sora (Open-Sora, 2024) 用于视频生成任务。为了证明与快速采样方法的兼容性,我们在 DiT 和 Pixart-Sigma 的 50 步 DPM-Solver 和 Open-Sora 的 100 步 IDDPM(Nichol & Dhariwal,2021 年)的基础上构建了我们的方法。我们使用 ImageNet 作为计算质量指标的评估数据集,并使用 MS-COCO 2017(Lin 等人,2014 年)标题作为 Pixart-Sigma 模型生成图像的文本提示。对于 ImageNet,我们生成 5k 幅图像来评估生成质量。根据之前的研究,我们采用 FID(Heusel 等人,2017 年)、IS(Salimans 等人,2016 年)和 CLIP 分数(Hessel 等人,2021 年)作为评估指标。我们在单个 Nvidia A100 GPU 上测量每个样本的延迟。我们使用平均相对绝对误差 L(O,O′),并以 0.025 的间隔试验不同的阈值 δ。我们将这些阈值设置分别记为 D1(δ=0.025)、D2(δ=0.05)、...、D6(δ=0.15)。我们将 WA-RS 的窗口大小设置为标记大小的 1/8。

4.2 图像生成结果

评估指标和 #FLOPs 的结果。DiTFastAttn 应用于预先训练的 DiT-2-XL- 512、PixArt-Sigma-1024 和 PixArt-Sigma-2K 模型。图 5 显示了这些模型在 ImageNet-5k 验证数据集上的表现。对于 DiT-2-XL-512 和 PixArt-Sigma-1024 模型,配置 D1、D2 和 D3 在 IS 和 FID 指标方面几乎与原始模型的性能相当。配置 D4、D5 和 D6 的 IS 和 CLIP 分数略有下降,作为实现更高压缩率的权衡。对三种模型的压缩效果和评估指标进行比较后发现,随着图像分辨率的提高,DiTFastAttn 不仅实现了更大的压缩,而且更好地保持了模型的生成性能。

图片

图 5:DiTFastAttn 在不同图像分辨率、不同压缩率下的图像生成性能。

表 1:扩散变换器中 DitFastAttn 的 FLOPs 分数和延迟分数与原始注意力的比较。延迟在 Nvidia A100 GPU 上进行评估。

图片

DiTFastAttn 生成结果的可视化。图 6 显示了 DiTFastAttn 的图像生成样本。对于 DiT-2-XL-512 和 PixArt-Sigma-1024 模型,D1、D2 和 D3 显示的视觉生成质量与原始模型相似。虽然 D4、D5 和 D6 实现了更大的压缩,生成图像的细节也略有不同,但它们仍能生成可用的高质量图像。PixArt-Sigma-2K 模型在 D4 之前的图像质量与原始模型相当,而配置 D5 和 D6 则继续生成高质量的图像。

4.3 视频生成结果

我们在 OpenSora 上应用 DitFastAttn 生成视频,我们分别将阈值设置为0.01到0.05。结果如图 10 所示,更多分析结果见附录。

图片

图 6:不同压缩率下不同图像分辨率的图像生成样本。

图片

图 7:使用 OpenSora V1.1 在 240p 分辨率下生成 16 帧视频的比较。

4.4 #FLOPs 减少和加速

DiTFastAttn 对不同序列长度的压缩结果。我们基于 FlashAttention-2 实现了 DiTFastAttn(Dao,2023 年)。表 1 显示了 DiTFastAttn 在扩散变换器中与原始注意力机制相比的 FLOPs 分数和延迟分数。

图片

图 8:不同压缩率下生成不同分辨率图像的延迟

图片

图 9:DiT-XL-2-512 的消融研究。方法影响(左)、时间步变异(中)和残差分担(右)检查。“WA "表示无残余份额 (RS) 的窗口关注。

DiTFastAttn 的总体延迟。图 8 显示了应用 DiTFastAttn 时,随着计算量减少,图像生成和注意力的延迟。在 DiT-XL-512 模型上,在低压缩比设置下,图像生成和注意力计算的总体延迟略有增加。DiTFastAttn 在推理过程中没有带来任何开销。延迟增加的原因是我们的内核没有很好地实现,导致性能略低于 FlashAttention-2。在高压缩比设置下,延迟比原始数据有所减少,其中 D6 的延迟最低。对于 PixArt-Sigma,延迟随着 FLOPs 的减少而继续降低。随着分辨率的提高,DiTFastAttn 在减少整体注意力和图像生成的延迟方面取得了更好的性能。值得注意的是,进一步优化我们的内核实现可以更好地降低延迟。

4.5 消融研究

DiTFastAttn 优于单一方法。如图 9 左侧所示,在计算预算相同的情况下,DiTFastAttn 与单项技术相比保持了更高的质量指标。在单一技术中,AST 的生成质量最好。然而,在 2.2 FLOPs 之后,使用 AST 进一步压缩会显著降低输出,导致搜索算法终止。DiTFastAttn 支持进一步压缩,同时保持更好的质量。

步长越大,DiTFastAttn 的性能越好。如图 9 中间所示,我们比较了 DiTFastAttn 在不同步长下的性能。很明显,随着步长的增加,DiTFastAttn 可以在保证质量的前提下压缩更多的计算量。

残差缓存技术对保持性能至关重要如图 9 右侧所示,在压缩比相同的情况下,带有残差共享的窗口注意力比窗口注意力能保持更好的生成性能。如果没有残差,窗口注意力会导致性能大幅下降。

图片

结论

本文介绍了一种新颖的训练后压缩方法 DiTFastAttention,以加速扩散模型。我们确定了三种类型的冗余:(1) 空间维度上的冗余。(2) 注意力输出中相邻步骤之间的相似性。(3) 注意力输出中条件推理和非条件推理之间的相似性。我们还提出了相应的压缩技术:(1) 带有剩余共享的窗口注意力,(2) 跨时间步的注意力共享,(3)跨 CFG 的注意力共享。实验表明,DiTFastAttention 能显著降低注意力成本,加快计算速度。

图片

未来工作

局限性和未来工作。首先,我们的方法是一种训练后压缩技术,因此无法利用训练来避免性能下降。其次,我们的贪婪压缩方法选择简单,但可能找不到最佳方法。第三,我们的方法只能降低注意力模块的成本。在未来的工作中,我们计划探索训练感知压缩方法。我们还打算将我们的方法扩展到其他模块。此外,进一步的内核级优化可能会为我们提出的压缩技术带来更快的速度。

标签:输出,窗口,DiTFastAttn,无问,生成,180%,注意力,冗余
From: https://blog.csdn.net/2401_87329534/article/details/143360104

相关文章

  • LeetCode|3180. 执行操作可获得的最大总奖励 I(day23)
    作者:MJ昊博客:掘金、CSDN等公众号:程序猿的编程之路今天是昊的算法之路第23天,今天分享的是LeetCode第3180题执行操作可获得的最大总奖励I的解题思路。这是一道中等难度的题目,要求我们在给定的奖励值数组中,通过某些操作尽可能获取最大总奖励。题目描述简要回顾题目要......
  • CF1800E2. Unforgivable Curse (hard version) 题解 并查集
    题目链接:https://codeforces.com/contest/1800/problem/E2视频讲解:https://www.bilibili.com/video/BV1tZ1FYPELp?p=2把下标\(i\)对应到图中编号为\(i\)的节点。节点\(i\)和\(i+k\)之间连一条边,节点\(i\)和\(i+k+1\)之间也连一条边。同一个连通块里的节点对应的字......
  • springboot考研交流平台-计算机毕业设计源码91806
    摘要基于SpringBoot的考研交流平台,精心打造了一个集考研资讯管理、历年真题管理和考研政策管理于一体的全方位服务平台。该平台凭借SpringBoot框架的卓越性能,确保了系统的稳定运行和高效响应,为考研学子提供了实时更新的考研资讯、详尽的历年真题资源和准确的考研政策解读。......
  • 大数据-180 Elasticsearch - 原理剖析 索引写入与近实时搜索
    点一下关注吧!!!非常感谢!!持续更新!!!目前已经更新到了:Hadoop(已更完)HDFS(已更完)MapReduce(已更完)Hive(已更完)Flume(已更完)Sqoop(已更完)Zookeeper(已更完)HBase(已更完)Redis(已更完)Kafka(已更完)Spark(已更完)Flink(已更完)ClickHouse(已更完)Kudu(已更完)Druid(已更完)Kylin(已更完)Elasticsearch(正在更......
  • XL6019芯龙180KHz 60V 5A开关电流升压/升降压型DC-DC转换器
    描述XL6019是一款专为升压、升降压设计的单片集成电路,可工作在DC5V到40V输入电压范围,低纹波,内置功率MOS。XL6019内置固定频率振荡器与频率补偿电路,简化了电路设计。PWM控制环路可以调节占空比从0~90%之间线性变化。内置过电流保护功能与EN脚逻辑电平关......
  • 基于51单片机的大气压强检测仪(BMP180)(程序+Proteus仿真)
    编号:60基于51单片机的大气压强检测仪(BMP180)功能描述:   本设计由51单片机+BMP180大气压强检测模块+1602液晶显示模块组成。1、主控制器是51单片机2、利用BMP180传感器读取大气压强、温度、海拔高度等信息3、1602液晶显示大气压强、温度、海拔高度等信息视频演示链......
  • 《纪元1800》遭遇dll丢失问题无法启动:msvcr71.dll丢失详解与定制化解决方案
    《纪元1800》是一款非常受欢迎的城市建设和经济策略游戏,但有时玩家可能会遇到msvcr71.dll丢失的问题,导致游戏无法启动。msvcr71.dll是MicrosoftVisualC++运行库的一部分,负责支持许多应用程序的运行。以下是对msvcr71.dll丢失问题的详细解释及定制化解决方案。问题原......
  • 180+ 优质YouTube频道推荐:数据科学、机器学习、人工智能等领域学习资源汇总
    yt-channels-DS-AI-ML-CS180+优质YouTube频道推荐:数据科学、机器学习、人工智能等领域学习资源汇总在这个信息爆炸的时代,YouTube已经成为许多人学习新知识的重要平台。特别是在数据科学、机器学习、人工智能等热门技术领域,有大量优质的教学内容。本文整理了180多个高质量的Y......
  • 【刷题笔记】[ABC180F] Unbranched
    【刷题笔记】Unbranched题意求\(N\)个点,\(M\)条边且满足以下条件的图的数量:1.图中无自环;2.每个点度数最多为2;3.连通块大小的最大值恰好为L。答案对\(10^9+7\)取模。\(1\leM,L\leN,2\leN\le300\)思路注意构造出来的图,不一定是联通的,所以容易联想到将一个联通分量......
  • GB 18030及生僻字治理
     名词解释:编码字符集codedcharacterset一组无歧义的规则,用以建立一个字符集和该字符集中的字符及其编码表示之间的对应关系,通常也指按照这种规则确定的文字的有序集合。示例:1.GB18030是我国制订的以汉字为主并包含多种我国少数民族文字(例如藏、蒙古、傣、彝、朝鲜、维......