首页 > 其他分享 >FlashAttention全解

FlashAttention全解

时间:2024-07-24 21:29:55浏览次数:14  
标签:right mathbf FlashAttention ell 计算 全解 left

目录

LLM大模型训练加速利器FlashAttention详解

  • Attention层是扩展到更长序列的主要瓶颈,因为它的运行时间和内存占用是序列长度的二次方。使用近似计算的Attention方法,可以通过减少FLOP计算次数、甚至于牺牲模型质量来降低计算复杂性,但通常无法实现大比例的加速。

  • FlashAttention没有进行近似计算,所以也没有精度损失。然而,FlashAttention的实际速度仍然和理论上的运算速度差距较大,仅达到理论最大 FLOPs/s 的 25-40%。效率低下的原因主要是不同线程块和warp之间的工作分区不理想,导致低占用率或不必要的共享内存读/写。

    为此,2023年7月,论文作者进一步提出了FlashAttention-2,实现了Attention计算速度的大幅度提升。

一、FlashAttention

1.1 硬件基础

  • 我们常说的A100 80G,80G指的是GPU中的HBM存储,其上还有更为快速的SRAM,其大小约为20MGB。

  • 一次注意力计算分为多步计算过程,第一步 \(QK^T\),第二步\(softmax\),第三步\(\cdot V\),每一步计算产生的中间结果都需要存储到HBM中,需要时在进行读取,其复杂度为\(O(N^2)\)。

  • SRAM是非常有限的,无法把所有的数都加载的SRAM里,序列长度N(Token的数量)通常是以k来计算,4k和8k是比较常见的,某些应用(code)甚至希望能到64k和128k。因此\(N^2\)会增长的非常快。

  • 每次计算产生的中间结果都需要\(O(N^2)\)的开销将中间结果移动到HBM中,通信代价 > 计算代价

如何降低通信代价?让更多的操作发生在SRAM上?计算的中间结果,如何才能不传输到HBM?全部在SRAM中进行?

1.2 FlashAttention 核心思想

如何改造计算流程,让中间结果存储到HBM的过程不要发生?

Flash Attention 将计算模块化,将QKV分为若干个模块进行计算,在计算过程中不存储$ N \times N $的矩阵

image-20240715194054953

最终只有输出\(O_1\)涉及存储到HBM中。

1.3 计算前提

仅仅是将矩阵进行分块计算这样就可以了吗?No

  1. 数值稳定性:在计算注意力的过程中,涉及到\(softmax\)操作,\(softmax\)包含指数函数,所以为了避免数值溢出问题,可以将每个元素都减去最大值,对于一个向量来说,我们给每一个数减去相同的任一常量,其\(softmax\)​是不变的。

    \(m(x):=\max _{i} \quad x_{i}, \quad f(x):=\left[\begin{array}{lll} e^{x_{1}-m(x)} & \ldots & e^{x_{B}-m(x)} \end{array}\right], \quad \ell(x):=\sum_{i} f(x)_{i}, \quad \operatorname{softmax}(x):=\frac{f(x)}{\ell(x)}\)

  2. 分块计算softmax:同时对于\(softmax\)操作是按行计算的,如果对齐进行分块,那么每行的最大值如果只考虑分块后的最大值就产生了偏差,需要维护每行的最大值用于分块以及上述数值的稳定性计算。

    我们考虑将一行分为两部分的情况,即原本的一行数据 \(x \in \mathbb{R}^{2 B}=\left[x^{(1)}, x^{(2)}\right]\)

    考虑整行正确的 \(f(x^{(1)}) = e^{x^{(1)}-max(x)}\)

    有偏差的考虑分块的 \(f'(x^{(1)}) = e^{x^{(1)}-max(x^{(1)})}\)

    如何进行纠正?\(f(x^{(1)}) =f'(x^{(1)}) \times e^{max(f'(x^{(1)})) - max(x)}\)

    所以说,分块计算之后,要进行纠正,需要维护整行的\(max\)值。

1.4 FlashAttention 算法

image-20240715201049158

简要描述:外层循环遍历\(K^T\),内层循环遍历\(Q\)

\(Q\)的第一块(按行拆分)和\(K^T\)的第一块(按列拆分)计算得到中间结果\(S\)中\((0, 0)\)的结果

\(Q\)的第二块(按行拆分)和\(K^T\)的第一块(按列拆分)计算得到中间结果\(S\)中\((1, 0)\)的结果

\(Q\)的第三块(按行拆分)和\(K^T\)的第一块(按列拆分)计算得到中间结果\(S\)中\((2, 0)\)的结果

\(Q\)的第三块(按行拆分)和\(K^T\)的第一块(按列拆分)计算得到中间结果\(S\)中\((3, 0)\)的结果

