首页 > 其他分享 >torch保存模型

torch保存模型

时间:2023-11-24 11:13:55浏览次数:34  
标签:nn 模型 torch 保存 PyTorch SimpleModel model

保存模型有两种方式,方式不同,在调用模型的时候也不同

我更建议用torch.jit。。。这样不需要在写模型的参数

torch.save

保存模型:
import torch
import torch.nn as nn

# 假设 model 是你的 PyTorch 模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

model = SimpleModel()

# 保存模型到文件
torch.save(model.state_dict(), 'model.pth')
解释:
model.state_dict() 返回模型的参数字典,torch.save 将这个字典保存到名为 model.pth 的文件中。

  

调用模型:
import torch
import torch.nn as nn

# 假设 model 是你的 PyTorch 模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

model = SimpleModel()

# 加载模型参数
model.load_state_dict(torch.load('model.pth'))

# 将模型设为评估模式(如果是测试模型)
model.eval()
outputs = model(data.float())

  

torch.jit.script

TorchScript — PyTorch 2.1 documentation

torch.jit 模块是 PyTorch 中的即时(just-in-time)编译模块,提供了一种将 PyTorch 模型转换为脚本(script)或 Torch 脚本(TorchScript)的方法。Torch 脚本是一种中间表示形式,可以在不依赖 Python 解释器的情况下在 PyTorch 中运行。

可以将整个模型保存为一个 Torch 脚本文件,而不仅仅是模型的参数。这样做可以更轻松地保存和加载整个模型。

保存模型:
import torch
import torch.jit

# model 是我的 PyTorch 模型
class SimpleModel(torch.nn.Module):
    def forward(self, x):
        return x + 1

model = SimpleModel()

# 将模型转换为 Torch 脚本
scripted_model = torch.jit.script(model)
# 保存 Torch 脚本到文件
scripted_model.save("scripted_model.pt")

 

# 调用模型 
loaded_model = torch.jit.load("scripted_model.pt")

# 将模型设为评估模式(如果是测试模型)
model.eval()
outputs = model(data.float())

  

 

标签:nn,模型,torch,保存,PyTorch,SimpleModel,model
From: https://www.cnblogs.com/mxleader/p/17853291.html

相关文章

  • 深度学习中实现PyTorch和NumPy之间的数据转换知多少?
    在深度学习中,PyTorch和NumPy是两个常用的工具,用于处理和转换数据。PyTorch是一个基于Python的科学计算库,用于构建神经网络和深度学习模型。NumPy是一个用于科学计算的Python库,提供了一个强大的多维数组对象和用于处理这些数组的函数。在深度学习中,通常需要将数据从NumPy数组转换......
  • 使用Python在Tkinter中保存异常
    我为其他使用Tkinter接收用户输入的人开发了几个Python程序。为了保持简单和用户友好,命令行或python控制台永远不会打开(即。.pyw文件),因此,当出现异常时,我正在研究如何使用日志库向文件写入错误文本。然而,我很难让它真正捕获异常。例如:我们编写一个会导致错误的函数:defcause_a......
  • R语言集成模型:提升树boosting、随机森林、约束最小二乘法加权平均模型融合分析时间序
    原文链接:http://tecdat.cn/?p=24148原文出处:拓端数据部落公众号 最近我们被要求撰写关于集成模型的研究报告,包括一些图形和统计输出。特别是在经济学/计量经济学中,建模者不相信他们的模型能反映现实。比如:收益率曲线并不遵循三因素的Nelson-Siegel模型,股票与其相关因素之间的......
  • 在r语言中使用GAM(广义相加模型)进行电力负荷时间序列分析|附代码数据
    原文链接:http://tecdat.cn/?p=9024原文出处:拓端数据部落公众号  最近我们被要求撰写关于GAM的研究报告,包括一些图形和统计输出。用GAM进行建模时间序列我已经准备了一个文件,其中包含四个用电时间序列来进行分析。数据操作将由data.table程序包完成。将提及的智能电表数据......
  • torch用法--张量操作
    创建张量:torch.tensor(data):从数据中创建张量。用列表创建,numpy创建维度只看[]#一维张量data_1d=[1,2,3]tensor_1d=torch.tensor(data_1d)#结果tensor([1,2,3])#二维张量data_2d=[[1,2,3],[4,5,6],[4,5,6]]tensor_2d=torch.tensor(dat......
  • 简单的用Python采集股票数据,保存表格后分析历史数据
    前言字节跳动如果上市,那么钟老板将成为我国第一个世界首富趁着现在还没上市,咱们提前学习一下用Python分析股票历史数据,抱住粗大腿坐等起飞~好了话不多说,我们直接开始正文准备工作环境使用Python3.10解释器Pycharm编辑器模块使用requests—>数据......
  • Torch张量是什么
    定义:在PyTorch中,张量(tensor)是一种类似于多维数组的数据结构,它是PyTorch的核心数据类型。张量可以具有不同的维度,例如标量(0维张量,类似于一个数字)、向量(1维张量,类似于一维数组)、矩阵(2维张量,类似于二维数组)以及更高维度的数组。张量的维度,矩阵的维度主要看第一个数,也就是看行,几行代......
  • 大模型基础
    学习以下文章:揭密Transformer:大模型背后的硬核技术人人都需要掌握的PromptEngineering技巧通俗解读大模型微调(FineTuning)大模型时代的应用创新范式如何理解大模型中的参数?大模型可以看作是数据转换问题,即输入\(X\)序列,输出\(Y\)序列,其中\(Y=WX\),这里的W矩阵就可以......
  • 查看电脑已经保存过的WiFi以及密码
    查询无线网卡所有保存过密码的ssidnetshwlanshowprofiles “所有用户配置文件:”后面接的就是无线网卡连接过的SSID,也就是WiFi名查询对应的SSID的密码信息netshwlanshowprofiles"龙回首"key=clear“”中是SSID的名称, 图片中,关键内容指的就是对应......
  • Mybatis保存多记录,导致SQL过长,保存失败,按指定次数切分,多次保存。
     privatestaticfinalIntegerWORKITEM_MAX_NUMBER=200;privateintsavePbhProblemworkitem(List<ProblemWorkitemVm>problemworkitem){try{intcount=0;intlimit=countStep(problemworkitem.size(),WORKI......