开端
若在网络的 forward 过程中使用 clamp 函数对数据进行截断,可能会阻断梯度传播。即,梯度变成零。
不妨先做一个实验。定义一个全连接网络 fc
,通过输入 input_t
获得结果 pred
,其值为 \(0.02\):
from torch.nn import functional as F
import torch.nn as nn
import torch
fc = nn.Linear(in_features=1, out_features=1, bias=True)
fc.weight.data = torch.tensor([[0.01]])
fc.bias.data = torch.tensor([[0.01]])
input_t = torch.tensor([[1.0]], dtype=torch.float32)
pred = fc(input_t)
print(pred) # pred = 0.02
对 pred
进行反向传播,可以看到网络的权重都有梯度:
pred.backward()
print(fc.weight.grad) # grad = 1.0
print(fc.bias.grad) # grad = 1.0
如果使用 torch.clamp()
将 pred
结果截断在 \((0.1,0.9)\) 范围内,会发现梯度消失了:
fc.zero_grad()
pred = fc(input_t)
pred = torch.clamp(pred, min=0.1, max=0.9) # pred = 0.1
pred.backward()
print(fc.weight.grad) # grad = 0.0
print(fc.bias.grad) # grad = 0.0
解决方法
我们需要跳过 torch.clamp()
的梯度运算。
若是用 TensorFlow,使用 tf.stop_gradient
可以很方便地让 clamp 不参与梯度计算。在 PyTorch 就有点绕了。
最后我总结出了一个相对优雅的解决方案:
def nclamp(input, min, max):
return input.clamp(min=min, max=max).detach() + input - input.detach()
参考来源
- PyTorch: clamp函数与梯度的关系 https://blog.csdn.net/DragonGirI/article/details/132319710
- https://discuss.pytorch.org/t/exluding-torch-clamp-from-backpropagation-as-tf-stop-gradient-in-tensorflow/52404
- https://www.tensorflow.org/api_docs/python/tf/stop_gradient