首页 > 其他分享 >[pytorch] 训练时冻结一部分模型的参数 —— module.requires_grad_(False)

[pytorch] 训练时冻结一部分模型的参数 —— module.requires_grad_(False)

时间:2023-10-17 20:15:56浏览次数:42  
标签:False weight requires module seed m1 m2 grad

prologue

title: [pytorch] 训练时冻结一部分模型的参数 —— module.requires_grad_(False)
代码用到一个解码器\(dec\),希望用它预测生成结果\(g\)的counting encode并用以计算损失,以此约束生成器生成合理的结果(能解码出正确的counting encode)
但考虑到\(g\)并不准确,如果不冻结\(dec\)的参数,就会被\(g\)带偏

idea

实际上可以用 nn.Module.requires_grad_(False)

dec.requires_grad_(False)
logits = dec.classify(g)
dec.requires_grad_(True)
loss(logits, label)

但是dec还需要处理其他特征(比如输入x,用以训练dec本身),需要更新参数,不确定上面那样在梯度更新之前就重新设为True是否可行

validate

下面就用一个两个MLP来模拟\(dec\):

import random
import torch
import torch.nn as nn
from torch import optim
import numpy as np


def setup_seed(seed, strict=True):
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  np.random.seed(seed)
  random.seed(seed)
  if strict:
    torch.backends.cudnn.deterministic = True


class MLP(nn.Module):

  def __init__(self):
    super().__init__()
    self.c_fc = nn.Linear(6, 4 * 6, bias=False)
    self.gelu = nn.GELU()
    self.c_proj = nn.Linear(4 * 6, 6, bias=False)
    self.dropout = nn.Dropout(0.1)

  def forward(self, x):
    x = self.c_fc(x)
    x = self.gelu(x)
    x = self.c_proj(x)
    x = self.dropout(x)
    return x

接着先固定seed,创建两个mlp并查看初始参数(以其中一个线性层为例即可),输入x随机初始化,再用SGD更新参数

setup_seed(233)
m1 = MLP()
m2 = MLP()
print('m1-weight\n', m1.c_fc.weight[:3, :7])
print('m2-weight\n', m2.c_fc.weight[:3, :7])
x = torch.randn((3, 6), requires_grad=True)
# print(x.requires_grad, x)
opimizer1 = optim.SGD(m1.parameters(), lr=0.1)
opimizer2 = optim.SGD(m2.parameters(), lr=0.1)

上述输出:

m1-weight
 tensor([[-0.1926, -0.3506, -0.0758,  0.2318, -0.1711,  0.3389],
        [-0.1360, -0.3253, -0.1471,  0.3951,  0.4054, -0.0247],
        [-0.1603,  0.1496, -0.2566,  0.3095,  0.2247, -0.3124]],
       grad_fn=<SliceBackward0>)
m2-weight
 tensor([[ 0.3739,  0.3307, -0.1514,  0.0787,  0.3436, -0.1428],
        [-0.1487,  0.1236,  0.4002, -0.2563, -0.0266, -0.2860],
        [ 0.3197, -0.1728, -0.1770, -0.2492,  0.2864, -0.3191]],
       grad_fn=<SliceBackward0>)

将x扔进模型,然后得到最终输出x3,retain_grad用以维持这些中间变量的梯度,方便输出查看,m2调用时冻结参数,与m1形成对照:

x1 = m1(x)
m2.requires_grad_(False)
x2 = m2(x1)
m2.requires_grad_(True)
x3 = x2 * 0.3 - 3
x1.retain_grad()
x2.retain_grad()

下面的t不太清楚,似乎是给backward的起始梯度,总之能算就行。

t = torch.randn((3, 6)) * 10
# t = torch.ones_like(x)
opimizer1.zero_grad()
opimizer2.zero_grad()
x3.backward(t)
opimizer1.step()
opimizer2.step()
# print('x\n', x.grad)
# print('x1\n', x1.grad)
# print('x2\n', x2.grad)
print('m1-weight\n', m1.c_fc.weight[:3, :7])
print('m2-weight\n', m2.c_fc.weight[:3, :7])

上述输出:

m1-weight
 tensor([[-0.1897, -0.3464, -0.0873,  0.2439, -0.1725,  0.3384],
        [-0.1663, -0.3467, -0.2410,  0.4209,  0.4023, -0.0558],
        [-0.1909,  0.1313, -0.3833,  0.3582,  0.2189, -0.3493]],
       grad_fn=<SliceBackward0>)
