首页 > 其他分享 >2-3动态计算图

2-3动态计算图

时间:2024-02-05 13:45:54浏览次数:25  
标签:loss y2 tensor torch 计算 y1 动态 grad

本节我们将介绍Pytorch的动态计算图。

包括:

  • 动态计算图简介
  • 计算图中的Function
  • 计算图和反向传播
  • 叶子节点和非叶子节点
  • 计算图在TensorBorad中的可视化

1.动态计算图简介

Pytorch的计算图是由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系。

Pytorch中的计算图是动态图。这里的动态主要有两重含义。

第一层含义是:计算图的正向传播是立即执行的。无需等待完整的计算图创建完毕,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到结果。

第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图。如果在程序中使用了backward方法执行了反向传播,或者利用autograd.grad方法计算 了梯度,那么创建的计算图会被立即销毁,释放存储空间,下次调用需要重新创建。

  • 计算图的正向传播是立即执行的
import torch

w = torch.tensor([[3.0, 1.0]], requires_grad=True)
b = torch.tensor([[3.0]], requires_grad=True)
X = torch.randn(10, 2)
Y = torch.randn(10, 1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关  # 转置
loss = torch.mean(torch.pow(Y_hat - Y, 2))

print(loss.data)
print(Y_hat.data)

"""
tensor(24.1702)
tensor([[-1.6971],
        [ 6.2795],
        [ 8.3266],
        [-0.0096],
        [ 1.7790],
        [ 3.9209],
        [ 2.9720],
        [ 7.0749],
        [ 4.8034],
        [ 3.9983]])
"""
  • 计算图在反向传播后立即销毁
import torch

w = torch.tensor([[3.0, 1.0]], requires_grad=True)
b = torch.tensor([[3.0]], requires_grad=True)
X = torch.randn(10, 2)
Y = torch.randn(10, 1)
Y_hat = X@w.t() + b
loss = torch.mean(torch.pow(Y_hat - Y, 2))

# 计算图在反向传播后立即销毁,如果需要保留计算图,需要设置retain_graph=True
loss.backward()
# loss.backward()  # 如果再次执行反向传播,将会报错

2.计算图中的Function

计算图中的张量我们已经比较熟悉了,计算图中的另外一种节点是Function,实际上就是Pytorch中各种对张量操作的函数

这些Function和我们Python中的函数有一个较大的区别,那就是它同时包括正向计算逻辑和反向传播的逻辑。

我们可以通过继承torch.autograd.Function来创建这种支持反向传播的Function

class MyReLU(torch.autograd.Function):

    # 正向传播逻辑,可以用ctx存储一些值,供反向传播使用
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)  # 限幅。将input的值限制在[min, max]之间,并返回结果。o

    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input
import torch

w = torch.tensor([[3.0, 1.0]], requires_grad=True)
b = torch.tensor([[3.0]], requires_grad=True)
X = torch.tensor([[-1.0, -1.0], [1.0, 1.0]])
Y = torch.tensor([[2.0, 3.0]])

relu = MyReLU.apply  # relu现在也可以具有正向传播和反向传播功能
Y_hat = relu(X@w.t() + b)
loss = torch.mean(torch.pow(Y_hat - Y, 2))

loss.backward()

print(w.grad)
print(b.grad)

"""
tensor([[4.5000, 4.5000]])
tensor([[4.5000]])
"""

# Y_hat的梯度函数即是我们自己所定义的MyReLU.backward

print(Y_hat.grad_fn)

"""
<torch.autograd.function.MyReLUBackward object at 0x0000025B574E4840>
"""

3.计算图与反向传播

import torch

x = torch.tensor(3.0, requires_grad=True)
y1 = x + 1
y2 = 2 * x
loss = (y1-y2) ** 2

loss.backward()

loss.backward()语句调用后,依次发生以下计算过程。

1、loss自己的grad梯度赋值为1,即对自身的梯度为1。

2、loss根据其自身梯度以及关联的backward方法,计算出其对应的自变量即y1和y2的梯度,将该值赋值到y1.grad和y2.grad

3、y2和y1根据其自身梯度以及关联的backward方法,分别计算出其对应的自变量x的梯度,x.grad将其收到的多个梯度值累加

注意:123步骤的求梯度顺序和对多个梯度值的累加规则恰好是求导链式法则的程序表述

正因为求导链式法则衍生的梯度累加规则,张量的grad梯度不会自动清零,在需要的时候手动置零。

4.叶子节点和非叶子节点

执行下面代码,我们会发现loss.grad并不是我们期望的1,而是None

类似的,y1.grad以及y2.grad也是None

这是为什么呢?这是由于它们不是叶子节点张量

在反向传播过程中,只有is_leaf=True的叶子节点,需要求导的张量的导数结果才会被最后保留下来

那什么是叶子节点张量呢?叶子节点张量需要满足两个条件。

1、叶子节点张量是由用户直接创建的张量,而非由某个Function通过计算得到的张量

2、叶子节点张量的requires_grad属性必须为True

Pytorch设计这样的规则主要是为了节点内存或者显存空间,因为几乎所有的时候,用户只会关心他自己直接创建的张量的梯度

所有依赖于叶子节点张量的张量,其requires_grad属性必定是True的,但其梯度值只在计算过程中被用到,不会最终存储到grad属性中

如果需要保留中间计算结果的梯度到grad属性中,可以使用retain_grad方法。如果仅仅是为了调试代码查看梯度值,可以利用register_hook打印日志

import torch

x = torch.tensor(3.0, requires_grad=True)
y1 = x + 1
y2 = 2 * x
loss = (y1-y2) ** 2

loss.backward()

print('loss.grad:', loss.grad)
print('y1.grad:', y1.grad)
print('y2.grad:', y2.grad)
print(x.grad)

