Class TestModle(nn.Module):
def __init__(self):
self.conv = nn.Conv(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
...
def forward(self, x):
...
....
假如有这样一个模型
一、使用状态字典保存模型参数(官方推荐用法)
保存模型
torch.save(model.state_dict(), PATH)
模型一般选择pt后缀结尾
载入模型
由于我们仅仅保存模型的权重参数,没有模型的结构是无法载入参数的
model = TestModel()
model.load_state_dict(torch.load(PATH))
model.eval()
- 注意:这里我们载入参数后使用eval()后再对输入的内容进行推理,因为eval会把模型内的标准化和dropout等功能给禁用了。才能输出正确的推理结果
二、保存整个模型,载入模型(无需加载模型的结构)
保存模型
torhc.save(model, PATH)
加载模型
model = torch.load(PATH)
以上是本人常用的两种方法,实测有效。但是由于pytorch并不熟练,如果想了解更多,可移步到这
https://zhuanlan.zhihu.com/p/82038049