首页 > 其他分享 >Pytorch只更新预训练模型的部分参数

Pytorch只更新预训练模型的部分参数

时间:2023-06-09 15:08:20浏览次数:39  
标签:name parameters 模型 update b4 Pytorch 参数 lr model


Pytorch只更新预训练模型的部分参数

假设有一个训练好的模型,并且我们只想微调部分参数。
比如,这里我们只想更新最后一部分的参数:
可以看到,这里的模块叫b4。

Pytorch只更新预训练模型的部分参数_pytorch


我们可以直接通过获取模块的名字来进行更新:

方法1

def update(model,flag=True):
    for name,p in model.named_parameters():
        if "b4" in name:
            print("update only",name)
            p.requires_grad = flag

也就是说 只要模块名字包含b4 就会让他跟新网络。
对应的optimizer 的设置如下:

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_)

然后直接训练就行。

方法二

也可以直接 把这些符合条件的 parameters 加入 list中,并传给 optimizer

def update(model,flag=True):
    paras = []
    for name,p in model.named_parameters():
        if "b4" in name:
            print("update only",name)
            p.requires_grad = flag
            paras.append(p)
    return paras
optimizer = torch.optim.Adam(paras, lr=lr_)

直接训练就行。##


标签:name,parameters,模型,update,b4,Pytorch,参数,lr,model
From: https://blog.51cto.com/u_11384719/6447819

相关文章

  • 实例讲解Flink 流处理程序编程模型
    摘要:在深入了解Flink实时数据处理程序的开发之前,先通过一个简单示例来了解使用Flink的DataStreamAPI构建有状态流应用程序的过程。本文分享自华为云社区《Flink实例:Flink流处理程序编程模型》,作者:TiAmoZhang。在深入了解Flink实时数据处理程序的开发之前,先通过一个简单......
  • MariaDB 10.11 参数变化一览
    在MariaDB10.11中,有一些参数发生了变化,下面就一起来看一下。slowquery在mariadb10.11中,与慢查询相关的参数共13个,相比于mariadb10.6,有几个参数发生了变化。MariaDB[(none)]>showvariableslike'%slow%';+---------------------------------+------------------------......
  • [Java SE] 彻底搞懂Java程序的三大参数配置途径:系统变量与JVM参数(VM Option)/环境变
    0序言一次没搞懂,处处受影响。这个问题属于基础问题,但又经常踩坑,不得不重视一下了。1Java程序动态参数的配置途径:系统变量与JVM参数(VMOption)vs环境变量vs启动程序参数argsIDEA中的配置位置参数使用方式示例代码获取方式系统属性由操作系统、JVM、应用......
  • pandas中的read_csv参数详解
    来自:https://blog.csdn.net/weixin_44852067/article/details/122366383,感谢作者。 pandas中的read_csv参数详解独影月下酌酒于2022-01-0715:57:29发布40866收藏204分类专栏:pandas文章标签:python数据挖掘pandas版权华为云开发者联盟该内容已被华为云开发者联盟......
  • N6、seq2seq翻译实战-Pytorch复现
    ......
  • python tkinter 动态批量建立Widget时,combobox 或 entry传递参数问题
    terminal_combobox.bind('<<ComboboxSelected>>',lambdaevent,arg=key_dict:self.terminal_select(key_dict=arg))#注意,传递参数方法defterminal_select(self,key_dict,*args):var=self.dict_widget[key_d......
  • Java各种路径和参数
    1.JSP中获得当前应用的相对路径和绝对路径:根目录所对应的绝对路径:request.getRequestURI()文件的绝对路径:application.getRealPath(request.getRequestURI());当前web应用的绝对路径:application.getRealPath("/");取得请求文件的上层目录:newFile(application.getRealP......
  • [人工智能-NLP]使用GPT-2预训练模型进行微调
    下面是一个使用GPT-2进行微调的示例。以文本生成为例,我们将微调GPT-2来生成新闻标题。此外,我们将使用PyTorch作为深度学习框架,以便于构建和训练模型。安装PyTorch和Transformers首先需要安装PyTorch和Transformers库。在终端中输入以下命令:pipinstalltorchtransformers......
  • 【人人懂AI】用chatGPT学会大模型GPT
    1.一句话掌握最新关键知识点1.1什么是chatGPT?chatGPT是基于OpenAI公司的人工智能大模型GPT系列开发出的一个网页版的对话机器人。用户可以在网页登录与chatGPT进行语言交流,支持多种主流语言,chatGPT与传统大的智能对话机器人不同,它可以几乎接近人类的理解和表达能力,在对话中扮......
  • 面试官:你会哪些JVM调优参数?
    你好,我是田哥。上周一位朋友去面试被问到JVM参数,本文咱们就来聊聊。面试造火箭.......,我们很多人干了三、五年的Java开发,其实压根儿没使用过JVM调优参数。但是,面试官可不管你有没有用过,面试官心里想的是“这问题回答不出来,证明你很lowB,还想要那么高的薪资,没门”。话不多说,我们开始......