首页 > 其他分享 >解决 clamp 函数会阻断梯度传播

解决 clamp 函数会阻断梯度传播

时间:2024-01-10 17:34:54浏览次数:23  
标签:clamp 梯度 torch fc pred input grad 阻断

开端

若在网络的 forward 过程中使用 clamp 函数对数据进行截断,可能会阻断梯度传播。即,梯度变成零。

不妨先做一个实验。定义一个全连接网络 fc,通过输入 input_t 获得结果 pred,其值为 \(0.02\):

from torch.nn import functional as F
import torch.nn as nn
import torch

fc = nn.Linear(in_features=1, out_features=1, bias=True)
fc.weight.data = torch.tensor([[0.01]])
fc.bias.data = torch.tensor([[0.01]])

input_t = torch.tensor([[1.0]], dtype=torch.float32)
pred = fc(input_t)
print(pred)  # pred = 0.02

pred 进行反向传播,可以看到网络的权重都有梯度:

pred.backward()
print(fc.weight.grad)  # grad = 1.0
print(fc.bias.grad)  # grad = 1.0

如果使用 torch.clamp()pred 结果截断在 \((0.1,0.9)\) 范围内,会发现梯度消失了:

fc.zero_grad()
pred = fc(input_t)
pred = torch.clamp(pred, min=0.1, max=0.9)  # pred = 0.1
pred.backward()
print(fc.weight.grad)  # grad = 0.0
print(fc.bias.grad)  # grad = 0.0

解决方法

我们需要跳过 torch.clamp() 的梯度运算。

若是用 TensorFlow,使用 tf.stop_gradient 可以很方便地让 clamp 不参与梯度计算。在 PyTorch 就有点绕了。

最后我总结出了一个相对优雅的解决方案:

def nclamp(input, min, max):
    return input.clamp(min=min, max=max).detach() + input - input.detach()

参考来源

标签:clamp,梯度,torch,fc,pred,input,grad,阻断
From: https://www.cnblogs.com/chirp/p/17956968

相关文章

  • css多行文本省略 line-clamp
    css多行文本省略line-clamp一行文本内容溢出的省略例子:<divclass="container"style="width:200px;outline:1pxsolidred"><divclass="description"style="overflow:hidden;text-overflow:ellipsis;white-space:nowrap;"&g......
  • 策略梯度
    策略梯度呢,顾名思义,策略就是一个状态或者是action的分布,梯度就是我们的老朋友,梯度上升或者梯度下降。  就是说,J函数的自变量是西塔,然后对J求梯度,进而去更新西塔,比如说,J西塔,是一个该策略下预测状态值,也可以说是策略值,那么我们当然希望这个策略值越大越好,于是就要使用梯度上升......
  • 机器学习笔记(一)从波士顿房价预测开始,梯度下降
    从波士顿房价开始目标其实这一章节比较简单,主要是概念,首先在波士顿房价这个问题中,我们假设了一组线性关系,也就是如图所示我们假定结果房价和这些参数之间有线性关系,即:然后我们假定这个函数的损失函数为均方差,即:那么就是说,我们现在是已知y和x,来求使得这个损失函数Loss最小......
  • 共轭梯度法
    共轭梯度法适应于求解非线性优化问题线性共轭梯度法和非线性共轭梯度法1共轭方向梯度下降法和共轭方向法优过程的区别:可以发现:共轭方向法分别按两个轴的方向搜索(逐维搜索)每次搜索只更新迭代点的一个维度保证每次迭代的那个维度达最优共轭方向法的两个搜索方向正交(特......
  • 梯度下降法
    1梯度下降法\(\qquad\)梯度下降法又称最速下降法,是最优化方法中最基本的一种方法。所有的无约束最优化问题都是在求解如下的无约束优化问题:$$\min_{x\inR^n}f(x)$$将初始点\(x_0\)逐步迭代到最优解所在的点\(x^*\),那么考虑搜索点迭代过程:$$x_{t+1}=x_t+\gamma_td_t$$......
  • [最优化方法笔记] 梯度下降法
    1.梯度下降法无约束最优化问题一般可以概括为:\[\min_{x\in\mathbb{R}^n}f(x)\]通过不断迭代到达最优点\(x^*\),迭代过程为:\[x^{k+1}=x^k+\alpha_kd^k\]其中\(d^k\)为当前的搜索方向,\(\alpha_k\)为当前沿着搜索方向的步长。我们需要寻找可以不断使得\(f(x^{......
  • 利用率夹紧(Utilization Clamping) 【ChatGPT】
    https://www.kernel.org/doc/html/v6.6/scheduler/sched-util-clamp.html利用率夹紧1.简介利用率夹紧,也称为utilclamp或uclamp,是一种调度器功能,允许用户空间帮助管理任务的性能需求。它是在v5.3版本中引入的。CGroup支持在v5.4中合并。Uclamp是一种提示机制,允许调度器了解......
  • 机器学习-线性回归-小批量-梯度下降法-04
    1.随机梯度下降法梯度计算的时候随机抽取一条importnumpyasnpX=2*np.random.rand(100,1)y=4+3*X+np.random.randn(100,1)X_b=np.c_[np.ones((100,1)),X]n_epochs=10000learn_rate=0.001m=100theta=np.random.randn(2,1)forepoch......
  • 机器学习-线性回归-梯度下降法-03
    1.梯度下降法梯度:是一个theta与一条样本x组成的映射公式可以看出梯度的计算量主要来自于左边部分所有样本参与--批量梯度下降法随机抽取一条样本参与--随机梯度下降法一小部分样本参与--小批量梯度下降法2.epoch与batchepoch:一次迭代wt-->wt+1......
  • 12-梯度计算方法
    1.图像梯度-Sobel算子流程: 2.计算绝对值dx为1水平方向: 3.计算绝对值dy为1竖直方向: 4.求出x和y以后,再进行求和: 5.不建议直接设置dx为1,dy为1会造成图像不饱和: 6.推荐使用,dx和dy分别计算进行梯度计算处理: 7.不推荐使用,直接将dx(水平方向)和dy(竖直方向)同时设置为1......