首页 > 其他分享 >阅读笔记:Merak 大模型并行训练系统

阅读笔记:Merak 大模型并行训练系统

时间:2023-05-07 16:45:29浏览次数:45  
标签:训练 并行 模型 子图 笔记 计算 Merak 节点

论文简介

Merak: An Efficient Distributed DNN Training Framework With Automated 3D Parallelism for Giant Foundation Models 这篇论文发表在IEEE TPDS 2023上,主题是提出一种高效的具有三维并行性的自动化分布式训练系统-Merak

论文背景和Motivation

3维并行训练即将数据并行(DP)、张量模型并行(TMP)和流水线模型并行(PMP)集成到多节点分布式训练系统的训练方式。DP和PMP是常用的分布式训练方法,例如DP以参数服务器模式和AllReduce模式为两种主要范式,而PMP也以GPipe和PipeDream为代表性工作。TMP则是由Megatron LM引入基于transformer的模型,即沿着行或列维度划分权重矩阵到不同的训练节点,并添加AllReduce操作以确保操作逻辑正确性。它在(前向传播)FP和(反向传播)BP过程中都会带来大量通信阻塞,这大大减缓了训练阶段的速度。

激活的重新计算:在模型训练过程中,中间的激活变量(activation)需要在反向传播的时候被重新计算,这是一种在大模型训练的情况下内存和计算权衡的结果,这种方法要多花大约1/3的运算开销来降低内存成本。

最新(SOTA)模型训练系统的缺点:

  1. 缺乏通用性:需要有经验的开发人员,手动设置模块或其他代码修改
  2. 低效率:PMP给GPU资源带来了气泡时间(bubble),TMP给张量计算引入了通信时间,因此带宽和计算资源都是低效率的

Merak旨在建立一个系统达到以下目标:
1)通过改进训练方式(DP,PMP,TMP)的整合来加速训练过程。
2)无需修改原始模型即可实现三维并行训练。

下图也说明了Merak在自动训练上的优势:

论文方法

系统的整体架构如下:

Automatic Model Partitioner

Proxy Nodes Graph

Merak是一个基于Pyotch的系统,所以为了模型分割,首先需要做的就是获得完整的计算图,这里Merak使用了torch.tx。但是在大模型场景下,torch.tx直接运行一次训练来获取计算图的思路并不可靠,显存容易溢出。所以Merak提出了一个代理节点图(Proxy Nodes Graph)的概念。

Merak设计代理节点来替代密集计算的计算图节点。代理节点不与任何参数相关联,也不执行计算,但它们可以根据输入返回适当大小的结果,从而正常参与计算图的跟踪确定。代理节点不需要计算的特点允许大模型在CPU中执行一步的推理。

这些代理计算图节点都是对一些特定的运算操作预设的,例如矩阵计算,注意力计算等。

Graph-Sharding Algorithm and Automated PMP

获得了模型的结构(计算图)之后,接下来就是设计一种模型水平划分算法来实时流水线并行(PMP)的模型分割操作。Merak使用的是一种启发式算法,算法原则:子图之间连接的节点数量应该是最小的(为了减少子图间的传输数据量)。因此,将不同的图划分为具有有限连接的顺序子图是其自动PMP划分的基本问题。

接下来,Merak提出了一个引理:一个节点应该保持在具有最远依赖关系的同一个子图中。同时,如果一个节点没有依赖关系,或者其最远的依赖关系是最新创建的子图,则可以将其放置到新的子图中。这个引理保证了前面提到的子图间的连接节点最小。

但是这种启发式的原则太严格,有些例如Mask Attention计算操作就被所有的注意力层引用为输入。所以Merak将一种节点定义为公共节点(commen node),并允许它们成为依赖的例外,以避免所有子图使用。

这里Merak论文提供了两个启发式规则算法。第一个是给定一个计算节点,返回其最远的依赖节点(需要该节点计算完成才能开始计算)的子图索引。第二个是给定一个计算图,返回最佳的子图分割。后者的核心是保证所有的节点和它的依赖节点位于同一个子图中,且当节点依赖为1或者子图超过了显存限制时创建新的子图。

Hign-Performance Training

Shifted Critical Path Schedule

这里论文提出的关键路径下移方法比较暧昧,关键路径(Critical Path)就是整个工作流中时间最长的一条路径,也是最终直接影响流水线计算图的路径。下图给出了几种流水线并行的计算时间图:

上图中(a)即一般的训练方法,其中每个mini-batch都需要等待前一个mini-batch的训练反向传播完成,所以虽然可以流水并行,但后期还是出现了气泡时间。 (b)则是GPipe的训练方法,GPipe划分mini-batch为micro-batch,彼此之间不需要等待参数更新,可以认为训练是独立的,所以减少了后期流水的气泡时间。 (c)是Merak的计算时间图,Merak的关键路径迁移方法即避免最后一个子图(stage)的激活重计算,因为最后一个子图并不需要等待任何节点就可以直接反向传播,所以直接使用前向计算得到的激活即可。

所以这里突出的关键路径下移的的方法就是避免最后子图的重计算。

这种方法被很多论文都提到了,似乎这里是改名重提?

Stage-Aware Recomputation

前面已知激活重计算是一种显存和计算时间的权衡,这里的主题就是根据子图(stage)来判断是否应该重计算激活或者重计算多少的激活。

首先给出数学上的推导,定义\(m\)为micro-batch的数量,\(s\)为子图总数,\(M_r\)为进行前向计算和反向传播在不保存激活时的显存占用大小,\(M_a\)为一个micro-batch里一个激活的大小。当第\(i\)级的子图节点上有\(a_i\)个模块不使用激活重新计算时,它们需要额外的内存占用\((s−i)\alpha_i M_a\)。这里的\(s-i\)来自于流水线图中前面的节点需要多保存一些micro-batch的激活。

