首页 > 其他分享 >【Transformer 基础系列】手推显存占用

【Transformer 基础系列】手推显存占用

时间:2023-12-26 22:57:08浏览次数:73  
标签:显存 Transformer 模型 手推 中间 KV 激活 优化

https://zhuanlan.zhihu.com/p/648924115

 

本文试图以最清晰的方式手动推导 Transformers 每一步的参数量到显存、计算量问题。理解底层,才能更好的做训练和优化。可能是目前最全的大模型显存优化方案分析。

本文内容包括
(1)模型训练和推理过程中的显存占用
(2)KV cache、中间激活值等显存占用
(3)模型状态显存优化方案: Megatron(3D) + Deepspeed(ZeRO)(更新于2023-09-11)
(4)激活值显存优化方案:重计算 + 3D 并行(更新于2023-08-11)
(5)KV Cache 显存优化方案:MQA 和 GQA(更新于2023-09-11)

关于计算量、参数量的分析在本系列其他文章记录。

乞力马扎罗不说话:【Transformer 基础系列】手推计算量FLOPS和训练时间
乞力马扎罗不说话:【Transformer 基础系列】模型参数量

0 前置知识和标记

  1. 显存占用 = 参数数量 x 该参数精度占用的 bytes 数
  2. 换算关系:Int8 需1 bytes, fp16 / bf16 数需 2 bytes, fp32 需要 4 bytes
  • transformer 模型的层数为 �
  • 隐藏层维度为 ℎ
  • 注意力头数为 �
  • 词表大小为 �
  • 批次大小为 �
  • 序列长度为 �

1 训练过程

训练中的显存占用分两块,分别是:

  1. 模型状态,参数、梯度和优化器状态
  2. 剩余状态, 中间激活值、临时buffer、显存碎片等

1-1 模型状态显存

模型状态指的是和模型参数、梯度和优化器状态相关的显存占用。

设模型参数量为 Φ ,模型参数(fp16)、模型梯度(fp16)和优化器状态(fp32),总参数量 = 2Φ+2Φ+�Φ=(4+�)Φ 。参数量和模型配置之间的关系可以看另一篇文章推导,合计约 �ℎ+�(12ℎ2+13ℎ) 。

一般是混合精度训练,梯度/权重为 fp16,但所有涉及累加操作都需要 fp32 防止误差累计,同时优化器也要存 fp32 主权重。以 Adam 系列为例,总数为 2Φ+2Φ+(4+4+4)Φ=16Φ 。

  1. 这部分比较固定,主要和参数量有关,和输入大小无关。
  2. 在整个训练过程中都要存在显存中。 模型参数一般只能通过并行切分(Tensor Parallelism/Pipeline Parallism)能减少。优化器状态一般通过ZeRO 来减少。
  3. 不同优化器的 K 值不同,算法的中间变量、框架的实现都有可能有一定区别。复旦 LOMO 的方法也是基于类似的思路重新改进 SGD 来减少 K 值和梯度部分显存。

不同优化器的 K 值

优化器K值构成
adamw 12 fp32 主权重 4 + 动量 4 +方差 4
SGD 8 fp32 主权重 4 + 动量 4
bitsandbytes 6 fp32 主权重 + 动量 1 + 方差 1
LOMO 0  

1-2 中间激活值显存

激活(activations)指的是前向传递过程中计算得到的,并在后向传递过程中需要用到的所有张量。

中间激活值占用显存分两个部分分析:Attention 和 MLP,Embedding 没有中间值。最终合计 (34��ℎ+5��2�)∗�=(13��ℎ+5��2�+21��ℎ)∗� 。

  1. 这部分比较灵活,激活值与输入数据的大小(批次大小 b 和序列长度 )成正相关。
  2. 在训练过程中是变化值,特别是 batch size 大的时候成倍增长很容易导致 OOM。
  3. 可以通过重计算、并行切分策略减少。

直接看公式不太直观,下面是 GPT-3 和 LLaMA 为例计算的模型显存和中间激活值显存占用比例。

Attention 层中间显存表

self-attention 块的计算公式如下: �=���,�=���,�=��� ����=�������(���ℎ)⋅�⋅��+�

