首页 > 其他分享 >LLM并行训练7-混合并行总结

LLM并行训练7-混合并行总结

时间:2024-07-29 21:29:35浏览次数:14  
标签:总结 显存 frac 并行 batch TP LLM size

概述

根据前面的系列文章, 对预训练大模型里用到的主要并行加速技术做了一系列拆分分析. 但是在实际的训练里往往是多种并行混合训练. 我们要怎么配置这些并行策略才能让训练框架尽可能的减少通信瓶颈, 提升GPU计算利用率呢? 这里的变量太多了, 以最简单的3D并行为例:

  • 硬件层面有: 单台机器的卡数/卡间带宽/网卡带宽, 机器间通信时的网络拓扑构建.
  • 并行策略上有: 张量并行数/流水线并行数/数据并行数
  • 训练超参有: batch_size / AttnHeads / seq_len / hidden_size

如果靠脑补来调整这些参数, 会存在一个非常巨大的搜索空间, 很难找到最优于计算效率的方法, 所以需要先通过理论分析确定各个参数的大致范围. 最后再通过有限次尝试找到较优的方案. 本章参考nvidia的调参实践GTC演讲, 对GPT-3训练尝试进行理论调参分析.

并行方法适用场景分析

后文的标记备注:

  • \((p, t, d)\) : 3D 并行维度. \(p\) 代表流水并行数, \(t\) 代表张量并行数, \(d\) 代表数据并行数
  • \(n\): 总共的 GPU 数量. 要求 \(p\cdot t \cdot d = n\).
  • \(B\): Global batch size.
  • \(b\): Microbatch size.
  • \(b^{'}\): 一个流水线要处理的 batch size 大小, 等于 \(B/d\).
  • \(m = \frac{1}{b} \cdot \frac{B}{d}\): 一个 batch 在每个 pipeline 的 microbatch 的数量.
  • \(s\): seq_len.
  • \(h\): hidden_size = emb_size * attention_head_size
  • \(a\): attention_head_size

张量并行(TP)

TP开销

模式 Normal ColParallel ratio
flops (n次乘法 + n次加法)* n^2 = 2n^3 2n^3/t 1/t
Bandwidth (n^2)【n*n 矩阵的读或写】 * 2(fp16)) * 3(读 X、读A,写 Y) = 6n^2 2n^2 + 4n^2/t(A,Y切分) (1+2/t)/3
Intensity(flops/bandwidth) n/3 n/(2+p) 3/(2+t)

当并行度\(t\)增长的时候, 可以看到intensity也处于一个增长的趋势. 需要权衡通信和计算成本的平衡, 由于TP需要在结束时进行一次激活的AllReduce, 在多机通信上会导致较高的通信成本. 所以TP一般只考虑在单机卡间通信时使用. TP在LLM里主要有两个使用场景:

  1. MLP先列再行, 这块前后一般会和SP结合进行将AllReduce拆分为allGather和reduceScatter
image-20240727180417038
  1. attention处多头切分并行 每个头之间的计算各自独立, 所以可以进行切分计算.
image-20240727180501261

流水线并行(PP)

流水线主要是将一个batch的数据切分为多个mirco-batch, 在micro-batch之间做异步并行. 因为通信内容只包含切分stage的输出, 而且是点对点通信, 不需要多点集合通信. 通信数据量小, 因此比较适合在多台机器间通信的场景. LLM里一般把一个transformLayer作为一个stage, 在多个stage之间构建pipeline, 如下图:

image-20240727182819634

混合并行

当网络结构确定后, 一般TP和PP就能估算到比较合理的区间, 最后根据显存容量的计算来估计DP需设置的值.

image-20240727195029105

TP与PP的策略分析

数据并行度\(d=1\)时, \(p * t = n\), 会有以下计算公式:

  • 流水线bubble_time: \(\frac{(p-1)}{m}=\frac{n/t-1}{m}\), 提高TP并行度时会减少气泡占比, 但会增大单机内部的通信量, tp内部一个microbatch需要4个allReduce(fp/bp各两个)

  • 单机单次allReduce通信量: \(2bsh(\frac{t-1}{t})\), (layer激活为\(bsh\), allReduce通信量为数据量2倍)

  • 流水线并行时单个micro-batch机器间通信量为: \(C_{inter} = 2bsh\) (fp/bp各一次)

设一个pipeline内有\(l^{stage}\)个transformLayer, 则在1F1B非交错调度的情况下单个stage单机内部通信量为:

\[C_{inra} = l^{stage}\cdot4\cdot2bsh(\frac{t-1}{t}) = l^{stage}\cdot4\cdot2bsh(1-\frac{1}{t}) \]

所以机器间和机器内的通信量关系为:

\[C_{intra}= l^{stage}\cdot4\cdot(1-\frac{1}{t}) \cdot C_{inter} \]

因为机器间通信速率远小(IB 200GB/s)于卡间通信(NVLink 600GB/s), 所以我们如果希望优化吞吐, 那么应该尽量降低机器间通信比率.

[!TIP]

也就是在不会导致TP产生机器间通信的前提下让t尽可能的大. 如果这样还放不下模型,再使用流水线并行来切分模型

image-20240727212003374
micro-batch设置

在固定其他参数的前提下. 只调整micro_batch数, 单个batch的执行时间: \((\frac{b^{'}}{b}+(p-1))\cdot(t_{f} + t_{b})\) , 如果增大b, 单个pipeline内数量减少但执行时间会变长, 计算耗时和b是非线性的关系. 而且调整micro-batch后, 通信耗时也会变化, 所以mirco-batch调整需要实验尝试才能找到最优解. megatron在论文中尝试gpt训练的mirco-batch设置4比较合适

image-20240727220854686

DP的策略分析

便于分析设\(t=1, d * p = n\), 在这种情况下的流水线bubble占比为 \(\frac{p-1}{m} = \frac{n/d - 1}{B/b/d} = \frac{b(n - d)}{B}\)

PP和DP关系: 对于d单调递减, 也从下图可以看到, 当流水线并行的数量越小, 数据并行度越大的时候训练速度越快. 所以我们可以在PP满足显存占用的情况下尽可能的提升DP并行度.

和Batch_size关系: bubble和B成反比, B越大吞吐越高. 但是过大的B和数据并行度会导致模型不收敛. 需要在不影响效果的前提下调整B

image-20240727213002586

DP和TP关系: 在TP内每个batch 需要进行4次allReduce, 而DP只需要对梯度做一次allReduce, 另外在TP的时候如果W比较小也会影响矩阵乘法计算效率. 下图可以看到TP并行度越小, DP并行度越大吞吐越高. 调整策略是只要TP满足显存需求, 就尽可能的增大DP来提升吞吐.

image-20240727220023520

[!TIP]

如果模型比较大,需要先组合模型并行和流水并行,\(M=t \cdot p\) 的组合用来满足模型和模型相关的数据的显存需求,但是要让 M 尽量小。之后使用数据并行来扩展训练规模(扩大数据并行度、扩大 Global batch size)

GPT-3例子分析

以如下的超参GPT-3训练为例:

image-20240729183643039

显存分析

ModelMemory

单卡存储模型参数主要包含4个部分(由于流水线并行, 单卡一般只存储1-2个transformLayer): attention参数 / FC参数 / token_emb / Positional encoding

image-20240729184709396

以\(N_p\)代表一份全量参数, 则单卡上包含的参数量如下:

\[\frac{N_p}{n} = h * \frac{h}{t}*3(QKV参数)+h*\frac{h}{t}(multihead拼接后经过的fc) + h*\frac{h}{t}*4*2(fc1+fc2参数) \\ + \frac{v}{t} *h(token) + s*\frac{h}{t}(positional) \approx 1.73B\approx \frac{175B}{p*t} (单卡上单份DP参数量) \]

在混合精度训练中, 总共的数据量包扩1份fp16的w和grad, 1份fp32的optimizer_state(\(w+grad+momentum+variance\))

\[N_{storage} = 2Bytes * N_p + 2Bytes * N_p + (4+4+4+4)Bytes*N_p = 20N_P = 27.4GB \]

Activation

在nvidia分享里, 看着activation只存了过token前的emb激活和进fc前的激活, 剩下的全部都是bp时重计算的..因为也没使用SP, 这里每张卡的激活都存了TP并行数的冗余数据

\[M_{act}^{emb} = 2

标签:总结,显存,frac,并行,batch,TP,LLM,size
From: https://www.cnblogs.com/sunstrikes/p/18331127

相关文章

  • Pinely Round 4 (Div. 1 + Div. 2) 赛后总结
    PinelyRound4(Div.1+Div.2)赛时提交情况:CF1991A.MaximizetheLastElement赛时思路首先,CF判断了足足2min确定我是真人,看到题目时首先想到的是,最后保留的数字之前及之后必然有偶数个数字,且\(n\)为奇数,所以我们可以确定若\(a_i\)是最后保留的数字,\(i\)必然为奇......
  • STL用法总结(二)(deque,map,set)
    4.deque(双端队列)1.介绍首尾都可插入和删除的队列为双端队列#include<deque>//初始化定义deque<int>dq;2.方法函数代码含义q.push_back(x)/pusu_front(x)把x插入队尾/队首q.back()/front()返回队尾/队首元素q.pop_back()/pop_front()删除队尾/队首元素q.erase(ite......
  • 【前端 · 面试 】HTTP 总结(十一)—— HTTPS 概述
    最近我在做前端面试题总结系列,感兴趣的朋友可以添加关注,欢迎指正、交流。争取每个知识点能够多总结一些,至少要做到在面试时,针对每个知识点都可以侃起来,不至于哑火。HTTPS前言通过前面内容的学习,相信大家对HTTP的概念、特点、请求方法及缓存等的......
  • Audio and LLM
    语音交互的三驾马车:ASR、NLP、TTShttps://www.woshipm.com/ai/2620327.html01 语音交互的组成我们以一个智能音箱的例子来开始今天的讨论:假设我们对智能音箱天猫精灵说“放一首周杰伦的《晴天》”。天猫精灵就会说“好的,马上为你播放周杰伦的《晴天》”,并且开始播放音乐。......
  • 2024.7.25 模拟赛总结
    T1icanStatement:给定一个有\(n(1\len\le10^7)\)个元素的序列\(a(a_i\le10^9)\),求满足\(i<j\)且\(a_i<a_j\)的点对\((i,j)\)中\(j-i\)的最大值。Solution:考虑什么样的\(a_i\)可能作为点对中较靠左边的元素出现。显然对于一个\(k>i\)且\(a_k......
  • 前端实现图片或视频预览的三种方法总结
    在前端开发中,实现图片或视频预览是一个常见的需求。以下是三种常用的方法来实现图片或视频的预览功能:1.HTML标签直接预览1.1图片预览使用<img>标签可以直接显示图片,并且可以通过src属性动态加载图片。<imgsrc="image.jpg"alt="预览图"/>1.2视频预览使用<video>......
  • CompressGraph: 基于规则的高效并行图分析压缩方法
    背景随着数据爆炸式增长,图数据分析在社交网络、科学计算和数据挖掘等领域变得越来越重要。然而,处理大规模图数据面临着存储和计算资源的挑战。传统的图压缩方法可能会丢失重要信息,影响分析结果的准确性。CompressGraph框架旨在通过规则基压缩技术,在有效压缩图数据的同......
  • spellman电源维修XRM50P50X3839 NY11788
    电源维修的常见故障包括:无法开机、电源烧、短路、输出偏小、电源不通电、电源风扇不转,无输出,缺项,输出过高,电源烧毁,灯不亮,不动作等故障维修。Spellman的专有高压技术,再加上MT电路,导致了一个紧凑和轻量级的模块,是理想的OEM应用布置来获得的高压输出,而较低的电压单元则采用稳健......
  • Pinely Round 4 (Div. 1 + Div. 2) 复盘总结
    PinelyRound4(Div.1+Div.2)发挥到极致了,写出了两题A.MaximizetheLastElement对于每个满足他左边的数的个数和他后面的数的个数都是奇数的数,取最大值即可。#include<bits/stdc++.h>usingnamespacestd;typedeflonglongll;//#defineintlonglong#defi......
  • Python内置模块常用总结
    目录unittestunittest点击查看代码importunittestfromhtmltestreportimportHTMLTestReportclassTestOne(unittest.TestCase):@classmethoddefsetUpClass(cls)->None:print(f'setUpClass')@classmethoddeftearDownClass......