首页 > 其他分享 >22、模型的保存与读取

22、模型的保存与读取

时间:2023-02-25 20:36:18浏览次数:44  
标签:load 读取 22 pth vgg16 模型 torch models

1、模型的保存

 1 '''1、模型的保存'''
 2 import torch
 3 import torchvision
 4 
 5 vgg16=torchvision.models.vgg16(pretrained=False)
 6 #保存方式1: 保存网络模型结构也+模型参数
 7 torch.save(vgg16,'models/vgg16_method1.pth')   #两个参数,第一个是要保存的模型,第二个是保存的路径,后缀名以pth结尾;
 8 
 9 #保存方式2:将网络模型的参数状态保存成字典的形式  (官方推荐)
10 torch.save(vgg16.state_dict(),'models/vgg16_method2.pth')

 

'''2、模型的加载'''
import torch

#第一种保存方式所对应的加载模型方式
model1=torch.load('models/vgg16_method1.pth')
print(model1)

#加载2;
model2=torch.load('models/vgg16_method2.pth')
print(model2)

第一种方式的输出:

 

 第二种方式的输出:

 

 

###如果想把第二种方式的输出恢复成网络结构的形式
#新建一个网络模型,-----然后使用load_state_dict加载保存好的模型的参数
vgg16=torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('models/vgg16_method2.pth'))
print(vgg16)

3、第一种方式的缺点

就是当自己定义的网络模型的时候,在加载的时候会报错

#第一种方式的缺点情况:自己创建的网络模型在加载时报错
class lianran(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv=Conv2d(in_channels=3,out_channels=64,kernel_size=3)
    def forward(self,x):
        output=self.conv(x)
        return  output
li=lianran()
torch.save(li,'models/lianran.pth')
#缺点情况的加载:
model3=torch.load('models/lianran.pth')

 

 解决:

#缺点情况的加载会报错,解决办法是把模型拿过来,但是不需要构建
#解决办法一:把模型拿过来,但是不需要构建
class lianran(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv=Conv2d(in_channels=3,out_channels=64,kernel_size=3)
    def forward(self,x):
        output=self.conv(x)
        return  output
model3=torch.load('models/lianran.pth')
#解决办法二:把定义网络模型的py文件加载过来
from models import *
model3=torch.load('models/lianran.pth')
print(model3)

 

标签:load,读取,22,pth,vgg16,模型,torch,models
From: https://www.cnblogs.com/ar-boke/p/17155286.html

相关文章