01 现有网络模型的使用及修改
import torchvision
from torch import nn
#train_data = torchvision.datasets.ImageNet("../data_image_net",split='train',download =True,
# transform = torchvision.transforms.ToTensor())
vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)
print(vgg16_true)
train_data = torchvision.datasets.ImageNet("../data_image_net",split='train',download =True,
transform = torchvision.transforms.ToTensor())
vgg16_true.add_module('add_linear',nn.Linear(1000,10))
print(vgg16_true)
print(vgg16_false)
vgg16_false.classifier[6] = nn.Linear(4096,10)
print(vgg16_false)
02 网络模型的保存
import torchvision
import torch
from torch import nn
vgg16 = torchvision.models.vgg16(pretrained=False)
#保存方式1 模型结构+模型参数
torch.save(vgg16,"vgg16_method.pth")
#保存方式2 模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")#只保存网络模型中的参数
#存在的陷阱
class Tudui(nn.Module):
def __init__(self):
super(Tudui,self).__init__()
self.conv1 = nn.Conv2d(3,64,kernel_size=3)
def forward(self,x):
x = self.conv1(x)
return x
tudui = Tudui()
torch.save(tudui,"tudui_method1.pth")
03 网络模型的提取
import torch
import torchvision
#方式1,加载模型
model = torch.load("vgg16_method1.pth")
#print(model)
#方式2,加载模型
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
#model = torch.load("vgg16_method2.path")
#print(model)
#陷阱1
model = torch.load('tudui_method1.pth')
print(model)
标签:torchvision,pth,vgg16,torch,pytorch,print,import From: https://blog.csdn.net/weixin_53294261/article/details/143375235