文章目录
0. 概述
Gradient checkpointing 的核心思想是不保存所有层的激活值,而是只保存一部分关键点的激活值。当需要计算某个特定层的梯度时,如果该层的激活值没有被直接保存,那么可以通过重新计算从最近的关键点到该层的前向传播来获得这些激活值。这样做的代价是增加了计算量,因为部分前向传播过程需要重复执行,但可以显著降低内存使用。
下图是一个具有n层的简单前馈神经网络计算图:
其中:
- f f f 表示前向传播的激活计算节点
- b b b 表示反向传播的梯度计算节点
1. 简单反向传播
1.1 整体流程
简单反向传播(Vanilla backpropagation)中:
- 为了在反向传播阶段能够高效地计算梯度,前向传播时所有 f f f 节点都保存在内存中;
- 只有当反向传播进行到足以计算出 f f f 节点的所有依赖项或子节点时,才能将其从内存中删除。
这种方式意味着简单的反向传播所需的内存随着神经网络层数n成线性增长。
执行顺序和使用内存如下:
其中:
- 紫色的圆圈表示保留在内存的计算节点;
- 箭头的指向表示节点的依赖关系,例如:A --> B 表示 B 节点的计算依赖 A 的数据。
1.2 详细说明
为了方便说明清晰,我们将上面流程拆开来看:
(1)前向传播
前向传播的所有节点全部保存在内存中,如图中的紫色节点所示。
(2)反向传播
注意上图中的最后流程,节点1和节点2被用于计算出节点3的梯度后,不再会被使用到,因此在这个流程中可以释放内存。
后续流程类似,如下图所示:
在计算出最后节点的梯度之后,所有的节点都不再被依赖,因此可以全部释放。
1.3 总结
尽管简单反向传播在概念上简单且有效,在计算方面是最佳的(只计算每个节点一次),但它可能不是最高效的实现方式,尤其是在处理大规模或深度网络时,因为其内存需求较高。
2. 初步优化版本
2.1 整体流程
在这个版本中,在每次需要时,重新计算前向传播中的每个节点。
执行顺序和使用内存如下:
2.2 详细说明
为了方便说明清晰,我们还是将上面流程拆开来看:
(1)前向传播
注意上图中最后流程的节点2,因为节点2可以被节点1重新计算出来,因此先可以释放节点2。
后续流程类似,如下图所示:
(2)反向传播
注意上图中最后流程,节点2和节点5同上述一样,可以释放内存。下一步应该计算节点3的梯度,而节点3依赖于节点1和节点4的数据,此时缺失节点1的数据。
因此下面需要重新计算节点1的数据。
后续流程类似,如下图所示:
2.3 总结
使用这种策略,每次前向计算都把当前不需要的节点内存释放了,因此内存使用是最少的。但是计算梯度时每次都需要重新开始计算,计算效率为 O ( n 2 ) O(n^2) O(n2) ,n 为网络层数,计算速度要慢得多,这使得该方法不太适用于大型的深度学习计算。
3. Checkpointed反向传播
3.1 整体流程
在这个版本中,为了在内存和计算之间取得平衡,制定一个允许重新计算节点的策略,但不要太频繁。例如选择的 checkpoint 如下图:
这些 checkpoint 节点在正向传播后保留在内存中,而其余节点最多重新计算一次。重新计算后,非 checkpoint 节点将保留在内存中,直到不再需要它们为止。
执行顺序和使用内存如下:
3.2 详细说明
为了方便说明清晰,我们还是将上面流程拆开来看:
(1)前向传播
注意,我们选择的 checkpoint 节点在前向传播中不释放,一直保留在内存中。
(2)反向传播
注意上图中最后流程,节点2和节点5同上述一样,可以释放内存。下一步应该计算节点3的梯度,而节点3依赖于节点1和节点4的数据,此时缺失节点1的数据。
由于我们保存了 checkpoint 节点,因此此时计算节点1的数据就不用重头开始计算了,可以借助 checkpoint 节点直接计算。
后续流程类似,如下图所示:
checkpoint 节点在反向传播时,如果后续不再需要时同样需要释放。
3.3 总结
该策略在内存和计算之间取得平衡,对于示例中的简单前馈网络,最佳选择是将 n \sqrt{n} n 个节点标记为 checkpoint。由于每个节点最多重新计算一次,所需的额外计算相当于一次前向传播。
4. 补充:内存分配算法
在2016年陈天奇团队在论文《Training Deep Nets with Sublinear Memory Cost》提出了亚线性内存优化相关的 gradient checkpointing 技术,主要思路在上面章节已经阐述了。
论文原文:https://arxiv.org/pdf/1604.06174
在这篇论文中,还提到了计算图中的内存分配算法:
内存分配算法目的是优化内存的动态分配,降低运行过程中实际占用的内存总量。(ps:乍一看还以为是 go 语言垃圾回收的三色标记法,实际上完全不一样)
图中出现的图示说明如下所示:
下面具体对上面的算法展开分析:
(1)计算图的初始化
图中,B 被 C 和 F 依赖,因此引用计数为2。其中 G 没有被任何节点所依赖,但为最后节点,这里初始化设置为1。
(2)分配内存(按照拓扑排序的顺序依次进行)
第一步,给 B 分配内存,标识为红色;
第二步,给 C 分配内存,标识为绿色,同时 B 引用计数减1;
第三步,给 D 分配内存,标识为蓝色,同时 B 引用计数减1,此时为0,则 B 的红色标识的内存回收到空闲内存集中;
第四步,给 E 分配内存,由于空闲内存集中存在已有的红色标识的内存块,可以直接给 E,因此 E 内存也标识为红色,代表内存复用共享。此时,C 的引用计数减1后为0,因此 C 的绿色标识的内存回收到空闲内存集中;
(3)原地操作(Inplace operation)
最后一步需要给 G 分配内存。
如果按照之前的逻辑,那么空闲内存集中存在已有的绿标识的内存块,则应该给 G 复用共享绿色的内存块,然后把 E、F 回收到空闲内存集中,那么结果应该如下图:
当然这么做也没有什么问题,但论文中用到了原地操作的方法,即将输入的内存直接用于保存输出的结果。
具体来讲,将 G 的结果直接覆盖 E 的内存,不再额外引入其他内存块,因此这里 G 的内存块也标识为红色。同时因为计算到最后节点了,F 的内存块也不再回收到空闲内存集中了,可以直接释放。
(4)内存分配计划
最终根据上述流程的内存分配,同样的颜色表示可以共享内存,得到如下的分配计划:
参考
[1] https://github.com/cybertronai/gradient-checkpointing?tab=readme-ov-file
[2] https://arxiv.org/pdf/1604.06174
欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;
欢迎关注知乎/CSDN:SmallerFL
也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤
标签:Gradient,流程,传播,反向,内存,计算,讲解,checkpointing,节点 From: https://blog.csdn.net/qq_36803941/article/details/144240508