首页 > 其他分享 >PyTorch 1.0 中文文档:torch.utils.checkpoint

PyTorch 1.0 中文文档:torch.utils.checkpoint

时间:2023-05-05 11:32:15浏览次数:53  
标签:function 1.0 utils torch RNG checkpoint 模型 输入


译者: belonHan

注意

checkpointing的实现方法是在向后传播期间重新运行已被checkpint的前向传播段。 所以会导致像RNG这类(模型)的持久化的状态比实际更超前。默认情况下,checkpoint包含了使用RNG状态的逻辑(例如通过dropout),与non-checkpointed传递相比,checkpointed具有更确定的输出。RNG状态的存储逻辑可能会导致一定的性能损失。如果不需要确定的输出,设置全局标志(global flag) torch.utils.checkpoint.preserve_rng_state=False 忽略RNG状态在checkpoint时的存取。

torch.utils.checkpoint.checkpoint(function, *args)

checkpoint模型或模型的一部分

checkpoint通过计算换内存空间来工作。与向后传播中存储整个计算图的所有中间激活不同的是,checkpoint不会保存中间激活部分,而是在反向传递中重新计算它们。它被应用于模型的任何部分。

具体来说,在正向传播中,function将以torch.no_grad()方式运行 ,即不存储中间激活,但保存输入元组和 function的参数。在向后传播中,保存的输入变量以及 function会被取回,并且function在正向传播中被重新计算.现在跟踪中间激活,然后使用这些激活值来计算梯度。

Warning
警告

Checkpointing 在 torch.autograd.grad()中不起作用, 仅作用于 torch.autograd.backward().

警告

如果function在向后执行和前向执行不同,例如,由于某个全局变量,checkpoint版本将会不同,并且无法被检测到。

参数:

  • function - 描述在模型的正向传递或模型的一部分中运行的内容。它也应该知道如何处理作为元组传递的输入。例如,在LSTM中,如果用户通过 ,应正确使用第一个输入作为第二个输入(activation, hidden)functionactivationhidden
  • args – 包含输入的元组function

阅读全文/改进本文


标签:function,1.0,utils,torch,RNG,checkpoint,模型,输入
From: https://blog.51cto.com/wizardforcel/6245148

相关文章

  • PyTorch 1.0 中文文档:torch.distributed
    译者:univeryinli后端torch.distributed支持三个后端,每个后端具有不同的功能。下表显示哪些功能可用于CPU/CUDA张量。仅当用于构建PyTorch的实现支持时,MPI才支持CUDA。后端gloompinccl设备CPUGPUCPU————发送✓✘✓接收✓✘✓广播✓✓✓all_reduce✓✓✓reduce✓✘✓all_gather......
  • Pytorch-模型的保存/复用/迁移
    模型的保存与复用模型定义和参数打印#定义模型结构classLenNet(nn.Module):def__init__(self):super(LenNet,self).__init__()self.conv=nn.Sequential(#[batch,1,28,28]nn.Conv2d(1,8,5,2),#[batch,1,28,28]......
  • 无CUDA安装PyTorch
    1.官网选择2.加国内镜像快速下载pip3installtorchtorchvisiontorchaudio-ihttps://pypi.tuna.tsinghua.edu.cn/simple3.验证是否安装成功importtorchprint(torch.__version__)......
  • 超越 PyTorch 和 TensorFlow,这个国产框架有点东西
    By超神经内容概要:都已经有这么多深度学习框架了,为什么还要搞个OneFlow?在机器学习领域,袁进辉看的比90%的人都长远。 关键词:开源  深度学习框架  OneFlow在深度学习领域,PyTorch、TensorFlow等主流框架,毫无疑问占据绝大部分市场份额,就连百度这样级别的公司,也是花费了大量......
  • windows 配置 cuda pytorch
    1.进入 https://pytorch.org,依次选择 PyTorchBuild->YourOS->Package->Language->ComputePlatform,然后会生成安装命令或下载链接,执行或下载安装即可如果没有GPU,ComputePlatform选CPU即可  对于CUDA版本,可以执行cmd命令查看本地显卡支持的版本:nvidia-smi......
  • 【2023 · CANN训练营第一季】昇腾AI入门Pytorch
    昇腾AI全栈架构华为AI全栈全场景解决方案为4层,分别为芯片层、芯片使能层、AI框架层和应用使能层。芯片基于统一、可扩展架构的系列化AIIP和芯片,为上层加速提供硬件基础。芯片产品:昇腾310和昇腾910的独立芯片,Nano-Tiny-Lite的非独立芯片。Ascend层,一切集成电路的核心,主要作用......
  • pytorch模型降低计算成本和计算量
    下面是如何使用PyTorch降低计算成本和计算量的一些方法:压缩模型:使用模型压缩技术,如剪枝、量化和哈希等方法,来减小模型的大小和复杂度,从而降低计算量和运行成本。分布式训练:使用多台机器进行分布式训练,可以将模型训练时间大大缩短,提高训练效率,同时还可以降低成本。硬件加......
  • Python数据库连接池DBUtils
    DBUtils是Python的一个用于实现数据库连接池的模块。安装pip3instal1dbutilspip3instal1pymysql 此连接池有两种连接模式:模式一:为每个线程创建一个连接,线程即使调用了close方法,也不会关闭,只是把连接重新放到连接池,供自己线程再次使用。当线程终止时,连接自动关闭。......
  • 【pytorch】为什么 ToTensor 后紧接 Normalize 操作?
    学习pytorch的transforms一节中产生疑问:ToTensor操作中图像数据满足[0,255]条件会进行线性归一化,映射到[0,1]。在ToTensor操作后一般紧接着Nomalize操作,又进行了一次标准差归一化。既然已经归一化了一次,为什么还要再来一次?以下是我在网络上找到的一些答案:数据如果......
  • 【pytorch】土堆pytorch教程学习(四)Transforms 的使用
    transforms在工具包torchvision下,用来对图像进行预处理:数据中心化、数据标准化、缩放、裁剪、旋转、翻转、填充、噪声添加、灰度变换、线性变换、仿射变换、亮度/饱和度/对比度变换等。transforms本质就是一个python文件,相当于一个工具箱,里面包含诸如Resize、ToTensor、Nor......