首页 > 其他分享 >pytorch的梯度求导问题

pytorch的梯度求导问题

时间:2022-11-11 20:00:42浏览次数:38  
标签:tensor 梯度 x2 pytorch 求导 x3 grad

实验数据:

x1 = torch.tensor([1, 2], dtype=torch.float, requires_grad=True)
x2 = torch.tensor([3, 4], dtype=torch.float, requires_grad=True)
x3 = torch.tensor([5, 6], dtype=torch.float, requires_grad=True)
y = (torch.pow(x1, 3) + torch.pow(x2, 2) + x3).sum()
y.backward()
x4 = x2.clone()
print(x1.grad, x2.grad, x3.grad, x4.grad)
# 梯度分别是 3x^2, 2x, 1, None
# tensor([ 3., 12.]) tensor([6., 8.]) tensor([1., 1.]) None

求导是通过backward()来实现的,最后对象一定是一个scalar,比如y.backward()。这里的y是一个和一些需要求导的tensor相关的数值。
则可以通过x.grad去查看tensor x的梯度。

tensor有一个属性requires_grad,决定是否求梯度。
可以通过detach()或者detach_()放弃求导。注意,在pytorch中_代表是否修改自身。比如x.detach()只是返回一个放弃求导的tensor,x本身并没有放弃。但是需要注意下文所说的浅拷贝问题,即使y=x.detach()返回了一个放弃求导的tensor,此时x可求导y不可求,但是对y做修改仍然会影响到x。
注意python中的=是浅拷贝,如果用a = b去构造a的话,a的变化会在b上面进行修改。

x2.detach_()
print(x2)  # x2本身不可求导了  tensor([3., 4.])
x5 = x3.detach()
print(x3, x5)  # 一个可求导一个不可 tensor([5., 6.], requires_grad=True) tensor([5., 6.])
x5[0] = 100
print(x3, x5)  # 浅拷贝 还是会影响彼此 tensor([100.,   6.], requires_grad=True) tensor([100.,   6.])

所以如果想构造一个相同的tensor,可以通过clone()来实现,如a = b.clone()。需要注意的是,clone出来的tensor的requires_grad始终为True,并且不会复制原来的tensor的grad,并且丢失自己原来的grad。

# 原来的x3和x1都有梯度值
x3 = x1.clone()
print(x3, x3.grad)  # tensor([1., 2.], grad_fn=<CloneBackward>) None  # 可以看到x3可求导,但是梯度值清空为None了,丢失了原来的梯度

对于求出来的梯度grad是不清空的,如果多次求导,梯度会累加。如果想清空某个tensor的梯度,可以使用grad.zero(),比如x.grad.zero()
举例:

x2.grad.zero_()  # 清空梯度
z = (torch.pow(x1, 3) + torch.pow(x2, 2) + x3).sum()
z.backward()
print(x1.grad, x2.grad, x3.grad, x4.grad)
# 注意此时除了x2之外其他tensor梯度是原来梯度的两倍,因为其他都累加了一次新的,而x2梯度清空了
# tensor([ 6., 24.]) tensor([12., 16.]) tensor([2., 2.]) None

标签:tensor,梯度,x2,pytorch,求导,x3,grad
From: https://www.cnblogs.com/ReflexFox/p/16881562.html

相关文章

  • Pytorch随机种子设置及原理
    深度学习网络模型中初始的权值参数通常都是初始化成随机数,而使用梯度下降法最终得到的局部最优解对于初始位置点的选择很敏感,下面介绍Pytorch中随机种子的设置及其原理。......
  • DataLoader 每次迭代返回BatchEncoding还是dict类型依pytorch的版本而定
    发现DataLoader在不同的pytorch版本上,执行dataset的__item__会返回不同的效果。pytorch在1.12.1上,每一次迭代会返回BatchEncoding这个类型(可能会比这个版本低也......
  • Pytorch笔记:dataloader的collate_fn参数在加载数据集时的作用
    1.前言最近在复现MCNN时发现一个问题,ShanghaiTech数据集图片的尺寸不一,转换为tensor后的shape形状不一致,无法直接进行多batch_size的数据加载。经过查找资料,有人提到可以......
  • 是谁的请求导致我的系统一直抛异常?
    作者:屿山、十眠在线上环境中,请求错综复杂,如果有某个请求出现了不符合预期的情况,我们往往会先需要确定这个请求在实际环境中是由哪个Controller来处理的。通常情况下,我们......
  • conda 虚拟环境安装pytorch & d2l包
    conda虚拟环境安装pytorch1、首先,conda终端添加清华镜像源,可以加快安装速度。2、确认电脑匹配的CUDA型号,(例如,9.2)3、新建一个虚拟环境,在终端运行condacreate-nXXXp......
  • PyTorch中F.cross_entropy()函数
    对PyTorch中F.cross_entropy()的理解PyTorch提供了求交叉熵的两个常用函数:一个是F.cross_entropy(),另一个是F.nll_entropy(),是对F.cross_entropy(input,target)中参数targ......
  • pytorch张量索引
    一、pytorch返回最值索引1官方文档资料1.1torch.argmax()介绍 返回最大值的索引下标函数:torch.argmax(input,dim,keepdim=False)→LongTensor返回值:Retur......
  • pytorch tensor 张量常用方法介绍
    1. view()函数PyTorch 中的view()函数相当于numpy中的resize()函数,都是用来重构(或者调整)张量维度的,用法稍有不同。>>>importtorch>>>re=torch.tensor([1,......
  • pytorch TensorDataset和DataLoader区别
    TensorDatasetTensorDataset可以用来对tensor进行打包,就好像python中的zip功能。该类通过每一个tensor的第一个维度进行索引。因此,该类中的tensor第一维度必须......
  • pytorch入门
    初衷:看不懂论文开源代码参考:B站小土堆(土堆yyds~)   PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】_哔哩哔哩_bilibili 1.环境配置参考:(39条消息)win10......