prologue
title: [pytorch] 训练时冻结一部分模型的参数 —— module.requires_grad_(False)
代码用到一个解码器\(dec\),希望用它预测生成结果\(g\)的counting encode并用以计算损失,以此约束生成器生成合理的结果(能解码出正确的counting encode)
但考虑到\(g\)并不准确,如果不冻结\(dec\)的参数,就会被\(g\)带偏
idea
实际上可以用 nn.Module.requires_grad_(False)
:
dec.requires_grad_(False)
logits = dec.classify(g)
dec.requires_grad_(True)
loss(logits, label)
但是dec还需要处理其他特征(比如输入x,用以训练dec本身),需要更新参数,不确定上面那样在梯度更新之前就重新设为True是否可行
validate
下面就用一个两个MLP来模拟\(dec\):
import random
import torch
import torch.nn as nn
from torch import optim
import numpy as np
def setup_seed(seed, strict=True):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
if strict:
torch.backends.cudnn.deterministic = True
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.c_fc = nn.Linear(6, 4 * 6, bias=False)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * 6, 6, bias=False)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
x = self.c_fc(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x
接着先固定seed,创建两个mlp并查看初始参数(以其中一个线性层为例即可),输入x随机初始化,再用SGD更新参数
setup_seed(233)
m1 = MLP()
m2 = MLP()
print('m1-weight\n', m1.c_fc.weight[:3, :7])
print('m2-weight\n', m2.c_fc.weight[:3, :7])
x = torch.randn((3, 6), requires_grad=True)
# print(x.requires_grad, x)
opimizer1 = optim.SGD(m1.parameters(), lr=0.1)
opimizer2 = optim.SGD(m2.parameters(), lr=0.1)
上述输出:
m1-weight
tensor([[-0.1926, -0.3506, -0.0758, 0.2318, -0.1711, 0.3389],
[-0.1360, -0.3253, -0.1471, 0.3951, 0.4054, -0.0247],
[-0.1603, 0.1496, -0.2566, 0.3095, 0.2247, -0.3124]],
grad_fn=<SliceBackward0>)
m2-weight
tensor([[ 0.3739, 0.3307, -0.1514, 0.0787, 0.3436, -0.1428],
[-0.1487, 0.1236, 0.4002, -0.2563, -0.0266, -0.2860],
[ 0.3197, -0.1728, -0.1770, -0.2492, 0.2864, -0.3191]],
grad_fn=<SliceBackward0>)
将x扔进模型,然后得到最终输出x3,retain_grad
用以维持这些中间变量的梯度,方便输出查看,m2调用时冻结参数,与m1形成对照:
x1 = m1(x)
m2.requires_grad_(False)
x2 = m2(x1)
m2.requires_grad_(True)
x3 = x2 * 0.3 - 3
x1.retain_grad()
x2.retain_grad()
下面的t不太清楚,似乎是给backward
的起始梯度,总之能算就行。
t = torch.randn((3, 6)) * 10
# t = torch.ones_like(x)
opimizer1.zero_grad()
opimizer2.zero_grad()
x3.backward(t)
opimizer1.step()
opimizer2.step()
# print('x\n', x.grad)
# print('x1\n', x1.grad)
# print('x2\n', x2.grad)
print('m1-weight\n', m1.c_fc.weight[:3, :7])
print('m2-weight\n', m2.c_fc.weight[:3, :7])
上述输出:
m1-weight
tensor([[-0.1897, -0.3464, -0.0873, 0.2439, -0.1725, 0.3384],
[-0.1663, -0.3467, -0.2410, 0.4209, 0.4023, -0.0558],
[-0.1909, 0.1313, -0.3833, 0.3582, 0.2189, -0.3493]],
grad_fn=<SliceBackward0>)
m2-weight
tensor([[ 0.3739, 0.3307, -0.1514, 0.0787, 0.3436, -0.1428],
[-0.1487, 0.1236, 0.4002, -0.2563, -0.0266, -0.2860],
[ 0.3197, -0.1728, -0.1770, -0.2492, 0.2864, -0.3191]],
grad_fn=<SliceBackward0>)
发现m2参数果然没变,而m1已经更新,因此证明方案可行。
标签:False,weight,requires,module,seed,m1,m2,grad From: https://www.cnblogs.com/Stareven233/p/17762895.html