首页 > 其他分享 >Pytorch相关(第二篇)

Pytorch相关(第二篇)

时间:2024-09-07 11:52:23浏览次数:4  
标签:自定义 grad 梯度 传播 Pytorch input 相关 backward 第二篇

Pytorch自动梯度法,实现自定义向前 向后传播方法

在 PyTorch 中,自定义自动求导的功能可以通过实现继承自 torch.autograd.Function 的类来实现。这允许您定义自己的前向传播(forward)和反向传播(backward)逻辑。下面是如何自定义实现向前和向后传播的详细步骤和示例代码。

自定义 autograd 制作步骤

  1. 创建继承自 torch.autograd.Function 的类。
  2. 实现 forward 方法:计算前向传播并保存任何需要在反向传播中使用的张量。
  3. 实现 backward 方法:计算反向传播,使用 grad_output 计算每个输入的梯度。
  4. 使用自定义的 Function:在训练或评估中使用您的自定义操作。

示例代码:自定义平方函数

下面的示例自定义了一个平方操作的前向和反向传播:

import torch

class MySquareFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        """前向传播"""
        ctx.save_for_backward(input)  # 保存输入张量以用于反向传播
        return input ** 2  # 返回平方结果

    @staticmethod
    def backward(ctx, grad_output):
        """反向传播"""
        input, = ctx.saved_tensors  # 获取保存的输入张量
        grad_input = 2 * input * grad_output  # 计算关于输入的梯度
        return grad_input  # 返回梯度

# 使用自定义的 Function
x = torch.tensor([2.0, 3.0, 4.0], requires_grad=True)  # 创建输入张量
square = MySquareFunction.apply  # 获取自定义函数的引用

# 前向传播
y = square(x)
print("Output:", y)  # 输出: tensor([4.0, 9.0, 16.0], grad_fn=<MySquareFunctionBackward>)

# 计算损失并进行反向传播
loss = y.sum()
loss.backward()  # 计算梯度

# 打印输入的梯度
print("Gradient:", x.grad)  # 输出: tensor([4.0, 6.0, 8.0])

代码解释

  1. 自定义类 MySquareFunction

    • 继承了 torch.autograd.Function
    • 实现了两个静态方法 forward() 和 backward()
  2. 前向传播:

    • 在 forward 方法中,计算输入的平方并将输入张量保存到上下文中,以便在反向传播中使用。
    • 使用 ctx.save_for_backward(input) 保存输入。
  3. 反向传播:

    • 在 backward 方法中,从上下文中获取保存的输入张量。
    • 计算关于输入的梯度公式 d(output)d(input)=2⋅inputd(input)d(output)​=2⋅input。
    • 注意,grad_output 是由后续层传播来的梯度。
  4. 使用自定义的 Function:

    • 创建一个需要计算梯度的张量 x,调用自定义的 square 函数进行前向传播。
    • 计算损失并进行反向传播,调用 loss.backward() 计算输入的梯度。

总结

通过以上示例,您可以看到如何在 PyTorch 中自定义前向和反向传播的逻辑。自定义 torch.autograd.Function 允许您实现复杂的操作和梯度计算,同时保留 PyTorch 的自动求导功能。这种方式在编写新的模型或需要特定行为的操作时尤为有用。您可以根据具体需求修改上述示例,实现自己的自定义操作。

标签:自定义,grad,梯度,传播,Pytorch,input,相关,backward,第二篇
From: https://www.cnblogs.com/lovebay/p/18401507

相关文章

  • Pytorch相关(第三篇)
    torch.nn.Module定义简单神经网络模型在PyTorch中,torch.nn.Module 是构建神经网络的基本构件。每一个用于构建神经网络的类都通常应该继承自 torch.nn.Module。该类提供了许多便利的功能,其中之一就是实现了 __call__ 方法。__call__ 方法的作用__call__ 方法使得 tor......
  • Pytorch相关(第一篇)
    torch.autograd.Function使用方法torch.autograd.Function 是PyTorch提供的一个接口,用于自定义自动求导的操作。通过继承这个类,你能够定义自定义的前向和反向传播逻辑。下面是使用 torch.autograd.Function 的基本步骤以及示例。自定义 Function 的步骤继承 torch.au......
  • rk3566 rk3588 Android11/13 给内置APP添加相关权限,无需手动同意APP权限
    现象:打开APP会跳出权限弹窗,给APP相关权限才能够使用APP。目录1、adb查看logcat2、在SystemUIService.java内给APP添加加权限3、开机自启动APP4、executeCMD函数1、adb查看logcat打开APP,logcat会打印APP包名。我这边包名是com.jhooit.endoscope2、在SystemUIService.......
  • SSM相关面试题
    1Spring1.1什么是SpringIOC和DI?①控制反转(IOC):Spring容器使用了工厂模式为我们创建了所需要的对象,我们使用时不需要自己去创建,直接调用Spring为我们提供的对象即可,这就是控制反转的思想。② 依赖注入(DI):Spring使用JavaBean对象的Set方法或者构造方法为我们在创建......
  • 微服务相关面试题
     1Springboot1.1讲一讲SpringBoot自动装配的原理在SpringBoot项目中的引导类上有一个注解@SpringBootApplication,这个注解是对三个注解进行了封装,分别是: @SpringBootConfiguration @EnableAutoConfiguration @ComponentScan其中@EnableAutoConfiguration是实现自......
  • 五子棋AI:实现逻辑与相关背景探讨(上)
    绪论本合集将详细讲述如何实现基于群只能遗传算法的五子棋AI,采用C++作为底层编程语言本篇将简要讨论实现思路,并在后续的文中逐一展开了解五子棋五子棋规则五子棋是一种经典的棋类游戏,规则简单却充满策略性。游戏在一个19×19的棋盘上进行(也可以使用13×13或15×15的棋盘)。......
  • PyTorch深度学习教程第二章-PyTorch 简介
    文章目录前言一、高层次理解PyTorch二、开始使用PyTorch清单2.1:使用pip安装PyTorch清单2.2:使用conda安装PyTorch清单2.3:一个简单的PyTorch测试程序输出2.1:测试程序的输出结果三、PyTorch的应用四、PyTorch的优点和限制五、PyTorch与TensorFlow的比......
  • C++vector类相关OJ练习
    个人主页:C++忠实粉丝欢迎点赞......
  • 脑机接口定义及相关概念
    1什么是脑机接口脑机接口(Brain-ComputerInterface,简称,BCI)是指一种系统或设备,它通过解码大脑的电生理信号来与外部计算机或设备进行直接的通讯。BCI的目的是在不依赖身体运动的情况下实现大脑与计算机之间的信息交换。2相关概念2.1脑电图(EEG)最常用的脑机接......
  • [Linux][防火墙]Centos7 防火墙相关操作以及 添加开放端口
    1、firewalld的基本使用启动:     systemctl   startfirewalld查看状态: systemctl   statusfirewalld 停止:    systemctl   disablefirewalld禁用:     systemctl   stopfirewalld2.systemctl是CentOS7的服务管理......