首页 > 其他分享 >【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

时间:2023-12-06 15:05:32浏览次数:32  
标签:版本号 梯度 self mask label context 128 backward



【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误

  • 报错详情
  • 错误产生背景
  • 原理
  • 解决方案


RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

报错详情

  模型在backward时,发现如下报错:

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_人工智能


  即RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

  其大概意思是说,当在计算梯度时,某个变量已经被操作修改了,这会导致随后的计算梯度的过程中该变量的值发生变化,从而导致计算梯度出现问题。

错误产生背景

  起因是我要复现一种层级多标签分类的网络结构:

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_梯度更新_02


  当输入序列【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_深度学习_03经过一次BERT模型之后,得到当前预测的一级标签,然后拼接到输入序列【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_深度学习_03上,再次输入到BERT模型里以预测二级标签。

  出错版本的模型结构如下:

def forward(self, x, label_A_emb):
        context = x[0]  # 输入的句子
        mask = x[2]  

        d1 = self.bert(context, attention_mask=mask)
        logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]
        idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]
        extra = label_A_emb[idx]

        context[:, -3:] = extra
        mask[:, -3:] = 1

        d2 = self.bert(context, attention_mask=mask)
        logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]

        return logit1, logit2

  在计算梯度时,由于contextmask的值被中间修改过一次,所以会报错。

原理

【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_深度学习_05


【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_反向传播_06的梯度计算如上图,损失函数为【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_深度学习_07,最终【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_反向传播_06的梯度里是需要用到原始输入【backward解决方案与原理】网络模型在梯度更新时出现变量版本号机制错误_反向传播_09的。

  所以在上面贴的模型结构代码中,输入在经过神经网络之后,又作了一次改动,然后再经过神经网络。但是梯度计算会计算两次的梯度,可是发现输入只有改动后的值了,改动前的值已经被覆盖。

计算梯度时的版本号机制是PyTorch中用于跟踪张量操作历史的一种机制。它允许PyTorch在需要计算梯度时有效地管理和跟踪相关的操作,以便进行自动微分。每个张量都有一个版本号,记录了该张量的操作历史。当对一个张量执行就地操作(inplace operation)时,例如修改张量的值或重新排列元素的顺序,版本号会增加。这种就地操作可能导致计算梯度时出现问题,因为梯度计算依赖于操作历史。

解决方案

  把即将改动的变量深拷贝一份,最终优化的代码如下:

def forward(self, x, label_A_emb):
        context = x[0]  # 输入的句子
        mask = x[2]  

        d1 = self.bert(context, attention_mask=mask)
        logit1 = self.fc1(d1[1])  # [batch_size, label_A_num] = [128, 34]
        idx = torch.max(logit1.data, 1)[1] # [batch_size] = [128]
        extra = label_A_emb[idx]

        context_B = copy.deepcopy(context)
        mask_B = copy.deepcopy(mask)

        context_B[:, -3:] = extra
        mask_B[:, -3:] = 1

        d2 = self.bert_A(context_B, attention_mask=mask_B)
        logit2 = self.fc2(d2[1])  # [batch_size, label_B_num] = [128, 34]

        return logit1, logit2


标签:版本号,梯度,self,mask,label,context,128,backward
From: https://blog.51cto.com/u_15942590/8704944

相关文章

  • electron项目同一壳版本号(目录)实现安装信息和内容不同(少量不同)
    一、通过electron层的scripts中的build.nsi文件修改安装生成的set.ini文件内容SetShellVarContextall/*把当前安装包的名字写入set.ini,便于程序读取并设置{setupname}参数*/IfFileExists"$INSTDIR\set.ini"0file_not_foundWriteINIStr"$INSTDIR\R......
  • 获取git版本号写入到DLL文件
    stringbaseDirectory=System.AppDomain.CurrentDomain.BaseDirectory;stringprojectDirectory=baseDirectory.Substring(0,baseDirectory.LastIndexOf("\\aspnet-core"));stringfilePath=projectDirectory+"\\aspnet-core\\co......
  • 数据分享|python分类预测职员离职:逻辑回归、梯度提升、随机森林、XGB、CatBoost、LGB
    全文链接:https://tecdat.cn/?p=34434原文出处:拓端数据部落公众号分析师:ShilinChen离职率是企业保留人才能力的体现。分析预测职员是否有离职趋向有利于企业的人才管理,提升组织职员的心理健康,从而更有利于企业未来的发展。解决方案任务/目标采用分类这一方法构建6种模型对职......
  • Matlab中gradient函数 梯度计算原理
    ​Gradient(F)函数求的是数值上的梯度,假设F为矩阵.Gradient算法>>x=[6,9,3,4,0;5,4,1,2,5;6,7,7,8,0;7,8,9,10,0]x=6934054125677807891......
  • 解锁机器学习-梯度下降:从技术到实战的全面指南
    本文全面深入地探讨了梯度下降及其变体——批量梯度下降、随机梯度下降和小批量梯度下降的原理和应用。通过数学表达式和基于PyTorch的代码示例,本文旨在为读者提供一种直观且实用的视角,以理解这些优化算法的工作原理和应用场景。关注TechLead,分享AI全维度知识。作者拥有10+年互......
  • 解锁机器学习-梯度下降:从技术到实战的全面指南
    本文全面深入地探讨了梯度下降及其变体——批量梯度下降、随机梯度下降和小批量梯度下降的原理和应用。通过数学表达式和基于PyTorch的代码示例,本文旨在为读者提供一种直观且实用的视角,以理解这些优化算法的工作原理和应用场景。关注TechLead,分享AI全维度知识。作者拥有10+年......
  • 更新react版本号
    [求更新react版本号·Issue#52·DataV-Team/DataV-React](https://github.com/DataV-Team/DataV-React/issues/52)package.json设置overrides可以解决这个问题"overrides":{"@jiaminghi/data-view-react":{"react":"^18......
  • 神经网络入门篇之深层神经网络:详解前向传播和反向传播(Forward and backward propagati
    深层神经网络(DeepL-layerneuralnetwork)复习下前面的内容:1.逻辑回归,结构如下图左边。一个隐藏层的神经网络,结构下图右边:注意,神经网络的层数是这么定义的:从左到右,由0开始定义,比如上边右图,\({x}_{1}\)、\({x}_{2}\)、\({x}_{3}\),这层是第0层,这层左边的隐藏层是第1层,由此类推......
  • Maven 插件统一修改聚合工程项目版本号
    ......
  • linux: debian的数字版本号与别名
    1、首先查看操作系统的版本cat /etc/debian_version2、然后可以查看Debian系统版本与codename之间的关系https://wiki.debian.org/DebianReleases ......