首页 > 其他分享 >实操教程|PyTorch实现断点继续训练

实操教程|PyTorch实现断点继续训练

时间:2024-06-13 10:57:52浏览次数:8  
标签:optimizer state epoch checkpoint PyTorch dict 实操 lr 断点

作者丨HUST小菜鸡(已授权)

编辑丨极市平台

最近在尝试用CIFAR10训练分类问题的时候,由于数据集体量比较大,训练的过程中时间比较长,有时候想给停下来,但是停下来了之后就得重新训练,之前师兄让我们学习断点继续训练及继续训练的时候注意epoch的改变等,今天上午给大致整理了一下,不全面仅供参考

Epoch:  9 | train loss: 0.3517 | test accuracy: 0.7184 | train time: 14215.1018  sEpoch:  9 | train loss: 0.2471 | test accuracy: 0.7252 | train time: 14309.1216  sEpoch:  9 | train loss: 0.4335 | test accuracy: 0.7201 | train time: 14403.2398  sEpoch:  9 | train loss: 0.2186 | test accuracy: 0.7242 | train time: 14497.1921  sEpoch:  9 | train loss: 0.2127 | test accuracy: 0.7196 | train time: 14591.4974  sEpoch:  9 | train loss: 0.1624 | test accuracy: 0.7142 | train time: 14685.7034  sEpoch:  9 | train loss: 0.1795 | test accuracy: 0.7170 | train time: 14780.2831  s绝望!!!!!训练到了一定次数发现训练次数少了,或者中途断了又得重新开始训练

一、模型的保存与加载

PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)

torch.save主要参数:obj:对象 、f:输出路径

torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu

模型的保存的两种方法:

1、保存整个Module

torch.save(net, path)

2、保存模型参数

state_dict = net.state_dict()torch.save(state_dict , path)

二、模型的训练过程中保存

checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }

将网络训练过程中的网络的权重,优化器的权重保存,以及epoch 保存,便于继续训练恢复

在训练过程中,可以根据自己的需要,每多少代,或者多少epoch保存一次网络参数,便于恢复,提高程序的鲁棒性。

checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }    if not os.path.isdir("./models/checkpoint"):        os.mkdir("./models/checkpoint")    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
通过上述的过程可以在训练过程自动在指定位置创建文件夹,并保存断点文件

图片

三、模型的断点继续训练

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch

指出这里的是否继续训练,及训练的checkpoint的文件位置等可以通过argparse从命令行直接读取,也可以通过log文件直接加载,也可以自己在代码中进行修改。

四、重点在于epoch的恢复

start_epoch = -1

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch

for epoch in  range(start_epoch + 1 ,EPOCH):    # print('EPOCH:',epoch)    for step, (b_img,b_label) in enumerate(train_loader):        train_output = model(b_img)        loss = loss_func(train_output,b_label)        # losses.append(loss)        optimizer.zero_grad()        loss.backward()        optimizer.step()

通过定义start_epoch变量来保证继续训练的时候epoch不会变化

图片

断点继续训练

一、初始化随机数种子

import torchimport randomimport numpy as np
def set_random_seed(seed = 10,deterministic=False,benchmark=False):    random.seed(seed)    np.random(seed)    torch.manual_seed(seed)    torch.cuda.manual_seed_all(seed)    if deterministic:        torch.backends.cudnn.deterministic = True    if benchmark:        torch.backends.cudnn.benchmark = True

关于torch.backends.cudnn.deterministic和torch.backends.cudnn.benchmark详见

Pytorch学习0.01:cudnn.benchmark= True的设置

https://www.cnblogs.com/captain-dl/p/11938864.html

pytorch---之cudnn.benchmark和cudnn.deterministic_人工智能_zxyhhjs2017的博客

https://blog.csdn.net/zxyhhjs2017/article/details/91348108

图片

benchmark用在输入尺寸一致,可以加速训练,deterministic用来固定内部随机性

二、多步长SGD继续训练

在简单的任务中,我们使用固定步长(也就是学习率LR)进行训练,但是如果学习率lr设置的过小的话,则会导致很难收敛,如果学习率很大的时候,就会导致在最小值附近,总会错过最小值,loss产生震荡,无法收敛。所以这要求我们要对于不同的训练阶段使用不同的学习率,一方面可以加快训练的过程,另一方面可以加快网络收敛。

采用多步长 torch.optim.lr_scheduler的多种步长设置方式来实现步长的控制,lr_scheduler的各种使用推荐参考如下教程:

【转载】 Pytorch中的学习率调整lr_scheduler,ReduceLROnPlateau

