文章目录
- 前言
- 第一步 加载预训练模型并修改类别数
- 第二步 选择模型所有层/最后一层进行反向传播优化
- 探讨:如何确定模型最后一层的名字是什么
- 方法一: 查询源代码
- 方法二: 查询模型的子模块名字
前言
首先,这里不讲迁移学习的理论,只讲实践,因为理论已经全网飞了~~,不懂得大家先去学理论,理论学了再来实操。
今天,在这里只想给大家介绍一种代码写法,适用于基于pytorch的迁移学习。
迁移学习主要用在分类模型上,把原本在ImageNet或其他数据集上训练好的模型,迁移到自己的项目上来。所以对于分类模型,我们要把模型最后一层(通常是全连接层)的输出分类类别数量改了。比如,原本在ImageNet上是分1000类,而我们的目标是分3类,就要把最后的类别数改为3。
整个过程分为两步:
第一步 加载预训练模型并修改类别数
from torchvision.models import densenet169, resnet50
import torch.nn as nn
import torch.optim as optim
### 加载模型, 并修改模型的最后一层 ####
model = densenet169(pretrained=True)
# 设定 pretrained = True 就会加载训练好的模型
# 修改模型的最后一层。不同的模型,最后一层的修改略有差异
arch = 'densenet169'
classes = 2 # 分类的数量
if 'resnet' in arch:
# for param in model.layer4.parameters():
model.fc = nn.Linear(2048, classes)
if 'dense' in arch:
if '121' in arch:
# (classifier): Linear(in_features=1024)
model.classifier = nn.Linear(1024, classes)
elif '169' in arch:
# (classifier): Linear(in_features=1664)
model.classifier = nn.Linear(1664, classes)
第二步 选择模型所有层/最后一层进行反向传播优化
迁移学习有两种模型:一种是对预训练好的模型,再次从头训练,模型的每一层都要重新优化。另一种是只重新训练模型最后一层,其余层的参数固定。第一种方法适用于大多数迁移学习,预训练好的模型是在自然图像上训练的,如果迁移到医学图像上来,那么特征之间的差异很大,这时就选择第一种,重头训练。但假如有一个任务是分别猫和狗这种自然图像,且在ImageNet这个数据集中已经包含的,那么就只优化最后一层参数即可。
#### 选择模型所有层都要进行反向传播优化 还是 只优化最优一层 #####
fullretrain = True # True: 表示所有层都要进行优化,为False: 只优化最后一层
if fullretrain:
print("=> optimizing all layers")
for param in model.parameters():
param.requires_grad = True
optimizer = optim.Adam(model.parameters(), lr=0.03, weight_decay=1e-4)
# model.parameters(): 把模型所有参数都传进去
else:
print("=> optimizing fc/classifier layers")
optimizer = optim.Adam(model.module.fc.parameters(), lr=0.03, weight_decay=1e-4)
# model.module.fc.parameters(): 只传最后一个分类层的参数进去
# 注意: 不同模型,最后一层的名字不一样
我们从代码里面可以发现,两种方法的区别就是优化器(optimizer)接收的参数不一样,第一种方法是把模型的所有参数都传进去,第二种是只传模型最后一层的参数。
迁移学习这部分的代码就讲完了。其余的就跟平时训练模型一样的写法。
如果是做分类实验,经验来看,采用预训练的模型都比你自己从0开始训练的效果好。不信的话,可以自己对比对比。
接下来,对其中的部分细节进行进一步的探讨~~
探讨:如何确定模型最后一层的名字是什么
如上述代码里,在resnet这个模型中,它最后一层叫: fc
在Densenet模型中,最后一层叫: classifier
那我怎么知道它最后一层叫什么呢?
方法一: 查询源代码
最直接的办法就是进源代码里面去查看。
方法二: 查询模型的子模块名字
model = densenet169(pretrained=False)
for name in model.named_modules():
print(name)
这种方法不仅可以知道模型最后一层的名字。