保存模型有两种方式,方式不同,在调用模型的时候也不同
我更建议用torch.jit。。。这样不需要在写模型的参数
torch.save
保存模型: import torch import torch.nn as nn # 假设 model 是你的 PyTorch 模型 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc = nn.Linear(10, 1) model = SimpleModel() # 保存模型到文件 torch.save(model.state_dict(), 'model.pth')
解释:model.state_dict()
返回模型的参数字典,torch.save
将这个字典保存到名为model.pth
的文件中。
调用模型: import torch import torch.nn as nn # 假设 model 是你的 PyTorch 模型 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc = nn.Linear(10, 1) model = SimpleModel() # 加载模型参数 model.load_state_dict(torch.load('model.pth')) # 将模型设为评估模式(如果是测试模型) model.eval() outputs = model(data.float())
torch.jit.script
TorchScript — PyTorch 2.1 documentation
torch.jit
模块是 PyTorch 中的即时(just-in-time)编译模块,提供了一种将 PyTorch 模型转换为脚本(script)或 Torch 脚本(TorchScript)的方法。Torch 脚本是一种中间表示形式,可以在不依赖 Python 解释器的情况下在 PyTorch 中运行。
可以将整个模型保存为一个 Torch 脚本文件,而不仅仅是模型的参数。这样做可以更轻松地保存和加载整个模型。
保存模型:
import torch import torch.jit # model 是我的 PyTorch 模型 class SimpleModel(torch.nn.Module): def forward(self, x): return x + 1 model = SimpleModel() # 将模型转换为 Torch 脚本 scripted_model = torch.jit.script(model)
# 保存 Torch 脚本到文件
scripted_model.save("scripted_model.pt")
# 调用模型 loaded_model = torch.jit.load("scripted_model.pt")
# 将模型设为评估模式(如果是测试模型) model.eval() outputs = model(data.float())
标签:nn,模型,torch,保存,PyTorch,SimpleModel,model From: https://www.cnblogs.com/mxleader/p/17853291.html