https://www.cnblogs.com/devilmaycry812839668/p/10630302.html

所以我们在保存网络中的训练的参数的过程中,还需要保存lr_scheduler的state_dict,然后断点继续训练的时候恢复

#这里我设置了不同的epoch对应不同的学习率衰减,在10->20->30,学习率依次衰减为原来的0.1,即一个数量级lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)optimizer = torch.optim.SGD(model.parameters(),lr=0.1)
for epoch in range(start_epoch+1,80):    optimizer.zero_grad()    optimizer.step()    lr_schedule.step()
    if epoch %10 ==0:        print('epoch:',epoch)        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])
lr的变化过程如下:
epoch: 10learning rate: 0.1epoch: 20learning rate: 0.010000000000000002epoch: 30learning rate: 0.0010000000000000002epoch: 40learning rate: 0.00010000000000000003epoch: 50learning rate: 1.0000000000000004e-05epoch: 60learning rate: 1.0000000000000004e-06epoch: 70learning rate: 1.0000000000000004e-06

我们在保存的时候,也需要对lr_scheduler的state_dict进行保存,断点继续训练的时候也需要恢复lr_scheduler

#加载恢复if RESUME:    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch    lr_schedule.load_state_dict(checkpoint['lr_schedule'])#加载lr_scheduler

#保存for epoch in range(start_epoch+1,80):
    optimizer.zero_grad()
    optimizer.step()    lr_schedule.step()

    if epoch %10 ==0:        print('epoch:',epoch)        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])        checkpoint = {            "net": model.state_dict(),            'optimizer': optimizer.state_dict(),            "epoch": epoch,            'lr_schedule': lr_schedule.state_dict()        }        if not os.path.isdir("./model_parameter/test"):            os.mkdir("./model_parameter/test")        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(epoch)))

三、保存最好的结果

每一个epoch中的每个step会有不同的结果,可以保存每一代最好的结果,用于后续的训练

第一次实验代码

RESUME = True
EPOCH = 40LR = 0.0005

model = cifar10_cnn.CIFAR10_CNN()
print(model)optimizer = torch.optim.Adam(model.parameters(),lr=LR)loss_func = nn.CrossEntropyLoss()
start_epoch = -1

if RESUME:    path_checkpoint = "./models/checkpoint/ckpt_best_1.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch


for epoch in  range(start_epoch + 1 ,EPOCH):    # print('EPOCH:',epoch)    for step, (b_img,b_label) in enumerate(train_loader):        train_output = model(b_img)        loss = loss_func(train_output,b_label)        # losses.append(loss)        optimizer.zero_grad()        loss.backward()        optimizer.step()
        if step % 100 == 0:            now = time.time()            print('EPOCH:',epoch,'| step :',step,'| loss :',loss.data.numpy(),'| train time: %.4f'%(now-start_time))
    checkpoint = {        "net": model.state_dict(),        'optimizer':optimizer.state_dict(),        "epoch": epoch    }    if not os.path.isdir("./models/checkpoint"):        os.mkdir("./models/checkpoint")    torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))

更新实验代码

optimizer = torch.optim.SGD(model.parameters(),lr=0.1)lr_schedule = torch.optim.lr_scheduler.MultiStepLR(optimizer,milestones=[10,20,30,40,50],gamma=0.1)start_epoch = 9# print(schedule)

if RESUME:    path_checkpoint = "./model_parameter/test/ckpt_best_50.pth"  # 断点路径    checkpoint = torch.load(path_checkpoint)  # 加载断点
    model.load_state_dict(checkpoint['net'])  # 加载模型可学习参数
    optimizer.load_state_dict(checkpoint['optimizer'])  # 加载优化器参数    start_epoch = checkpoint['epoch']  # 设置开始的epoch    lr_schedule.load_state_dict(checkpoint['lr_schedule'])