Attention 层单步中间激活值显存表

MLP 层中间显存表

MLP 块的计算公式如下:

�=�����(�����1)�2+����

MLP 层单步中间激活值显存表

2 模型状态显存优化方案

如 1-1 所推,模型状态占用 2Φ+2Φ+�Φ=(4+�)Φ,其中一般只能通过各种各样的并行来解决。比如模型参数显存优化一般是 模型并行,包括张量并行 (tensor parallel, TP) 和流水线并行(pipeline parallel, PP),业内通用方案参考 Megatron。只做数据并行 (data parallelism, DP) 情况下,模型参数和优化器状态一般通过 Deepspeed ZeRO 来均摊到所有卡上。

总的来说,都是用通信时间换显存空间。业内很多框架也是基于 Megatron+Deepspeed 这一套比较成熟的底层上改的。

2-1 Megatron-LM 3D Parallel

Megatron-LM 里称之为 Model Parallel,也叫 Tensor Parallel。

Q / K / V 矩阵做列切分(纵刀流),对Dropout做行切分(横刀流),方便GPU 中间计算各算各的,减少额外通信

不切分的时候各层参数如下表

Model Parallel 需要切分所有参数 embedding / attention / mlp 为 � 份,其中 embedding 层 V 在 Megatron 中会补全到最小的 � 倍数以便于切分。因此,显存为 �′ℎ+(12ℎ2+13ℎ)�� 。

Pipeline Parallel 需要按层切分所有参数,一般是 � 层均分 � 份,embedding 在最前面一层或单独一层。不过针对一些奇特结构不能整除的(比如44层的 NeoX)可能需要设计特定切分策略。每层显存为 (12ℎ2+13ℎ)�/� 。

这里显存都没什么好说的,主要是通信量值得分析。

2-2 ZeRO Stage 1-3

Deepspeed ZeRO 本质上都是在数据并行层面对模型状态一步步做分片(partition),系统内只维护一份模型状态,需要全量状态时就执行通信。

ZeRO Stage: 不同 stage 区别主要是切什么,显存占用论文里这张图就很直观了。

  • Stage 1(P os): fp32 optimizer state
  • Stage 2(P os+g): fp32 gradient + Stage 1
  • Stage 3(P os+g+p): fp16 parameters + Stage 1 + Stage 2

  1. 总卡数越多越省。stage 1 下,按照一个节点估算,模型状态从 16Φ→5.5Φ ,如果是现在一般规模的预训练规模,卡数至少上百,优化器状态可以忽略不计,模型状态基本接近。
  2. Stage 1和2 不会额外增加通信量,Stage 3 会额外增加 50%(forward 和 backward 时分别一次 broadcast 参数以获得全量参数),因此后面 Deepspeed ZeRO++ 支持了 stage 3 量化和参数分层存储来降低通信量。
  3. ZeRO 除了分片还支持 offload,显存不够内存来凑,但是内存显存之间的 I/O 成本也不可忽视,因此实际训练中还是很少用。
  4. All-reduce 通信到底怎么充分利用设备和设备之间的带宽也很有趣,请参考袁老师文章 OneFlow:手把手推导Ring All-reduce的数学性质

3 中间激活值显存优化

1-2 中中间激活值式子可以看到,激活值与输入数据的大小成正相关,batch size 较大时远超过模型参数占用。因此主要显存优化是优化中间激活值,有重计算和并行两个思路

(34��ℎ+5��2�)∗�=(13��ℎ+5��2�+21��ℎ)∗�

3-1 重计算

  • activation checkpoint (recompute) :时间换空间,前向的时候重新计算一次来避免存储。计算量的增加参考另一篇博客。
    • 全部重计算可以减少到只有每个 attn 层输入的 2��ℎ�
    • 部分重计算可以减少 �(�2) 项相关的 QK 乘法中间结果,其他不变,减少到 34��ℎ�

3-2 TP 中间激活值

