常用函数
- 获取当前运行目录(类似c++)
import sys
curent_dir = sys.argv[0]
- 模型保存与读取
import torch
# 保存模型步骤
torch.save(model, 'net.pth') # 保存整个神经网络的模型结构以及参数
torch.save(model, 'net.pkl') # 同上
torch.save(model.state_dict(), 'net_params.pth') # 只保存模型参数
torch.save(model.state_dict(), 'net_params.pkl') # 同上
# 加载模型步骤
model = torch.load('net.pth') # 加载整个神经网络的模型结构以及参数
model = torch.load('net.pkl') # 同上
model.load_state_dict(torch.load('net_params.pth')) # 仅加载参数
model.load_state_dict(torch.load('net_params.pkl')) # 同上
标签:学习指南,load,torch,pth,py,深度,net,model,pkl
From: https://www.cnblogs.com/InsiApple/p/17300302.html