总览
HuggingFace 的这篇文章总结了一系列节约显存的方法,非常全面。
训练时显存占用的组成:
- 模型参数
- 优化器状态
- 输入张量和其他临时张量
- 激活值
- 梯度
- 通信缓冲
“激活值” 可能有点难理解。这是指像是 dropout 的 mask、LayerNorm 的 \(\mu\ \sigma^2\) 等,不是梯度但参加到梯度计算的张量。
除了用混合精度等方法降低整体显存占用,从 降低显存占用峰值 入手也是有效的。
融合 backward pass 和 optimizer step
通常的训练过程:计算 loss、反向传播、使用优化器 然后 清除梯度。
loss = loss_fn(model(inputs, targets))
loss.backward()
optimizer.step()
optimizer.zero_grad()
这就意味着,我们一次性计算了所有梯度,然后一并应用优化器参数更新。
如果能边算梯度边更新参数,就不需要用大量空间去存储梯度数据了。这就是融合 backward pass 和 optimizer step 的原理,能够有效降低显存占用峰值。
对于 PyTorch Lightning,需要借助 fsdp_overlap_step_with_backward
处理优化器逻辑:
from lightning.fabric.strategies.fsdp import fsdp_overlap_step_with_backward
optimizers = [Optimizer([p], ...) for p in model.parameters()]
...
for inputs, targets in epoch:
loss = loss_fn(model(inputs), targets)
with fsdp_overlap_step_with_backward(optimizers, model):
loss.backward()
若需要原生 PyTorch 实现,可以借助 register_post_accumulate_grad_hook
:
optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}
def optimizer_hook(parameter) -> None:
optimizer_dict[parameter].step()
optimizer_dict[parameter].zero_grad()
# Register the hook onto every parameter
for p in model.parameters():
p.register_post_accumulate_grad_hook(optimizer_hook)
# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()
loss = model.forward(fake_image)
loss.sum().backward()
本节参考:
- https://lightning.ai/pages/community/tutorial/faster-pytorch-training-by-reducing-peak-memory/
- https://pytorch.org/tutorials/intermediate/optimizer_step_in_backward_tutorial.html
优化器的选择
AdamW 优化器最为常用,调参简单效果好。要说缺点,就是每个参数都需要额外 8 字节的显存。
Adafactor 优化器改变 Adam 的动量思路,将空间占用降低到了 4 字节。但实际使用中发现 Adafactor 可能会导致训练不稳定。
Bitsandbytes 库提供了一系列 8-bit 优化器。其实现的 AdamW8bit 只需占用 2 字节空间。
这个 issue 是包含各种优化器的 benchmark。可以看出,各优化器的训练损失都差不多。这么说,大胆使用 AdamW8bit 节省显存是个不错的主意。
对于参数少、激活多的网络(例如卷积网络),8-bit 优化器的效果不是很明显。
Bitsandbytes 库推荐在使用 8-bit 优化器训练 NLP 模型时,将 embedding 层换为
bitsandbytes.nn.StableEmbedding
以保证训练稳定性。对于其他不稳定的参数,也可以使用 这个文档 提到的方法对那些参数单独使用 32-bit 优化器。这个知乎问题下 提到 8-bit 优化器可能会让模型容易过拟合。注意一下。
PyTorch Lightning 对 Bitsandbytes 库有支持,可以自动替换用上 Bitsandbytes 的 8-bit 线性层。具体可看官方文档。
关闭优化器的 foreach
PyTorch 的优化器默认启用了一个叫 foreach 的 trick,能加快训练。但随之而来的是额外的优化器中间变量占用,会导致峰值显存占用变高。若要关闭 foreach,在定义优化器时传入参数 foreach=False
即可。
本节参考:
标签:显存,loss,optimizer,训练,step,backward,优化 From: https://www.cnblogs.com/chirp/p/18143829