这里应该是可以并行计算的,因为这是中间结果的第一列,不涉及softmax操作,暂时还未涉及到分块计算softmax纠正偏差的问题。

当外层循环计算到\(K^T\)的第二块(按列拆分)时,就会开始填充中间结果\(S\)的第二列,这时有了第一列和第二列的结果,就需要进行\(softmax\)纠正偏差。

至此,后面计算到中间结果的每列时,每当有新的中间结果加入时,都需要对该中间结果所在的行进行纠正错误。

纠正方法:\(\times e^{max(x_{sec}) - max(x)}\)

image-20240715202212075

每次循环计算时,将\(K_j, V_j\)加载到SRAM,占据SRAM的50%的存储,将\(Q_i,O_i\)加载进SRAM,占据另一半的显存。\(l_i, m_i\)比较小,按作者的说法可以放进寄存器。

image-20240715202701498

  1. 初始化

    按列拆分,每个块的大小为 \(B_c = \left\lceil\frac{M}{4 d}\right\rceil\),将\(K, V\)拆分为\(T_{c}=\left\lceil\frac{N}{B_{c}}\right\rceil\)个列块,形状 \(B_c \times d\)。

    按行拆分,每个块的大小为 \(B_{r}=\min \left(\left\lceil\frac{M}{4 d}\right\rceil, d\right)\),将\(Q\)拆分为 \(T_{r}=\left\lceil\frac{N}{B_{r}}\right\rceil\) 个行块,形状 \(B_r \times d\)。

    将\(O\)按行拆分为 \(T_r\) 个块,形状 \(B_r \times d\)。

    将\(l,m\)均拆分为 \(T_r\) 个块,大小为 \(B_r\) 的一维向量。

  2. 将 \(Q_i, O_i, l_i, m_i\) 从HBM移动到SRAM

  3. 计算 \(S_{ij} = Q_iK_j^T\)

  4. 修正每个行块的 \(\tilde{m}_{i j}\),修正每个块的中间计算结果 \(\tilde{\mathbf{P}}_{i j}=\exp \left(\mathbf{S}_{i j}-\tilde{m}_{i j}\right)\),修正每行的累积和 \(\tilde{\ell}_{i j}\)

  5. 更新每块

二、FlashAttention-2

  • 去年7月,FlashAttention-2发布,相比第一代实现了2倍的速度提升,比PyTorch上的标准注意力操作快5~9倍。
  • 在H100上仅实现了理论最大FLOPS 35%的利用率。
  • FlashAttention(以及FlashAttention-2)通过减少内存读写次数,开创了一种在GPU上加速注意力机制的方法,现在大多数库都使用它来加速Transformer的训练和推理。这使得大语言模型的上下文长度在过去两年中大幅增加,从2-4K(如GPT-3、OPT)扩展到128K(如GPT-4),甚至达到1M(如Llama 3、Gemini 1.5 Pro)。

2.1 硬件特性

GPU存在大量的线程(被称为kernel)用于执行一个操作。线程被组织为线程块,线程块被调度在 streaming multiprocessors (SMs) 上运行。

在每个线程块内部,线程被分组为 warps (包含32个线程的线程组)。

warp 内的线程可以通过 fast shuffle instructions 进行通信或协同执行矩阵乘法。

线程块内的warps 可以通过对共享内存读写进行通信。

每个线程 (kernel) 从HBM 中加载输入到寄存器和SRAM中计算,然后将输出写回 HBM。

2.2 标准的注意力实现

给定输入序列 \(\mathbf{Q}, \mathbf{K}, \mathbf{V} \in \mathbb{R}^{N \times d}\),计算注意力输出 \(\mathbf{O} \in \mathbb{R}^{N \times d}\).

中间计算过程 \(\mathbf{S}=\mathbf{Q} \mathbf{K}^{\top} \in \mathbb{R}^{N \times N}, \quad \mathbf{P}=\operatorname{softmax}(\mathbf{S}) \in \mathbb{R}^{N \times N}, \quad \mathbf{O}=\mathbf{P V} \in \mathbb{R}^{N \times d}\)

softmax是基于每行进行操作的。

对于多头注意(MHA),这个相同的计算是在多个头部上并行执行的,并在批处理维度(批处理中输入序列的数量)上并行执行的。

注意力的反向传播过程如下。设 \(\mathrm{dO} \in \mathbb{R}^{N \times d}\)​ 是O相对于某些损失函数的梯度。然后根据链式规则(即反向传播):

\(\mathbf{d V}=\mathbf{P}^{\top} \mathbf{d} \mathbf{O} \in \mathbb{R}^{N \times d}\)

