首页 > 其他分享 >pytorch保存模型及加载模型

pytorch保存模型及加载模型

时间:2023-07-11 17:47:37浏览次数:51  
标签:模型 保存 载入 pytorch PATH model self 加载

Class TestModle(nn.Module):
	def __init__(self):
		self.conv = nn.Conv(3, 6, 5)
		self.pool = nn.MaxPool2d(2, 2)
		...
	def forward(self, x):
		...
	....

假如有这样一个模型

一、使用状态字典保存模型参数(官方推荐用法)

保存模型

torch.save(model.state_dict(), PATH)
模型一般选择pt后缀结尾

载入模型

由于我们仅仅保存模型的权重参数,没有模型的结构是无法载入参数的

model = TestModel()
model.load_state_dict(torch.load(PATH))
model.eval()
  • 注意:这里我们载入参数后使用eval()后再对输入的内容进行推理,因为eval会把模型内的标准化和dropout等功能给禁用了。才能输出正确的推理结果

二、保存整个模型,载入模型(无需加载模型的结构)

保存模型
torhc.save(model, PATH)
加载模型
model = torch.load(PATH)

以上是本人常用的两种方法,实测有效。但是由于pytorch并不熟练,如果想了解更多,可移步到这
https://zhuanlan.zhihu.com/p/82038049

标签:模型,保存,载入,pytorch,PATH,model,self,加载
From: https://www.cnblogs.com/ohj666/p/17545456.html

相关文章