优化目标为最大化显存的利用率,为此论文做出了一个假设(似乎没有合理性证明):假设在最佳情况下,每个子图的设备都有相同的显存消耗,并且每个设备都已满载。基于这个假设,论文推导在i和j级子图节点的最佳重计算比例满足:

\[M_r + (s−i)\alpha_i M_a=M_r + (s−j)\alpha_j M_a \]

由此,可以推导出一个迭代计算公式来更新所有的\(a_i\),如下:

其中的\(a_1\)是对第一个子图从1开始逐渐增加,直到保存的激活总量超过了内存限制为止得到的。

Sub-Pipelined TMP

在子流水线TMP中,Merak将TMP块的每个micro-batch均匀地划分为两个sub-micro-batch,它们的训练过程是相互独立的,这样可以在计算和通信交替的资源间隙中填充空闲时间,如图所示:

实验

基本实验部分都齐全,例如总体表现,模块表现等等,并且使用的都是GPT结构模型。这里列出来一个实验结果:

如图所示,在GPU节点较少的时候关键路径下移的效果略好于子图重计算规划,随着GPU数量的增加,效果逐渐反过来了。

讨论

子图重计算的推导,假设资源使用全部一致为最优情况以及几个统一的(无视子图结构)变量\(M_a, M_r\)是否合理?

标签:训练,并行,模型,子图,笔记,计算,Merak,节点
From: https://www.cnblogs.com/medianet-ytc/p/17379359.html

相关文章

  • Django笔记三十七之多数据库操作(补充版)
    本文首发于公众号:Hunter后端原文链接:Django笔记三十七之多数据库操作(补充版)这一篇笔记介绍一下Django里使用多数据库操作。在第二十二篇笔记中只介绍了多数据库的定义、同步命令和使用方式,这一篇笔记作为补充详细介绍如何对Django系统的多个数据库进行针对的建表同步操......
  • 学习笔记:MySQL常用的一些SQL语句
    本文谈谈MySQL的开发必会的sql语句创建数据库createdatabasedb1;删除数据库dropdatabasedb1;创建数据表createtabletb1用户表(idintnotnullauto_increment primarykey,namechar(10),                     department_idint,            ......
  • AC 自动机学习笔记
    前置知识:\(\texttt{trie}\)树。不会的话到这篇博客看看吧。前置知识:\(\texttt{kmp}\)。不会的话到这篇博客看看吧。字符串好的题单。下面设所有字符串的大小之和为\(|\Sigma|\)。\(\texttt{AC}\)自动机(也叫\(\texttt{ACAM}\))\(\texttt{ACAM}\)时为了解决\(\lceil\)多个......
  • yazi框架学习笔记
    主线程监听和建立客户端的连接接收客户端的请求数据,创建一个任务,该任务携带请求数据,并把该任务放入任务队列告诉分发线程,有请求任务过来了,叫他赶紧去处理重复上面三个步骤注意:主线程不处理具体请求分发线程查看任务队列,看是否有请求任务?没有任务则继续睡觉,否则把任务取......
  • 「学习笔记」双连通分量、割点与桥
    文章图片全部来自Oi-wiki,部分图片加以修改前面我们在学tarjan算法时,提到过强连通分量,即有向图上的环,那么无向图上是否也有强连通分量呢?很遗憾,没有但是,无向图有双连通分量!分为点双连通和边双连通(下面简称点双和边双)。边双连通分量概念在一张联通的无向图中,对于两个点\(x......
  • 茶文化笔记
    绪论茶文化概念“文化”是人类社会生活中提炼出来的,以精神领域为主要内容的成果。中国茶文化正是在中华民族持久茶饮活动中所提炼出来的具有中华民族特色的成果。茶形成茶文化的特殊性茶具有从物质文明到精神文明的特性茶的利用过程见证中国历史及各阶层茶文化:广义是指......
  • Linux驱动开发笔记(一):helloworld驱动源码编写、makefile编写以及驱动编译基本流程
    前言  基于linux的驱动开发学习笔记,本篇是描述了一个字符驱动的基础开发流程,以便做嵌入式开发多年的应用或者系统学习驱动开发。 笔者自身情况  笔者拥有硬件基础,单片机软硬基础,linux系统基础等各种,就是没有linux驱动框架基础,未做过linux系统移植和驱动移植开发了......
  • C/C++网络编程笔记Socket
    https://www.bilibili.com/video/BV11Z4y157RY/?vd_source=d0030c72c95e04a14c5614c1c0e6159b上面链接是B站的博主教程,源代码来自上面视频,侵删,这里只是做笔记,以供复习和分享。上一篇博客我记录了配置环境并且跑通了,以及碰到的一些问题这篇文章是对socket的代码解读笔记。先把服务端......
  • 【笔记】跟吴恩达和IsaFulford学提示词工程(初级开发者入门课程)
    标签:#Prompt#LLM创建时间:2023-04-2817:05:45链接:课程(含JupyterNotebook),中文版讲师:AndrewNg,IsaFulford阅读提示这是一篇入门的教程,入门的意思是指大部分内容,可能你都已经知道了,但是知道不等于掌握,Prompt是一门实践经验主义科学,LLM是个黑盒,你只要不断去“实践”才能......
  • [附课程学习笔记]CS231N assignment 3#1 _ RNN 学习笔记 & 解析
    欢迎来到assignment3从现在开始,网上的博客数量就少了很多.毕竟从现在,我们开始了更具体网络的学习.这里的组织形式可能会比较怪,我会将RNN相关的课程内容和代码混在一起,这样也可以同时作为学习笔记,也是考虑到RNN之后没有官方讲义之后自己概括性的评说,感觉比较好组织.......