首页 > 其他分享 >Gradient checkpointing 核心流程详细讲解

Gradient checkpointing 核心流程详细讲解

时间:2024-12-05 19:29:30浏览次数:13  
标签:Gradient 流程 传播 反向 内存 计算 讲解 checkpointing 节点

文章目录


0. 概述

Gradient checkpointing 的核心思想是不保存所有层的激活值,而是只保存一部分关键点的激活值。当需要计算某个特定层的梯度时,如果该层的激活值没有被直接保存,那么可以通过重新计算从最近的关键点到该层的前向传播来获得这些激活值。这样做的代价是增加了计算量,因为部分前向传播过程需要重复执行,但可以显著降低内存使用。

下图是一个具有n层的简单前馈神经网络计算图:

其中:

  • f f f 表示前向传播的激活计算节点
  • b b b 表示反向传播的梯度计算节点

1. 简单反向传播

1.1 整体流程

简单反向传播(Vanilla backpropagation)中:

  1. 为了在反向传播阶段能够高效地计算梯度,前向传播时所有 f f f 节点都保存在内存中;
  2. 只有当反向传播进行到足以计算出 f f f 节点的所有依赖项或子节点时,才能将其从内存中删除。

这种方式意味着简单的反向传播所需的内存随着神经网络层数n成线性增长。

执行顺序和使用内存如下:
在这里插入图片描述

其中:

  • 紫色的圆圈表示保留在内存的计算节点;
  • 箭头的指向表示节点的依赖关系,例如:A --> B 表示 B 节点的计算依赖 A 的数据。

1.2 详细说明

为了方便说明清晰,我们将上面流程拆开来看:

(1)前向传播


前向传播的所有节点全部保存在内存中,如图中的紫色节点所示。

(2)反向传播

注意上图中的最后流程,节点1和节点2被用于计算出节点3的梯度后,不再会被使用到,因此在这个流程中可以释放内存。

后续流程类似,如下图所示:


在计算出最后节点的梯度之后,所有的节点都不再被依赖,因此可以全部释放。

1.3 总结

尽管简单反向传播在概念上简单且有效,在计算方面是最佳的(只计算每个节点一次),但它可能不是最高效的实现方式,尤其是在处理大规模或深度网络时,因为其内存需求较高。

2. 初步优化版本

2.1 整体流程

在这个版本中,在每次需要时,重新计算前向传播中的每个节点。

执行顺序和使用内存如下:

在这里插入图片描述

2.2 详细说明

为了方便说明清晰,我们还是将上面流程拆开来看:

(1)前向传播


注意上图中最后流程的节点2,因为节点2可以被节点1重新计算出来,因此先可以释放节点2。

后续流程类似,如下图所示:

(2)反向传播

注意上图中最后流程,节点2和节点5同上述一样,可以释放内存。下一步应该计算节点3的梯度,而节点3依赖于节点1和节点4的数据,此时缺失节点1的数据。

因此下面需要重新计算节点1的数据。

后续流程类似,如下图所示:

在这里插入图片描述

2.3 总结

使用这种策略,每次前向计算都把当前不需要的节点内存释放了,因此内存使用是最少的。但是计算梯度时每次都需要重新开始计算,计算效率为 O ( n 2 ) O(n^2) O(n2) ,n 为网络层数,计算速度要慢得多,这使得该方法不太适用于大型的深度学习计算。

3. Checkpointed反向传播

3.1 整体流程

在这个版本中,为了在内存和计算之间取得平衡,制定一个允许重新计算节点的策略,但不要太频繁。例如选择的 checkpoint 如下图:

这些 checkpoint 节点在正向传播后保留在内存中,而其余节点最多重新计算一次。重新计算后,非 checkpoint 节点将保留在内存中,直到不再需要它们为止。

执行顺序和使用内存如下:
在这里插入图片描述

3.2 详细说明

为了方便说明清晰,我们还是将上面流程拆开来看:

(1)前向传播


注意,我们选择的 checkpoint 节点在前向传播中不释放,一直保留在内存中。

(2)反向传播


注意上图中最后流程,节点2和节点5同上述一样,可以释放内存。下一步应该计算节点3的梯度,而节点3依赖于节点1和节点4的数据,此时缺失节点1的数据。

由于我们保存了 checkpoint 节点,因此此时计算节点1的数据就不用重头开始计算了,可以借助 checkpoint 节点直接计算。

后续流程类似,如下图所示:

checkpoint 节点在反向传播时,如果后续不再需要时同样需要释放。

3.3 总结

该策略在内存和计算之间取得平衡,对于示例中的简单前馈网络,最佳选择是将 n \sqrt{n} n ​ 个节点标记为 checkpoint。由于每个节点最多重新计算一次,所需的额外计算相当于一次前向传播。

4. 补充:内存分配算法

在2016年陈天奇团队在论文《Training Deep Nets with Sublinear Memory Cost》提出了亚线性内存优化相关的 gradient checkpointing 技术,主要思路在上面章节已经阐述了。

论文原文:https://arxiv.org/pdf/1604.06174

在这篇论文中,还提到了计算图中的内存分配算法:

内存分配算法目的是优化内存的动态分配,降低运行过程中实际占用的内存总量。(ps:乍一看还以为是 go 语言垃圾回收的三色标记法,实际上完全不一样)

图中出现的图示说明如下所示:

下面具体对上面的算法展开分析:

(1)计算图的初始化


