首页 > 其他分享 >torch.detach()、torch.detach_()

torch.detach()、torch.detach_()

时间:2023-03-06 23:22:05浏览次数:27  
标签:tensor torch print detach grad out

训练网络的时候希望保持一部分网络参数不变,只对其中一部分的参数进行调整;或训练部分分支网络,并不让其梯度对主网络的梯度造成影响,这时可以使用detach()切断一些分支的反向传播。



1. tensor.detach()

返回一个新的 \(tensor\),从当前计算图中分离下来,但仍指向原 \(tensor\) 的存放位置,不同之处是requires_grad参数为 \(False\),得到的这个 \(tensor\) 永远不需要计算梯度。不具有 \(grad\)。

注意:

  • 即使之后将requires_grad设置为 \(True\) ,也不会具有梯度。
  • 使用detach返回的 \(tensor\) 和原始的 \(tensor\) 共同一个内存,即一个修改另一个也会跟着改变。

示例:参数为requires_grad=True,不具有梯度;反向传播 \(backward()\) 有梯度。

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)

out = a.sigmoid()
out.sum().backward()
print(a.grad)
None

tensor([0.1966, 0.1050, 0.0452])

示例:使用detach()分离 \(tensor\), 原始 \(tensor\) 可进行 \(backward()\)。

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)

out = a.sigmoid()
print(out)

# 添加 detach(), c 的 requires_grad 为 False
c = out.detach()
print(c)

# c 是添加 detach() 后的,不影响 out 的 backward()
out.sum().backward()
print(a.grad)
None

tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward0>)

tensor([0.7311, 0.8808, 0.9526])

tensor([0.1966, 0.1050, 0.0452])

示例:使用detach()分离 \(tensor\),新的 \(tensor\) 不能进行 \(backward()\)。

\(c\)、\(out\) 的区别是 \(c\) 没有梯度,\(out\) 有梯度。

#使用 c 进行反向传播
c.sum().backward()
print(a.grad)
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

示例:使用detach()分离 \(tensor\),对新的 \(tensor\) 更改,原始 \(tensor\) 也会更改,两者都不能进行 \(backward()\)。

import torch

a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)

out = a.sigmoid()
print(out)

# 添加 detach(), c 的 requires_grad 为 False
c = out.detach()
print(c)

# 对 c 进行更改,会影响 out
c.zero_()
print(c)
print(out)

# 对 c 进行更改,会影响 out 进行 backward()
out.sum().backward()
print(a.grad)
None

tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward0>)

tensor([0.7311, 0.8808, 0.9526])

tensor([0., 0., 0.])
tensor([0., 0., 0.], grad_fn=<SigmoidBackward0>)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation:...


2. tensor.detach_()

detach()detach_()的区别:detach_()对本身的更改,detach()则生成新的的 \(tensor\),后面想反悔只需对原来的计算图进行操作即可。



标签:tensor,torch,print,detach,grad,out
From: https://www.cnblogs.com/keye/p/17185923.html

相关文章

  • Pytorch中norm(几种范数norm的详细介绍)
    1.范数(norm)的简单介绍概念:距离的定义是一个宽泛的概念,只要满足非负,自反,三角不等式就可以称之为距离。范数是一种强化了的距离概念,它在定义上比距离多了一条数乘的运算法......
  • Windows Torch 安装
    首先,电脑要有显卡(没有显卡建议查cpu版本Torch安装和使用)一、基础装备(一)、Pycharm下载地址:DownloadPyCharm:PythonIDEforProfessionalDevelopersbyJetBrains......
  • PyTorch中的dim
    PyTorch中对tensor的很多操作如sum,softmax等都可以设置dim参数用来指定操作在哪一维进行。PyTorch中的dim类似于numpy中的axis。dim与方括号的关系创建一个矩阵a=to......
  • 机器学习日志 手写数字识别 pytorch 神经网络
    我是链接第一次用pytorch写机器学习,不得不说是真的好用pytorch的学习可以看这里,看看基本用法就行,个人感觉主要还是要看着实践代码来学习总结了几个点:1.loss出现nan这......
  • torch.nn.Embedding使用详解
    torch.nn.Embedding:随机初始化词向量,词向量值在正态分布N(0,1)中随机取值。输入:torch.nn.Embedding(num_embeddings,–词典的大小尺寸,比如总共出现5000个词,那就输入5000......
  • 在Windows上安装torch遇到的部分问题
    1、版本问题老师新买的这台机器是RTX3060,没动显卡驱动,直接安装的CUDA,装的11.4,完全按照这篇blog来的,非常舒服:https://blog.csdn.net/qq_45041871/article/details/1279500......
  • 安装pytorch报错 ERROR: Could not install packages due to an OSError: [Errno 28]
    windos安装,报错如下  看了不少回答,大概是缓存和内存满了我的C盘只给了70G,然后意外发现只剩下3G多了,先用系统自带的清理工具清理了一下,然后腾讯电脑管家“工具箱”中......
  • pytorch_debug
    1、报错信息1.1、出错位置1image=Image.open('./img.png')2#图像预处理3transforms=transforms.Compose([transforms.Resize(256),4......
  • Yolov5环境报错解决:No labels found in 与 Could not run 'torchvision::nms' with ar
    问题记录yolov5环境1Nolabelsfoundin(Done)报错内容F:\WorkSpace\GitSpace\yolov5>pythontrain-self.pytrain-self:weights=weights/yolov5s.pt,cfg=models/y......
  • 01 Pytorch的数据载体张量与线性回归
    Pytorch的数据载体张量与线性回归Pytorch官方英文文档:https://pytorch.org/docs/stable/torch.html?Pytorch中文文档:https://pytorch-cn.readthedocs.io/zh/latest/1.......