首页 > 其他分享 >训练时的显存优化

训练时的显存优化

时间:2024-04-18 16:44:05浏览次数:26  
标签:显存 loss optimizer 训练 step backward 优化

总览

HuggingFace 的这篇文章总结了一系列节约显存的方法,非常全面。

训练时显存占用的组成:

  • 模型参数
  • 优化器状态
  • 输入张量和其他临时张量
  • 激活值
  • 梯度
  • 通信缓冲

“激活值” 可能有点难理解。这是指像是 dropout 的 mask、LayerNorm 的 \(\mu\ \sigma^2\) 等,不是梯度但参加到梯度计算的张量。

除了用混合精度等方法降低整体显存占用,从 降低显存占用峰值 入手也是有效的。

融合 backward pass 和 optimizer step

通常的训练过程:计算 loss、反向传播、使用优化器 然后 清除梯度。

loss = loss_fn(model(inputs, targets))
loss.backward()
optimizer.step()
optimizer.zero_grad()

这就意味着,我们一次性计算了所有梯度,然后一并应用优化器参数更新。

如果能边算梯度边更新参数,就不需要用大量空间去存储梯度数据了。这就是融合 backward pass 和 optimizer step 的原理,能够有效降低显存占用峰值。

对于 PyTorch Lightning,需要借助 fsdp_overlap_step_with_backward 处理优化器逻辑:

from lightning.fabric.strategies.fsdp import fsdp_overlap_step_with_backward
optimizers = [Optimizer([p], ...) for p in model.parameters()]
...
for inputs, targets in epoch:
    loss = loss_fn(model(inputs), targets)
    with fsdp_overlap_step_with_backward(optimizers, model):
        loss.backward()

若需要原生 PyTorch 实现,可以借助 register_post_accumulate_grad_hook

optimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}

def optimizer_hook(parameter) -> None:
  optimizer_dict[parameter].step()
  optimizer_dict[parameter].zero_grad()

# Register the hook onto every parameter
for p in model.parameters():
   p.register_post_accumulate_grad_hook(optimizer_hook)

# Now remember our previous ``train()`` function? Since the optimizer has been
# fused into the backward, we can remove the optimizer step and zero_grad calls.
def train(model):
  fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()

  loss = model.forward(fake_image)
  loss.sum().backward()

本节参考:

优化器的选择

AdamW 优化器最为常用,调参简单效果好。要说缺点,就是每个参数都需要额外 8 字节的显存。

Adafactor 优化器改变 Adam 的动量思路,将空间占用降低到了 4 字节。但实际使用中发现 Adafactor 可能会导致训练不稳定。

Bitsandbytes 库提供了一系列 8-bit 优化器。其实现的 AdamW8bit 只需占用 2 字节空间。

这个 issue 是包含各种优化器的 benchmark。可以看出,各优化器的训练损失都差不多。这么说,大胆使用 AdamW8bit 节省显存是个不错的主意。

对于参数少、激活多的网络(例如卷积网络),8-bit 优化器的效果不是很明显。

Bitsandbytes 库推荐在使用 8-bit 优化器训练 NLP 模型时,将 embedding 层换为 bitsandbytes.nn.StableEmbedding 以保证训练稳定性。对于其他不稳定的参数,也可以使用 这个文档 提到的方法对那些参数单独使用 32-bit 优化器。

这个知乎问题下 提到 8-bit 优化器可能会让模型容易过拟合。注意一下。

PyTorch Lightning 对 Bitsandbytes 库有支持,可以自动替换用上 Bitsandbytes 的 8-bit 线性层。具体可看官方文档

关闭优化器的 foreach

PyTorch 的优化器默认启用了一个叫 foreach 的 trick,能加快训练。但随之而来的是额外的优化器中间变量占用,会导致峰值显存占用变高。若要关闭 foreach,在定义优化器时传入参数 foreach=False 即可。

本节参考:

标签:显存,loss,optimizer,训练,step,backward,优化
From: https://www.cnblogs.com/chirp/p/18143829

