Pytorch只更新预训练模型的部分参数
假设有一个训练好的模型,并且我们只想微调部分参数。
比如,这里我们只想更新最后一部分的参数:
可以看到,这里的模块叫b4。
我们可以直接通过获取模块的名字来进行更新:
方法1
def update(model,flag=True):
for name,p in model.named_parameters():
if "b4" in name:
print("update only",name)
p.requires_grad = flag
也就是说 只要模块名字包含b4 就会让他跟新网络。
对应的optimizer 的设置如下:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_)
然后直接训练就行。
方法二
也可以直接 把这些符合条件的 parameters 加入 list中,并传给 optimizer
def update(model,flag=True):
paras = []
for name,p in model.named_parameters():
if "b4" in name:
print("update only",name)
p.requires_grad = flag
paras.append(p)
return paras
optimizer = torch.optim.Adam(paras, lr=lr_)
直接训练就行。##