首页 > 其他分享 >torch.save(),torch.load(),state_dict(),load_state_dict()

torch.save(),torch.load(),state_dict(),load_state_dict()

时间:2024-04-16 10:23:32浏览次数:23  
标签:load torch state dict model 加载

这些函数是PyTorch中用于模型保存和加载的重要函数。下面是对它们的详细解析:

  1. torch.save(obj, file):

    • 作用:将PyTorch模型保存到文件中。

    • 参数:

      • obj: 要保存的对象,可以是模型、张量或字典。
      • file: 要保存到的文件路径。
    • 示例:

      torch.save(model.state_dict(), 'model.pth')
      
  2. torch.load(file):

    • 作用:从文件中加载保存的PyTorch模型。

    • 参数:

      • file: 要加载的文件路径。
    • 返回值:加载的对象。

    • 示例:

      model.load_state_dict(torch.load('model.pth'))
      
  3. state_dict():

    • 作用:返回包含模型所有参数的字典对象。

    • 示例:

      model_state = model.state_dict()
      
  4. load_state_dict(state_dict, strict=True):

    • 作用:加载预训练的参数字典到模型中。

    • 参数:

      • state_dict: 要加载的参数字典。
      • strict(可选): 如果为True(默认值),则要求state_dict中的键与模型的参数名完全匹配。
    • 示例:

      model.load_state_dict(torch.load('pretrained.pth'))
      

这些函数在训练过程中非常有用,可以帮助保存模型的状态以及加载预训练的参数,使得模型的训练和部署更加方便。



标签:load,torch,state,dict,model,加载
From: https://www.cnblogs.com/keye/p/18137550

相关文章

  • Pytorch Dataset入门
    ​Dataset入门PytorchDatasetcode:torch/utils/data/dataset.py#L17PytorchDatasettutorial:tutorials/beginner/basics/data_tutorial.html 理论:PyTorch中的Dataset是一个抽象类,用来表示数据集的接口,所有其他数据集都需要继承这个类,并且覆写以下三个方法:__init__:......
  • Pytorch计算机视觉实战(更新中)
    第一章人工神经网络基础1.1人工智能与传统机器学习学习心得:传统机器学习(ML):需要专业的主题专家人工提取特征,并通过一个编写良好的算法来破译给定的特征,从而判断这幅图像中的内容。输入-->人工提取特征-->特征-->具有浅层结构的分类器-->输出当存在欺骗性的图片出现时可能会......
  • Pytorch DistributedDataParallel(DDP)教程二:快速入门实践篇
    一、简要回顾DDP在上一篇文章中,简单介绍了Pytorch分布式训练的一些基础原理和基本概念。简要回顾如下:1,DDP采用Ring-All-Reduce架构,其核心思想为:所有的GPU设备安排在一个逻辑环中,每个GPU应该有一个左邻和一个右邻,设备从它的左邻居接收数据,并将数据汇总后发送给右邻。通过N轮迭代......
  • Pytorch分类模型的训练框架
    Pytorch分类模型的训练框架PhotoDataset数据集是自己定义的数据集,数据集存放方式为:----image文件夹--------0文件夹--------------img1.jpg--------------img2.jpg--------1文件夹--------------img1.jpg--------------img2.jpg....如果是cpu训练的话,就把代码中的.cu......
  • 05、IS-IS Overload
    IS-ISOverloadIS-ISOverload使用IS-IS过载标记位来标识过载状态。IS-IS过载标志位是指IS-ISLSP报文中的OL字段。对设备设置过载标志位后,其它设备在进行SPF计算时不会使用这台设备做转发,只计算该设备上的直连路由。图1 IS-IS过载示意图 如图1所示,RouterA到10.1.1.0/24......
  • 时空图神经网络ST-GNN的概念以及Pytorch实现
    在我们周围的各个领域,从分子结构到社交网络,再到城市设计结构,到处都有相互关联的图数据。图神经网络(GNN)作为一种强大的方法,正在用于建模和学习这类数据的空间和图结构。它已经被应用于蛋白质结构和其他分子应用,例如药物发现,以及模拟系统,如社交网络。标准的GNN可以结合来自其他机器......
  • Pytorch DistributedDataParallel(DDP)教程一:快速入门理论篇
    一、写在前面随着深度学习技术的不断发展,模型的训练成本也越来越高。训练一个高效的通用模型,需要大量的训练数据和算力。在很多非大模型相关的常规任务上,往往也需要使用多卡来进行并行训练。在多卡训练中,最为常用的就是分布式数据并行(DistributedDataParallel,DDP)。但是现有的......
  • GRPC - Load testing Production-Grade APIs
      https://ghz.sh/  ......
  • Conditional AutoEncoder的Pytorch完全实现
    一个完整的深度学习流程,必须包含的部分有:参数配置、Dataset和DataLoader构建、模型与optimizer与Loss函数创建、训练、验证、保存模型,以及读取模型、测试集验证等,对于生成模型来说,还应该有重构测试、生成测试。AutoEncoder进能够重构见过的数据、VAE可以通过采样生成新数据,对于MN......
  • GRPC - Distributing requests with load balancing
         ......