图中,B 被 C 和 F 依赖,因此引用计数为2。其中 G 没有被任何节点所依赖,但为最后节点,这里初始化设置为1。

(2)分配内存(按照拓扑排序的顺序依次进行)

第一步,给 B 分配内存,标识为红色;

第二步,给 C 分配内存,标识为绿色,同时 B 引用计数减1;

第三步,给 D 分配内存,标识为蓝色,同时 B 引用计数减1,此时为0,则 B 的红色标识的内存回收到空闲内存集中;

第四步,给 E 分配内存,由于空闲内存集中存在已有的红色标识的内存块,可以直接给 E,因此 E 内存也标识为红色,代表内存复用共享。此时,C 的引用计数减1后为0,因此 C 的绿色标识的内存回收到空闲内存集中;

(3)原地操作(Inplace operation)

最后一步需要给 G 分配内存。

如果按照之前的逻辑,那么空闲内存集中存在已有的绿标识的内存块,则应该给 G 复用共享绿色的内存块,然后把 E、F 回收到空闲内存集中,那么结果应该如下图:

当然这么做也没有什么问题,但论文中用到了原地操作的方法,即将输入的内存直接用于保存输出的结果。

具体来讲,将 G 的结果直接覆盖 E 的内存,不再额外引入其他内存块,因此这里 G 的内存块也标识为红色。同时因为计算到最后节点了,F 的内存块也不再回收到空闲内存集中了,可以直接释放。

(4)内存分配计划

最终根据上述流程的内存分配,同样的颜色表示可以共享内存,得到如下的分配计划:

参考

[1] https://github.com/cybertronai/gradient-checkpointing?tab=readme-ov-file

[2] https://arxiv.org/pdf/1604.06174


欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

在这里插入图片描述

标签:Gradient,流程,传播,反向,内存,计算,讲解,checkpointing,节点
From: https://blog.csdn.net/qq_36803941/article/details/144240508

相关文章

  • 基于微信小程序的医院挂号就医系统的设计与实现(源码+SQL脚本+LW+部署讲解等)
    专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。主要内容:免费功能设计、开题报告、任务书、中......
  • CMUX-CMUX协议讲解
    目录  1.CMUX协议  2.CMUX协议帧结构  3.示例1.CMUX协议CMUX(ConnectionMultiplexing),是一种串口多路复用协议,其功能主要在一个真实的物理通道上虚拟多个并行的逻辑通信通道的能力,一般应用于TE(TerminalEquipment)与MS(MobileStation)之间,TE相当于智......
  • 基于大数据的滴滴出行数据分析与可视化系统(源码+vue+可视化大屏展示+爬虫分析+讲解等
    收藏关注不迷路!!......
  • XSS漏洞详细讲解(初学者必看)
    前言:要深入理解XSS漏洞,掌握Web应用的基本原理非常关键。XSS攻击本质上是通过注入恶意的JavaScript代码到Web页面中,从而使得攻击者可以在用户浏览页面时执行恶意脚本。因此,理解Web应用如何处理输入、渲染、执行脚本等方面的基本原理非常重要。一、Web应用的基本原理:1.HTM......
  • 哈希表(【通俗易懂】知识点讲解,可速通,小白友好)
    一、哈希表的目的哈希表是用在查找问题中的。我们知道,一条数据包含了关键字和其他信息,所以一般查找问题的流程是:根据某条数据的关键字(key),在一个数据结构中(可能是线性表,也可能其他存储数据的结构),查找这条数据全部的内容。哈希表的目的是,只要知道了要查找数据的关键字,那么就可......
  • 基于微信小程序的校园二手书籍交易平台的设计与实现(源码+LW+讲解和调试)
     目录:博主介绍:  完整视频演示:系统技术介绍:后端Java介绍前端框架Vue介绍具体功能截图:部分代码参考:  Mysql表设计参考:项目测试:项目论文:​为什么选择我:源码获取:博主介绍:  ......
  • 基于SpringBoot的智能旅游网站系统的设计与实现(源码+LW+讲解和调试)
     目录:博主介绍:  完整视频演示:系统技术介绍:后端Java介绍前端框架Vue介绍具体功能截图:部分代码参考:  Mysql表设计参考:项目测试:项目论文:​为什么选择我:源码获取:博主介绍:  ......
  • 基于SpringBoot的游戏分享网站的设计与实现(源码+LW+讲解和调试)
    目录:博主介绍:  完整视频演示:系统技术介绍:后端Java介绍前端框架Vue介绍具体功能截图:部分代码参考:  Mysql表设计参考:项目测试:项目论文:​为什么选择我:源码获取:博主介绍:  ......
  • rcu的实例、注意事项及原理讲解
    一、背景在之前的内核模块里获取当前进程和父进程的cmdline的方法及注意事项,涉及父子进程管理,和rcu的初步介绍-CSDN博客里我们讲到了如何在rcu锁保护的情况下获取一个进程的父进程的pid和comm,另外也贴了一张浓缩了rcu相关概念精华的整理的思维导图。这篇博客里,我们先不涉及rcu......
  • 用比喻的方法大白话讲解SerDes(串行/并行转换器)
    想象你正在参加一个超级忙碌的派对,场面很热闹,每个人都在不停地讲话,但大家使用的是不同的语言。如果你是一个外语专家,你的任务就是把所有这些语言转换成你自己听得懂的语言,然后再把它们传递给你周围的人,让大家都能理解。这时,你就像一个SerDes(串行/并行转换器)!比喻:语言翻译器......