m2-weight
 tensor([[ 0.3739,  0.3307, -0.1514,  0.0787,  0.3436, -0.1428],
        [-0.1487,  0.1236,  0.4002, -0.2563, -0.0266, -0.2860],
        [ 0.3197, -0.1728, -0.1770, -0.2492,  0.2864, -0.3191]],
       grad_fn=<SliceBackward0>)

发现m2参数果然没变,而m1已经更新,因此证明方案可行。

标签:False,weight,requires,module,seed,m1,m2,grad
From: https://www.cnblogs.com/Stareven233/p/17762895.html

相关文章

  • 解决AttributeError: module tensorflow has no attribute placeholder
    解决AttributeError:module'tensorflow'hasnoattribute'placeholder'如果你在使用TensorFlow时遇到了"AttributeError:module'tensorflow'hasnoattribute'placeholder'"的错误,这意味着你正在使用的TensorFlow版本与你的代码不兼容。这个错误通常是因为在Tens......
  • ES6 module模块
    概述ES6中的module指的是JavaScript模块化规范中的一种。它通过export和import语法来导出和导入模块中的变量、函数、类等内容。使用ES6模块化的好处包括:解决了模块化的问题。消除了全局变量。管理加载顺序。使用在ES6模块中,一个文件代表一个模块当使用script标签加载模块时,需要......
  • 解决AttributeError: module tensorflow has no attribute placeholder
    解决AttributeError:module'tensorflow'hasnoattribute'placeholder'如果你在使用TensorFlow时遇到了"AttributeError:module'tensorflow'hasnoattribute'placeholder'"的错误,这意味着你正在使用的TensorFlow版本与你的代码不兼容。这个错误通常是因为在Tens......
  • Gradle导致Lombok不生效问题
    现象从debug看是可以查询到数据的,但是返回起前端是没有数据的解决办法//引入lombok注解处理器annotationProcessor,不然lombok不会生效annotationProcessor('org.projectlombok:lombok')结果......
  • Error: Vue packages version mismatch: - [email protected] (D:\前端\vue01\node_module
    Error:Vuepackagesversionmismatch:[email protected](D:\\前端\vue01\node_modules\vue\dist\vue.runtime.common.js)[email protected](D:\前端\vue01\node_modules\vue-template-compiler\package.json)根据提示信息,是版本不匹配的问题,可以直接找到vu......
  • 安装odoo13出现relation "ir_module_module" does not exist
    全新安装的odoo,但启动时出现relation"ir_module_module"doesnotexist,以为是数据库要手动初始化,所以也在启动时加入-ibase-dodoo13的命令,但也无效,注释addons_path就ok,但路径检查过是没有问题的,待启动之后,再打开addons_path就行了,应该和addons_path里面有些插件有错误导致......
  • Argument for '--moduleResolution' option must be: 'node', Unknown compiler opt
    node_modules/@vue/tsconfig/tsconfig.json(12,25):errorTS6046:Argumentfor'--moduleResolution'optionmustbe:'node','classic','node16','nodenext'.node_modules/@vue/tsconfig/tsconfig.json(33,5):erro......
  • 通过 modules 创建 vuex 的模块
    模块拆分:1.在store文件夹下再新建文件夹modules,在modules下新建xxx.js文件:eg:新建user.js文件conststate={ userInfo:{  name:'zs',  age:18 }, score:80}constmutations={}constactions={}constgetters={}exportdefault......
  • ConfigureAwait(false) 原理以及注意事项总结
    解决什么问题?1、避免线程死锁2、可能的性能提升存在的问题:1、当代码在另一个线程上继续时,线程同步上下文将丢失,因为状态机改变。这里最大的损失是你会失去归属于线程的Culture和Language,其中包含了国家语言时区信息,以及来自原始线程的HttpContext.Current之类的信息。因此,如......
  • Go语言模块管理:GO111MODULE的含义
    在cmd中使用goenv命令可以查看到我们的GOPATH环境变量。其目录结构为:bin:存放代码编译后的二进制文件pkg:存放编译后的库文件src:存放自己编写的Go语言代码文件在Go1.11后新增了modules特性,模块是相关Go包的集合。如果在cmd中执行以下命令将GO111MODULE变量的值设为on:go......