相关文章

  • 京东内部研效架构师训练营,首次对外公开课,不可错过的研效之旅!
    五月繁花似锦,让我们带你走进京东,开启研效实战之旅! 四大单位联合发起本次活动由“全国云计算技术行业产教融合共同体”发起,联合工业和信息化部电子第五研究所、E³CI软件研发效能度量工作委员会、京东云共同主办,重磅推出“卓越研效架构师”研习营,邀请30名企业研发核心管理者......
  • 30天【代码随想录算法训练营34期】第七章 回溯算法part06 (● 332.重新安排行程 ● 51
    332.重新安排行程木有看懂,没视频所以也没看懂51.N皇后自己写出来还是有难度的classSolution:defsolveNQueens(self,n:int)->List[List[str]]:result=[]#存储最终结果的二维字符串数组chessboard=['.'*nfor_inrange(n)]#初始化......
  • Unity性能优化——资源优化(一)
    实际项目中发现的许多问题都是源自无心之过:临时的“测试”更改和疲惫不堪的开发人员的误点击可能会暗地里添加性能不良的资源或更改现有资源的导入设置。对于任何大规模的项目,最好是将防止人为错误作为第一道防线。编写一小段代码来禁止将4K未压缩纹理添加到项目中,是相对简单的......
  • 解决加载GPT2(Tensorflow预训练模型)的Linear权重到PyTorch的Linear权重 形状不匹配(互为
    解决报错内容:RuntimeError:Error(s)inloadingstate_dictforPyTorchBasedGPT2:sizemismatchfortransformer.h.0.attn.c_attn.weight:copyingaparamwithshapetorch.Size([768,2304])fromcheckpoint,theshapeincurrentmodelistorch.Size([2304,768]).........
  • CIFAR10の训练
    CIFAR10の训练一,CIFAR10CIFAR-10是一个更接近普适物体的彩色图像数据集。CIFAR-10是由Hinton的学生AlexKrizhevsky和IlyaSutskever整理的一个用于识别普适物体的小型数据集。一共包含10个类别的RGB彩色图片:飞机(airplane)、汽车(automobile)、鸟类(bird)、猫(cat)、鹿......
  • Ansible:更快点,执行过程分析、异步、效率优化【转】
    Ansible你快点:Ansible执行过程分析、异步、效率优化Ansible虽然方便,但有个"为人诟病"的问题:任务执行速度太慢了,在有大量任务、大量循环任务时,其速度之慢真的是会让人等到崩溃的。Ansible官方给了一些优化选项供用户选择,还可以去网上寻找优化Ansible相关的插件。但在调优Ansible......
  • 29天【代码随想录算法训练营34期】第七章 回溯算法part05 (491.递增子序列 * 46.全排
    491.递增子序列如果在最前面加一个uset=set(),这个就是给这一层一个usedset,很好用,不错classSolution:deffindSubsequences(self,nums:List[int])->List[List[int]]:result=[]self.backtracking(nums,[],result,0)returnresult......
  • HarmonyOS 优化布局性能
    背景介绍 用户界面(UI)布局是应用程序中至关重要的部分,它不仅影响应用的外观和用户体验,还直接影响应用的性能。不合理的布局可能会导致过度的布局计算和界面嵌套,从而增加渲染和计算的开销,导致性能下降。 常用布局方式 HarmonyOS的ArkUI框架提供了多种布局方式,包括线性布局......
  • HarmonyOS 性能优化
    如何合理使用动效来获得更好的性能组件转场动画使用transition:推荐使用转场动画(transition)而不是组件动画(animateTo),因为transition只需要在条件改变时更新一次,而animateTo需要在动画前后做两次属性更新,导致性能开销更大。反例:通过改变透明度属性并使用animateTo来......
  • 试用阿里云GPU服务器进行深度学习模型训练
    试用阿里云GPU服务器进行深度学习模型训练最近在用PyTorch时发现在本地训练模型速度一言难尽,然后发现阿里云可以白嫖gpu服务器,只要没有申请过PAI-DSW资源的新老用户都可以申请5000CU*H的免费额度,三个月内有效。阿里云免费试用活动页面一、申请试用并创建实例点击试用,完成注......