首页 > 其他分享 >万字综述:全面梳理 FP8 训练和推理技术 -- 附录

万字综述:全面梳理 FP8 训练和推理技术 -- 附录

时间:2024-07-23 16:59:22浏览次数:8  
标签:Scale Tensor 综述 -- 梯度 Scaling 所示 FP8

万字综述:全面梳理 FP8 训练和推理技术 -- 附录

原创 AI闲谈 AI闲谈 2024年07月21日 20:02 北京

一、背景

在上一篇文章(万字综述:全面梳理 FP8 训练和推理技术)中我们通过几篇论文具体介绍了 FP8 的发展历程以及在 AI 模型训练和推理中的应用。然而由于篇幅的原因,部分内容并没有具体展开,这篇文章中我们对其补充,并结合代码来介绍。

二、FP8-LM:FP8 梯度和 AllReduce 通信

我们在介绍  [2310.18313] FP8-LM: Training FP8 Large Language Models 时提到其有 FP8 通信,FP8 优化器,以及 FP8 分布式并行训练 3 个方面的优化,但没有具体介绍 FP8 通信是怎么实现的(这个部分比较晦涩),这里进行补充。

如果有 N 个 GPU 要进行梯度聚合,直接使用 FP8 梯度进行梯度聚合会导致精度降低。

如下图所示为 pre-scaling 方案,在求和之前先分别除以 N,此时容易出现 Underflow 的问题:

图片

如下图所示为 post-scaling,其主要区别是在求和之后再除以 N,此时容易出现 Overflow 的问题:

图片

作者提出了相应的优化方案,假设有 4 个 GPU 要进行梯度的聚合,分别有 FP16 的梯度:

图片

假设最终聚合后的 FP16 梯度为:

图片

每一个 FP16 的 Tensor 要转换为 FP8,都是由(FP8 的 Tensor,Scale 值)共同表示。上述 FP16 的 Tensor 对应的 FP8 表示为:

图片

如果想要直接使用 FP8 进行 AllReduce,则需要有一个全局的 Scale 值,否则 Reduce 就不等价了。对应的全局 scale 变量如下所示,其中使用 min 可以避免 Overflow:

图片

复原后的 FP8 梯度如下所示(等价于 FP16 梯度使用相同的 Scale):

图片

通过 AllReduce 聚合后的 FP8 梯度可以表示为:

图片

所以,相当于 ge 的 FP8 表示如下所示,这里有个 Trick,聚合后的梯度并没有除以 N,而是让 Scale 值乘以 N:

图片

作者论文中为什么会介绍 μ 值(Auto Scaling)呢?应该是想要控制

图片

 都尽量在 FP8 的范围内(PS:开源代码里并未实现),比如:

  • 如果发现 

    图片

      中有超过 0.001% 超过了 FP8 表示的最大值(PS:这里是因为 Delayed Scaling 导致?如果每个 Tensor 都使用 Just-in-time Scaling 是不是就都不会超过?),则下一次迭代的时候让 

    图片

     都变为原来的 1/2,降低 Overflow 的风险。

  • 如果在接下来的 1000 次迭代中都没有出现超过 0.001% 的情况,则让 

    图片

     的指数增加 2,降低 Underflow 的风险。

三、FP8 梯度聚合代码实现

上述对应的代码位于 :msamp/nn/distributed.py#L124-L193。如下图所示:

第一步:collect 相应的 Gradient,并初始化 Meta(维护 Scale,amax 相关信息):

图片

第二步:分别求每个 Gradient 的 amax(绝对值的最大值),然后 AllReuce 操作获得全局的最大值,这里求 max 是因为 Scale 一般为 FP8 可表示的最大值 / amax,也就等价于求 Scale 的 min:

图片

第三步:根据全局最大 amax,FP8 可表示的最大值等计算全局 Scale 值,其中的 world_size 也就是 N:

图片

需要说明的是,这里本来应该是 Auto Scaling,也就是对应上述 μ 值的部分,然而作者实际上并没有集成 Auto Scaling,而是使用了经验值 1/sqrt(N),以缓解 Underflow 和 Overflow。可以参考:https://github.com/Azure/MS-AMP/issues/117:

图片

第四步:将 FP16 的 Tensor 转换为 FP8 的 Tensor:

图片

如上图所示存在 Gradient 除以 N 的操作,然而,因为 Gradient 为 ScalingTensor,所以实际除的时候是操作的 Scale 值,如下图所示:

图片

第五步:对 FP8 Gradient 进行 AllReduce 操作:

图片

四、FP8 训练和推理过程

Transformer 模型中最主要的操作就是矩阵乘,也就是 Linear Layer,如下图所示(来自 [2309.17224] Training and inference of large language models using 8-bit floating point)为一个 Linear 操作的伪代码,其核心思路就是在 FP8 矩阵乘之前需要转换的 Tensor 转换为 FP8 类型,如下图红框所示;然后在矩阵乘之后 Unscale 回 FP16,如下图蓝框所示:

图片

而在推理阶段只用离线的的对 Weight 转换一次,Forward 的时候只需对 x 进行相应的 Scale 操作:

图片

五、Scaling 实现方式