Tensor Parallelism 通过切 attention/mlp 层减少中间值

  • attention (8��ℎ+5��2�)/�
  • mlp 16��ℎ/�
  • dropout/layernorm 6��ℎ (外层的不受影响,但 softmax dropout 也要切 t)
  • attention/mlp input 2��ℎ+2��ℎ (f' 表示需要在 forward/backward 中需要 all reduce因此attn, mlp 输入也是完整的)

显存合计 ��ℎ(10+24�+5��ℎ�)

3-3 SP+TP 中间激活值

Sequence Parallelism 输入沿着 seq 维度切,从而进一步减少两个输入和 layernorm,dropout 的中间激活值

  • attention (8��ℎ+5��2�)/� 不变
  • mlp 16��ℎ/� 不变
  • dropout/layernorm 6��ℎ/� 外层的 sequence parallel 也切 t 份
  • attention/mlp input (2��ℎ+2��ℎ)/� ,外层g, g' 是 all-gather 操作

显存合计 ��ℎ�(34+5��ℎ)

3-4 PP+SP+TP 中间激活值

Pipeline Parallelism 没有减少

  • 和 pp size 无关, 1F1B pp 同时有 L/p 个 microbatch,即便参数只有 L/p 这么多,但是激活状态需要整个 batch 全保留才能backward 时用
  • Megatron 里 interleaving 如果开了需要存 �(1+(�−1)/��) 层的,m 为 interleaving stage

显存合计 ��ℎ��(34+5��ℎ)

3-5 总结

上述优化方案和组合方案优化后的中间激活值如下表

以 LLaMA 和 GPT 预估部分重算情况下模型显存和中间激活值比例

感兴趣也可以根据公式算全部方案下中间激活值节省。以下是博客:

Qi's Blog​parallel-carol-55e.notion.site/Qi-s-Blog-dd2088884ac94586b096202e9b68e221?pvs=4

4 推理过程

推理显存没有梯度和优化器,主要是模型参数,一般总显存经验值估算为 1.2 倍参数量

  1. 模型参数 fp16 下推理参数占 2Φ bytes
  2. KV Cache (如有) 缓存 KV Cache 加速方法
  3. 中间结果和输入数据 比较少,一般 20% 内

4-1 KV Cache 显存分析

KV Cache 是典型的推理加速方法,推理时缓存第 n 个 token 及前计算结果,第 n+1 个 token 相当于增量计算从而加速。

  1. 预填充阶段:输入一个 prompt 序列,为每个 transformer 层生成 key cache 和 value cache(KV cache) ()��⋅���(�) ,其中 ,��∈[�,���],��∈[ℎ,���],���∈[�,���,ℎ] 。这里是简化后的单头,多头时 ��∈[�,���,�,ℎ/�] 。
  2. 解码阶段:拼接并 concat 更新 KV cache,一个接一个地生成词,当前生成的词依赖于之前已经生成的词。假设输入序列的长度为 s ,输出序列的长度为 n ,最后一个token 推理时长度为 (s+n), KV Cache 占用峰值。

所以每层 2个K/V 各(s+n)bh ,每个 fp16 占 2 个 bytes,KV cache 的峰值显存占用大小为 �(�+�)ℎ∗�∗2∗2=4��ℎ(�+�)

KV Cache 占模型显存比例

4-2 MQA & GQA

面向推理的显存(和速度)的优化主要是 Multi-Query Attention (MQA) 和 Group-Query Attention (GQA),本质上是通过多头共用 KV Cache 减少内存 I/O 时间占总时间比例。已经应用或支持的包括 ChatGLM2、LLaMA2、和 flash attention v2 解决方案。

MHA(n:n) vs GQA(n/t : n) vs MQA(1:n)

这里显存的节省比较简单,如 MQA KV Cache 为原来的 1/n 倍,GQA 为原来的 1/������ 倍。主要是为了加速而不是显存优化提出的方法,推内存时间减少的比较值得一看。

5 参考

[1] Reducing activation recomputation in large transformer models
[2] ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
[3] Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelis
[4] https://arxiv.org/pdf/1911.02150.pdf
[5] https://arxiv.org/pdf/2305.13245.pdf
[6] OneFlow:手把手推导Ring All-reduce的数学性质

标签:显存,Transformer,模型,手推,中间,KV,激活,优化
From: https://www.cnblogs.com/chinasoft/p/17929552.html

相关文章

  • 自然语言处理的模型:从 Bag of Words 到 Transformer
    1.背景介绍自然语言处理(NLP)是计算机科学与人工智能的一个分支,研究如何让计算机理解、生成和处理人类语言。自然语言处理的主要任务包括文本分类、情感分析、命名实体识别、语义角色标注、语义解析、机器翻译等。随着大数据时代的到来,自然语言处理技术的发展得到了巨大的推动。在过......
  • 人工智能大模型原理与应用实战:从Transformer到Vision Transformer
    1.背景介绍人工智能(ArtificialIntelligence,AI)是计算机科学的一个分支,研究如何让计算机模拟人类的智能。在过去的几年里,人工智能技术的发展取得了显著的进展,尤其是在自然语言处理(NaturalLanguageProcessing,NLP)和计算机视觉(ComputerVision)等领域。这些进展主要归功于深度学习......
  • transformer 预测 ENSO
    第一篇《Aself-attention–basedneuralnetworkforthreedimensionalmultivariatemodelinganditsskillfulENSOpredictions》发表在SciAdv. 张荣华起名3D-Geoformer摘要中说SSTanomalyprediction18个月,但文中又说是12个月预测未来20个月 由于耦合了海温(异......
  • transformer总体架构
    transformer总体架构目录transformer总体架构循环神经网络总体架构EncoderDecoder输入输出层模型输入位置编码模型输出自注意力机制关于QKV的理解Q,K,V及注意力计算多头注意力机制多头注意力机制作用FeedForward层参考资料论文地址:AttentionisAllYouNeedhttps://arxiv......
  • transformer补充细节
    transformer补充细节目录transformer补充细节注意力机制细节为什么对点积注意力进行缩放多头带来的好处数据流训练时数据流推理时数据流解码器中注意力的不同带掩码的注意力机制位置编码整型数值标记[0,1]范围标记位置二进制标记周期函数标识用sin和cos交替来表示位置训练测试细......
  • Sw-YoloX An anchor-free detector based transformer for sea surface object detect
    Sw-YoloXAnanchor-freedetectorbasedtransformerforseasurfaceobjectdetection基于Transformer用于海上目标检测的无锚检测器:Sw-YoloX1)由于不同海洋状态下的活体和漂浮物体数据稀缺且昂贵,我们基于2022年1月至3月在中国厦门的实际海面测量,构建了XM-10000基准数据集。......
  • Vision Transformer with Super Token Sampling
    VisionTransformerwithSuperTokenSampling*Authors:[[HuaiboHuang]],[[XiaoqiangZhou]],[[JieCao]],[[RanHe]],[[TieniuTan]]Locallibrary初读印象comment::ViT在捕捉浅层局部特征时可能会出现高冗余度的问题,使用strongsupertoken提供具有语义意义的视......
  • Bottleneck Transformers for Visual Recognition
    BottleneckTransformersforVisualRecognition*Authors:[[AravindSrinivas]],[[Tsung-YiLin]],[[NikiParmar]],[[JonathonShlens]],[[PieterAbbeel]],[[AshishVaswani]]DOI:10.1109/CVPR46437.2021.01625初读印象comment::(BoTNet)通过在ResNet的最后三个......
  • SeaFormer: Squeeze-enhanced Axial Transformer for Mobile Semantic Segmentation
    SeaFormer:Squeeze-enhancedAxialTransformerforMobileSemanticSegmentation*Authors:[[QiangWan]],[[ZilongHuang]],[[JiachenLu]],[[GangYu]],[[LiZhang]]初读印象comment::(SeaFormer)提出了一种适用于移动设备的轻量级网络,设计了一个通用的注意力块,特......
  • BiFormer: Vision Transformer with Bi-Level Routing Attention 使用超标记的轻量ViT
    alias:Zhu2023atags:超标记注意力rating:⭐share:falseptype:articleBiFormer:VisionTransformerwithBi-LevelRoutingAttention*Authors:[[LeiZhu]],[[XinjiangWang]],[[ZhanghanKe]],[[WayneZhang]],[[RynsonLau]]Locallibrary初读印象comm......