1、网络模型在pytorch里面的torchvision里面torchvision.models,是关于图像类的网络模型
2、简单以一个分类模型为例子: VGG(最常用的是VGG16和VGG19)
pretrained:
如果是true的话,说明在ImageNet数据集上,模型的参数是都训练好的; 如果是False的话,说明模型的参数是初始化的,没有训练好。
vgg16_false=torchvision.models.vgg16(pretrained=False) #当pretrained为 False的时候只是加载网络模型。是不需要对网络模型的参数进行下载的 vgg16_true=torchvision.models.vgg16(pretrained=True) #pretrained=True时,需要下载网络模型,下载模型里的参数 print(vgg16_true)
progress:
如果是True,显示下载进度条; False则不显示
3、ImageNet数据集:
4、修改现有模型
train_data=torchvision.datasets.CIFAR10('../../dataset/CIFAR10',train=False, transform=torchvision.transforms.ToTensor(),download=True) '''如何利用现有的网络模型,去改动它的结构;比如说想让VGG是10分类任务,也就是让输出特征是10;可以有两种''' #1、再添加一个线性层 vgg16_true.add_module('add_linear',nn.Linear(in_features=1000,out_features=10)) #add_module()里面两个参数,一个是字符串型,给要加的模块起个名字,第二个是要加的模块,可以直接是一层网络,也可以是一个序列
print(vgg16_true)
输出:
# 2、如果想在序列里面添加可以这样网络模型.想要加的位置.add_moudle() vgg16_true.classifier.add_module('add_linear',nn.Linear(in_features=1000,out_features=10)) print(vgg16_true)
# 3、不想添加的话,可以进行修改 #对模型中的classifier中的第6层进行修改 vgg16_false.classifier[6]=nn.Linear(4096,10)
标签:21,vgg16,模型,网络,修改,add,true,torchvision From: https://www.cnblogs.com/ar-boke/p/17152774.html