1、模型的保存
1 '''1、模型的保存''' 2 import torch 3 import torchvision 4 5 vgg16=torchvision.models.vgg16(pretrained=False) 6 #保存方式1: 保存网络模型结构也+模型参数 7 torch.save(vgg16,'models/vgg16_method1.pth') #两个参数,第一个是要保存的模型,第二个是保存的路径,后缀名以pth结尾; 8 9 #保存方式2:将网络模型的参数状态保存成字典的形式 (官方推荐) 10 torch.save(vgg16.state_dict(),'models/vgg16_method2.pth')
'''2、模型的加载''' import torch #第一种保存方式所对应的加载模型方式 model1=torch.load('models/vgg16_method1.pth') print(model1) #加载2; model2=torch.load('models/vgg16_method2.pth') print(model2)
第一种方式的输出:
第二种方式的输出:
###如果想把第二种方式的输出恢复成网络结构的形式 #新建一个网络模型,-----然后使用load_state_dict加载保存好的模型的参数 vgg16=torchvision.models.vgg16(pretrained=False) vgg16.load_state_dict(torch.load('models/vgg16_method2.pth')) print(vgg16)
3、第一种方式的缺点
就是当自己定义的网络模型的时候,在加载的时候会报错
#第一种方式的缺点情况:自己创建的网络模型在加载时报错 class lianran(nn.Module): def __init__(self): super().__init__() self.conv=Conv2d(in_channels=3,out_channels=64,kernel_size=3) def forward(self,x): output=self.conv(x) return output li=lianran() torch.save(li,'models/lianran.pth')
#缺点情况的加载: model3=torch.load('models/lianran.pth')
解决:
#缺点情况的加载会报错,解决办法是把模型拿过来,但是不需要构建 #解决办法一:把模型拿过来,但是不需要构建 class lianran(nn.Module): def __init__(self): super().__init__() self.conv=Conv2d(in_channels=3,out_channels=64,kernel_size=3) def forward(self,x): output=self.conv(x) return output model3=torch.load('models/lianran.pth')
#解决办法二:把定义网络模型的py文件加载过来 from models import * model3=torch.load('models/lianran.pth') print(model3)
标签:load,读取,22,pth,vgg16,模型,torch,models From: https://www.cnblogs.com/ar-boke/p/17155286.html