\(\mathbf{d} \mathbf{P}=\mathbf{d} \mathbf{O} \mathbf{V}^{\top} \in \mathbb{R}^{N \times N}\)

\(\mathbf{d S}=\operatorname{dsoftmax}(\mathbf{d} \mathbf{P}) \in \mathbb{R}^{N \times N}\)

\(\mathrm{dQ}=\mathrm{dSK} \in \mathbb{R}^{N \times d}\)

\(\mathbf{d K}=\mathbf{Q d} \mathbf{S}^{\top} \in \mathbb{R}^{N \times d}\)

dsoftmax是逐行应用的softmax的梯度。

对于向量\(s, p\),如何从输出梯度 \(d_p\)计算输入梯度 \(

标签:right,mathbf,FlashAttention,ell,计算,全解,left
From: https://www.cnblogs.com/mudou/p/18321760

相关文章

  • WebKit的文本装饰艺术:CSS Text Decoration全解析
    WebKit的文本装饰艺术:CSSTextDecoration全解析CSS文本装饰(TextDecoration)是一组用于美化和增强网页文本表现的属性,它们可以为文本添加下划线、上划线、线删除和强调标记等效果。WebKit作为许多现代浏览器的渲染引擎,对CSS文本装饰的支持非常全面。本文将深入探讨WebKit对......
  • IPython的Bash之舞:%%bash命令全解析
    IPython的Bash之舞:%%bash命令全解析IPython的%%bash魔术命令为JupyterNotebook用户提供了一种在单元格中直接执行Bash脚本的能力。这个特性特别适用于需要在Notebook中运行系统命令或Bash特定功能的场景。本文将详细介绍如何在IPython中使用%%bash命令,并提供实际的代码示......
  • Java内存模型全解析:解决共享变量可见性与指令重排难题
    本期说一下Java内存模型(JavaMemoryModel,JMM)及共享变量可见性问题。“以下内容出自本人整理的面试秘籍。点击此处,无套路免费获取面试秘籍JMM是什么?答:Java内存模型(JavaMemoryModel,JMM)抽象了线程和主内存之间的关系就比如说线程之间的共享变量必须存储在主内存......
  • 低代码开发知识全解:提升开发效率的利器
    一、低代码开发的基本概念低代码开发是一种利用可视化工具和简化的编程接口来创建应用程序的方法。通过拖放组件、配置参数和使用预设模板,开发者可以在无需编写大量代码的情况下完成应用程序的设计和实现。这种方法不仅提高了开发效率,还使得非技术人员能够参与应用程序的开发......
  • ava 集合框架全解析:Collection vs Collections,Comparable vs Comparator,HashSet 工作
    Java中的集合框架是开发过程中不可或缺的一部分,但初学者常常会混淆其中的术语,特别是Collection和Collections。这篇博客将详细介绍它们之间的区别,并进一步探讨Comparable和Comparator、HashSet的工作原理,以及HashMap和Hashtable的区别。Collection和Collecti......
  • 快手矩阵系统全解析:功能、优势与特点一网打尽
    在数字化时代,短视频已成为连接创作者与观众的重要媒介。快手矩阵系统以其独特的功能和优势,为短视频的创作、管理和发布提供了一站式解决方案,极大地提升了内容运营的效率和效果。功能概览智能创作:AI技术的应用使得快手矩阵系统能够自动生成与视频主题高度契合的文案,极大提升了......
  • JngLoad.dll 缺失报错问题全解:从原理到实践的修复流程
    jngload.dll是一个动态链接库(DynamicLinkLibrary)文件,通常与处理JPEG2000图像格式的软件相关联。JPEG2000是一种高效率的图像压缩标准,jngload.dll可能包含了处理这种图像格式所需的功能和算法。这个DLL文件常出现在处理图形和图像的软件中,比如某些图像编辑软件或游戏,用来读......
  • PD还是QC?快充协议全解析
    什么是快充协议快充协议是一种通过提高充电效率来缩短设备充电时间的电池充电技术。它是通过在充电器和设备之间建立一种沟通机制,充电器能够根据设备的需求和状态,调整输出的电压和电流。这种沟通机制由快充协议定义,它决定了设备和充电器如何互相识别和交流,以及如何调整电力输出。......
  • 力扣-动态规划全解
    目录动态规划斐波那契数列-EASY爬楼梯-EASY使用最小花费爬楼梯-EASY不同路径-Middle不同路径II-Middle不同路径III-HARD整数拆分-MID*不同的二叉搜索树-MID背包问题-理论基础分割等和子集-EASY最后一块石头的重量II-MID目标和-MID*一和零-MID*53-最大子数组和-中等918-环形子数......