for epoch in range(start_epoch+1,80):
    optimizer.zero_grad()
    optimizer.step()    lr_schedule.step()

    if epoch %10 ==0:        print('epoch:',epoch)        print('learning rate:',optimizer.state_dict()['param_groups'][0]['lr'])        checkpoint = {            "net": model.state_dict(),            'optimizer': optimizer.state_dict(),            "epoch": epoch,            'lr_schedule': lr_schedule.state_dict()        }        if not os.path.isdir("./model_parameter/test"):            os.mkdir("./model_parameter/test")        torch.save(checkpoint, './model_parameter/test/ckpt_best_%s.pth' % (str(e


更多科研资源获取:

公·众·号「学长论文指导」回复关键字「156」

【一】上千篇CVPR、ICCV顶会论文
【二】动手学习深度学习、花书、西瓜书等AI必读书籍
【三】机器学习算法+深度学习神经网络基础教程
【四】OpenCV、Pytorch、YOLO等主流框架算法实战教程

标签:optimizer,state,epoch,checkpoint,PyTorch,dict,实操,lr,断点
From: https://blog.csdn.net/2401_83878212/article/details/139648323

相关文章

  • Python中用PyTorch机器学习神经网络分类预测银行客户流失模型|附代码数据
    阅读全文:http://tecdat.cn/?p=8522最近我们被客户要求撰写关于神经网络的研究报告,包括一些图形和统计输出。分类问题属于机器学习问题的类别,其中给定一组特征,任务是预测离散值。分类问题的一些常见示例是,预测肿瘤是否为癌症,或者学生是否可能通过考试在本文中,鉴于银行客户的某些......
  • 在Python中使用LSTM和PyTorch进行时间序列预测|附代码数据
    全文链接:http://tecdat.cn/?p=8145最近我们被客户要求撰写关于LSTM的研究报告,包括一些图形和统计输出。顾名思义,时间序列数据是一种随时间变化的数据类型。例如,24小时内的温度,一个月内各种产品的价格,一年中特定公司的股票价格诸如长期短期记忆网络(LSTM)之类的高级深度学习模型能......
  • pytorch--Matrix相关
    pytorch–Matrix相关1.矩阵生成Tensor,即张量,是PyTorch中的基本操作对象,可以看做是包含单一数据类型元素的多维矩阵。从使用角度来看,Tensor与NumPy的ndarrays非常类似,相互之间也可以自由转换,只不过Tensor还支持GPU的加速。1.1创建一个没有初始化的矩阵x=torch.empty(2,......
  • TIKTOK海外抖音实操班:下载注册/配置/养号/引流/发视频/等等(共17课)
    这个课程教你怎样在Tiktok上注册、设置、养号、吸引流量和发布视频。内容包括市场前景、与抖音的区别、下载和设置、网络环境、注册账号、打造优质号、找热门素材、使用热门标签和音乐、保持视频清晰度、上传技巧、避免账号被降权或封禁、优化流量下降、最佳发布时间和提高完播......
  • 使用PyTorch Profiler进行模型性能分析,改善并加速PyTorch训练
    如果所有机器学习工程师都想要一样东西,那就是更快的模型训练——也许在良好的测试指标之后加速机器学习模型训练是所有机器学习工程师想要的一件事。更快的训练等于更快的实验,更快的产品迭代,还有最重要的一点需要更少的资源,也就是更省钱。熟悉PyTorchProfiler然后就可以启动te......
  • Pytorch 实现简单的 线性回归 算法
    Pytorch实现简单的线性回归算法简单tensor的运算Pytorch涉及的基本数据类型是tensor(张量)和Autograd(自动微分变量)importtorchx=torch.rand(5,3)#产生一个5*3的tensor,在[0,1)之间随机取值y=torch.ones(5,3)#产生一个5*3的Tensor,元素都是1z=x+y......
  • “深度学习之巅:在 CentOS 7 上打造完美Python 3.10 与 PyTorch 2.3.0 环境”学习
    在CentOS7上安装Python3.10和PyTorch2.3.0时,为什么要首先升级OpenSSL?在CentOS7上,默认安装的OpenSSL版本可能不支持Python3.10所需的最新加密标准。因此,为了确保Python3.10能够正常工作,需要先升级OpenSSL到支持这些标准的版本。升级OpenSSL的具体步骤是什么?升级Ope......
  • 笔记本电脑(win11+3060+conda)安装PyTorch踩坑记录
    简而言之,先看你的显卡,打开CMD,输入nvidia-smi,右上角有一个CUDA:XX.X表示当前显卡及当前驱动支持的最高版本CUDA。输入nvidia-smi-q可以看到显卡架构(或者直接去Nvidia官网找你的显卡)。再打开这个连接,查看你显卡架构支持的最低版本CUDA。从中选择一个cuda版本torch是自带了cu......
  • 【Pytorch】一文向您详细介绍 torch.nn.DataParallel() 的作用和用法
    【Pytorch】一文向您详细介绍torch.nn.DataParallel()的作用和用法 下滑查看解决方法......
  • 【Pytorch】一文向您详细介绍 nn.MultiheadAttention() 的作用和用法
    【Pytorch】一文向您详细介绍nn.MultiheadAttention()的作用和用法 下滑查看解决方法......