首页 > 其他分享 >PyTorch保存模型断点以及加载断点继续训练

PyTorch保存模型断点以及加载断点继续训练

时间:2023-04-27 15:01:21浏览次数:36  
标签:load ckpt state epoch PyTorch dict 断点 加载

 

 

 

在训练神经网络时,用到的数据量可能很大,训练周期较长,如果半途中断了训练,下次从头训练就会很费时间,这时我们就想断点续训。

一、神经网络模型的保存,基本两种方式:
1. 保存完整模型model, torch.save(model, save_path) 

2. 只保存模型的参数, torch.save(model.state_dict(), save_path) ,多卡训练的话,在保存参数时,使用 model.module.state_dict( ) 。

二、保存模型的断点checkpoint

断点dictionary中一般保存训练的网络的权重参数、优化器的状态、学习率 lr_scheduler 的状态以及epoch 。

checkpoint = {'parameter': model.module.state_dict(),
              'optimizer': optimizer.state_dict(),
              'scheduler': scheduler.state_dict(),
              'epoch': epoch}
 torch.save(checkpoint, './models/checkpoint/ckpt_{}.pth'.format(epoch+1))

三、加载断点继续训练

if resume: # True
load_ckpt = torch.load(ckpt_dir, map_location=device)
load_weights_dict = {k: v for k, v in load_ckpt['parameter'].items()
                                      if model.state_dict()[k].numel() == v.numel()}  # 简单验证
model.load_state_dict(load_weights_dict, strict=False) 

optimizer.load_state_dict(load_ckpt['optimizer']) scheduler.load_state_dict(load_ckpt['scheduler'])
start_epoch = load_ckpt['epoch']+1 iter_epochs = range(start_epoch, args.epochs)

 

标签:load,ckpt,state,epoch,PyTorch,dict,断点,加载
From: https://www.cnblogs.com/booturbo/p/17358917.html

相关文章

  • 利用pytorch深度学习框架验证骰子的合格性
    利用pytorch深度学习框架验证骰子的合格性骰子生产的合格性可以用概率来表达,比如每个面出现的概率大概都是1/6。importtorchfromd2limporttorchasd2lfromtorch.distributionsimportmultinomial#多次扔骰子出现每个面的概率服从多项式分布fair_probs=torch.ones(......
  • 一个有趣的图片加载效果
    日常在业务中会经常使用到图片,而涉及到一些大图的加载等待的时间较长,一般为了用户更好的体验,会使用一些不同的图片加载效果,比如以下几种情况:骨架屏:在页面上用占位框架代替图片,展示出图片的大致结构和区域,给用户一种“正在加载”的视觉体验。进度条:用进度条的形式展示图片的加......
  • element-ui el-dialog中引用组件,为何组件只加载一次
    最近开发项目,页面中引入组件,2次展示,组件中生命周期都不调取,导致网组件中传的值不更新;<el-dialogv-dialogDragtitle="巡检记录":visible.sync="patrolItemVisible":show-close="true":close-on-press-escape="true":close-on-click-modal="true":appen......
  • 第二十三章:动态加载脚本和样式
    学习要点:1.元素位置2.动态脚本3.动态样式本章主要讲解上一章剩余的获取位置的DOM方法、动态加载脚本和样式。一.元素位置上一章已经通过几组属性可以获取元素所需的位置,那么这节课补充一个DOM的方法:getBoundingClientRect()。这个方法返回一个矩形对象,包含四个属性:left、top、ri......
  • 超大文件上传和断点续传的控件
    ​ PHP用超级全局变量数组$_FILES来记录文件上传相关信息的。1.file_uploads=on/off 是否允许通过http方式上传文件2.max_execution_time=30 允许脚本最大执行时间,超过这个时间就会报错3.memory_limit=50M 设置脚本可以分配的最大内存量,防止失控脚本占用过多内存,此......
  • vue-router4 配置懒加载 页面加载时展示loading
     懒加载写法{path:"/",name:"index",component:()=>import("../views/Home.vue"),}创建Loading组件并引入到顶层组件中使用store控制loading组件是否展示包装懒加载写法constlazyLoad=(componentPath)=>{//注意:componentPath不......
  • 超大文件上传和断点续传的组件
    ​ 以ASP.NETCoreWebAPI 作后端 API ,用 Vue 构建前端页面,用 Axios 从前端访问后端 API,包括文件的上传和下载。 准备文件上传的API #region 文件上传  可以带参数        [HttpPost("upload")]        publicJsonResultuploadProject(I......
  • cesium-1-加载影像数据和影像数据基础知识
    1、影像数据的图层类有哪些viewer-->imageryLayers(ImageryLayerCollection类型)-->ImageryLayer类型-->ImageryProvider抽象类viewer下有ImageryLayerCollection类型的imageryLayers用来存放影像数据(可多个),只能是ImageryLayer类型变量(包含影像数据但除了影像数据之外还有......
  • 不同语言加载不同字号,设置到资源文件中,进行引用
    在资源文件夹创建一个类在App.xaml文件中引用这个类的空间命名,并把这个类添加到资源在页面中应用在使用其他语言时,开启新的子线程依然会使用区域语言.net4.5后使用可以一次性解决varculture=newCultureInfo("en-US");    CultureInfo.DefaultThreadCurrent......
  • 类加载器
    类与类加载器任意一个类,都由加载它的类加载器和这个类本身一同确立其在Java虚拟机中的唯一性,每一个类加载器,都有一个独立的类名称空间。因此,比较两个类是否“相等”,只有在这两个类是由同一个类加载器加载的前提下才有意义,否则,即使这两个类来源于同一个Class文件,被同一个虚拟......