在之前的文章中我们介绍过 Scaling 的实现方式,这里在简单概括一下:

  • Static Scaling:提前离线计算好每个 Tensor 的 Scale,然后一直不变。为了保证精度,这种方式通常用于推理阶段的 Weight,如下图所示:

    图片

  • Dynamic Scaling:每次都实时计算每个 Tensor 的 Scale,好处是比较精确,不足是这些计算全部是同步的,如下图所示:

    图片

  • Delayed Scaling:会保存一些之前的多个 Scale 值,计算当前 Tensor 时根据以前的多个 Scale 预估当前 Scale,然后进行 Scaling 操作,同时异步的计算当前的 Scale 值,但是其实现也比较复杂,无状态变为有状态,如下图所示:

    图片

在 Pytorch 的 FP8 实现中(https://github.com/pytorch-labs/float8_experimental/tree/main),早期的测试表明 Delayed Scaling 反而比 Dynamic Scaling 慢,当然也非常接近:

图片

六、参考链接

  1. https://arxiv.org/abs/2310.18313

  2. https://github.com/Azure/MS-AMP/blob/main/msamp/nn/distributed.py#L124-L193

  3. https://github.com/Azure/MS-AMP/issues/117

  4. https://arxiv.org/abs/2309.17224

  5. https://github.com/pytorch-labs/float8_experimental/tree/main

标签:Scale,Tensor,综述,--,梯度,Scaling,所示,FP8
From: https://blog.csdn.net/sinat_37574187/article/details/140635919

相关文章

  • 第四十八天 第十章 单调栈part01 739. 每日温度 496.下一个更大元素 I 503.下一个更大
     739.每日温度 使用单调栈:注意栈中的递增递减顺序。classSolution{public:vector<int>dailyTemperatures(vector<int>&temperatures){vector<int>res(temperatures.size(),0);stack<int>sta;sta.push(0);for(int......
  • C++题目:DNA排序 代码
    题目描述现在有一些长度相等的 ......
  • 掌控 Spring Bean 的生命周期:`@Bean` 注解的执行顺序揭秘
    Java@Bean注解的Bean执行顺序控制引言在Spring框架中,@Bean注解是定义和管理bean的关键。理解如何控制这些bean的创建顺序对于维护复杂的Spring应用程序至关重要。基础知识SpringIoC容器:负责bean的创建、初始化和销毁。@Bean注解:用于在Spring配置类中声明一个方......
  • 美的空调全国售后服务热线电话/美的24小时官方客服热线号码
    美的空调售后服务客服电话:400-778-8380,美的空调24小时售后服务电话400-7788-380人工无转接提示操作,选择售后服务。美的电器服务无忧:家电设计、配送、安装、售后等服务都是由用户当地销售服务中心提供,有问题网上反馈或者拨打美的全国服务热线4007788380,24小时服务到位,不再担心服务......
  • 从零开始NEXT.js(五)——路由组和平行路由
    从零开始NEXT.s(四)——服务器组件上一章我们介绍了服务器组件的内部逻辑,这一章我们重点来讲一下NEXT,js中的页面路由。路由组在我们的app文件夹下,我们可以添加一个又一个文件夹去建立我们的页面路由,当页面过多时找起来就会很复杂,用路由组的形式可以很便捷的收纳我们的路由......
  • 每日一题-P1263
    一眼匈牙利,没有紫啊#include<bits/stdc++.h>usingnamespacestd;#definepbpush_backintn,m,res,a[205][205],p[40005];intid1[205][205],fr1[40005],cnt1,id2[205][205],fr2[40005],cnt2;boolvis[40005];structedge{ intv,nx;}e[40005];intcnt,hd[40005];vo......
  • Java面试八股之详细阐述Spring的DI和IOC
    详细阐述Spring的DI和IOCSpring框架的两大核心特性之一就是控制反转(InversionofControl,IoC),另一个密切相关的是依赖注入(DependencyInjection,DI)。这两个概念是Spring实现松耦合、可测试和可管理软件组件的关键机制。控制反转(InversionofControl,IoC)概念:IoC是一种设......
  • centos7 安装指定版本的chrome + chromedriver
    谷歌浏览器历史版本相关地址:https://www.chromedownloads.net/chrome64win/ 驱动下载地址:https://registry.npmmirror.com/binary.html?path=chromedriver   上传下载好的chrome和chromediriver到centos服务器中解压后上传  安装chromeyumlocalinstall-y......
  • 【Verilog入门】常见的可用于仿真不能综合成硬件的语句及其原因
    在Verilog设计中,不可综合的语句和结构主要是因为它们无法直接映射到实际的硬件实现。以下是详细的解释和每种不可综合语句或结构背后的原因:1.延迟控制语句(#)原因:延迟控制语句用于仿真环境中引入时间延迟,但在实际硬件中没有直接对应的实现。硬件电路的操作是由时钟边沿......
  • SpringBoot升级到3.3.2版本,JDK升级到17,引入Mybatis-plus后启动报错:Property 'sqlSessi
    【问题描述】2024-07-23T15:16:07.174+08:00WARN2604---[questionnaire][main]ConfigServletWebServerApplicationContext:Exceptionencounteredduringcontextinitialization-cancellingrefreshattempt:org.springframework.beans.factory.UnsatisfiedDependen......