"""
loss.grad: None
y1.grad: None
y2.grad: None
tensor(4.)
"""

print(x.is_leaf)
print(y1.is_leaf)
print(y2.is_leaf)
print(loss.is_leaf)

"""
True
False
False
False
"""
# 利用retain_grad可以保留非叶子节点的梯度值,利用register_hook可以查看非叶子节点的梯度值

import torch

# 正向传播
x = torch.tensor(3.0, requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2) ** 2

# 非叶子节点梯度显示控制
y1.register_hook(lambda grad: print('y1 grad:', grad))
y2.register_hook(lambda grad: print('y2 grad:', grad))
loss.retain_grad()

# 反向传播
loss.backward()

print('loss.grad:', loss.grad)
print('x.grad:', x.grad)

"""
y2 grad: tensor(4.)
y1 grad: tensor(-4.)
loss.grad: tensor(1.)
x.grad: tensor(4.)
"""

5.计算图在TensorBoard中的可视化

可以利用torch.utils.tensorboard将计算图导出到TensorBoard进行可视化

from torch import nn

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.randn(2, 1))
        self.b = nn.Parameter(torch.zeros(1, 1))

    def forward(self, x):
        y = x@self.w + self.b
        return y

net = Net()

from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./data/tensorboard')
writer.add_graph(net, input_to_model=torch.rand(10, 2))
writer.close()

%load_ext tensorboard

from tensorboard import notebook
notebook.list()

"""
No known TensorBoard instances running.
"""

notebook.start('--logdir ./data/tensorboard/')

标签:loss,y2,tensor,torch,计算,y1,动态,grad
From: https://www.cnblogs.com/lotuslaw/p/18007800

相关文章

  • springboot之ImportBeanDefinitionRegistrar动态注入
    SpringBoot中的使用在SpringBoot内置容器的相关自动配置中有一个ServletWebServerFactoryAutoConfiguration类。该类的部分代码如下:@Configuration(proxyBeanMethods=false)@AutoConfigureOrder(Ordered.HIGHEST_PRECEDENCE)@ConditionalOnClass(ServletRequest.class)@Con......
  • 计算机体系结构
    计算机体系结构是指计算机系统的设计和组织方式,它包括计算机硬件、软件、数据存储和通信等方面。计算机体系结构的发展经历了多个阶段,从简单的单处理器系统到复杂的多核系统和分布式系统。在现代计算机体系结构中,处理器是计算机系统的核心组件,它负责执行指令和处理数据。处理器的......
  • 【动态规划】最长公共子串
    目录题目应用1:最长公共子串题目解题思路边界条件状态转移代码实现应用2:Leetcode718.最长重复子数组题目解题思路代码实现解题思路方法一:动态规划初始条件状态转移复杂度方法二:滑动窗口复杂度代码实现题目应用1:最长公共子串题目给定两个字符串text1和text2,返回这两个......
  • 第三章——计算机进行小数运算时出错的原因
    在使用小数运算时计算机也会出错,这是因为有些十进制的小数无法转换为二进制数———例如二进制数0.0000对应的十进制数是0,二进制数0.0001对应的十进制数为0.625,由此得之二进制数是连续的而十进制数不是连续的,那十进制数0~0.625之间的数就无法用二进制数表示,进而出现错误。那实际......
  • 安卓动态链接库文件体积优化探索实践
    背景介绍应用安装包的体积影响着用户下载量、安装时长、用户磁盘占用量等多个方面,据GooglePlay统计,应用体积每增加6MB,安装的转化率将下降1%。   安装包的体积受诸多方面影响,针对dex、资源文件、so文件都有不同的优化策略,在此不做一一展开,本文主要记录了在研发时针对动态......
  • Drvsetup.dll 是 Windows 操作系统中的一个动态链接库文件,用于设备驱动程序的安装和配
     Drvsetup.dll是Windows操作系统中的一个动态链接库文件,用于设备驱动程序的安装和配置过程中。该文件通常位于C:\Windows\System32文件夹下。Drvsetup.dll主要负责设备驱动程序的安装和配置过程中的一些核心功能,包括驱动程序的复制、注册、配置和卸载等。在设备驱动程序......
  • 我与计算机
    《我与计算机》以其细腻的笔触、真实的故事,为我们展示了一个与计算机相伴的人生。张淑雅与李文静,两位截然不同的主人公,她们的故事为我们揭示了计算机如何从一个神秘的新鲜事物,转变为生活中不可或缺的一部分。在第三章中,我们跟随张淑雅的脚步,回到了那个初识计算机的时代。计算机对......
  • drvstore.dll 是 Windows 操作系统中的一个动态链接库文件
    drvstore.dll是Windows操作系统中的一个动态链接库文件,用于存储和管理设备驱动程序的信息。它通常位于系统目录(如C:\Windows\System32)下。drvstore.dll的主要作用是维护设备驱动程序的备份和安装信息,以便在需要时能够快速找到并加载正确的驱动程序。当用户连接新设备或更新设......
  • 对于计算机程序的理解
    计算机程序是指一组计算机指令的集合,它是按照特定顺序排列的指令集合。程序的作用是根据输入数据或条件,经过一系列的计算和处理,输出所需的结果。程序通常可以分为系统程序和应用程序两大类。系统程序是计算机的基本软件,负责管理计算机的硬件资源和应用程序的运行。应用程序是为了......
  • 计算机进行小数运算时出错的原因
    看完第三章之后我知道了运算出错的原因是有一些十进制数的小数无法转换成二进制数还有就是小数是使用浮点数表示,浮点数是指符号尾数基数和指数这四个部分组成浮点数的表示右很多种其中最为普遍的是IEEE标准符号部分是指使用一个数据位来表示数值的符号;位数部分使用的是正则表达......