首页 > 其他分享 >CUDA教学(2):反向传播

CUDA教学(2):反向传播

时间:2024-05-28 20:11:08浏览次数:31  
标签:dL interp 传播 points 反向 CUDA 教学 feats cuda

cuda 没有提供自动求导机制,因此我们需要完成以下两步,实现反向传播。

image

一、计算所有 trainable 参数的偏微分

判断哪些参数是 trainable 的?

本例中,输入 f 的坐标是固定的,所以 uvw 的值也是固定的,因此只需要求 feats_interp对各个顶点的特征量 \(f_i\) 的偏微分即可。

如何进行反向传播?

思路:先计算正向传播的 Loss 值,然后对各个顶点的特征量 \(f_i\) 求偏微分。

image

也就是利用了“链式法则”进行计算,但是必须要知道 Loss 对于 feats_interp 的偏导结果,因此我们自己实现反向传播函数需要传入这个参数。

二、代码实现

实现反向传播函数的注意事项?

(1)求偏导得到的维度和输入维度保持一致,因此 dL_dfeats 的维度是 [N,8,F]dL_dfeats_interp 的维度是 [N,F]

(2)需要把前向传播和反向传播函数都包裹在新类中,它是 torch.autograd.Function 的子类,如下面代码所示:

class Trilinear_interpolation_cuda(torch.autograd.Function):
    @staticmethod
    def forward(ctx, feats, points):# ctx 即 context 保存了传播过程中的状态量
        feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points)

        ctx.save_for_backward(feats, points)

        return feat_interp

    @staticmethod
    def backward(ctx, dL_dfeat_interp):
        feats, points = ctx.saved_tensors

        dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points)

        # forward 输入有几个参数,这里就要回传几个参数,如果没有则写 None
        # 我们这里求的是对 feats 的偏导,所以写在第一个位置
        return dL_dfeats, None

最后的返回值和前向传播的参数保持一致,如果不需要求偏导则需要对应写为 None。

(3)传入的参数都要添加求导支持。

N = 65536; F = 256
rand = torch.rand(N, 8, F, device='cuda')
feats = rand.clone().requires_grad_()

(4)调用:前向传播需要使用 apply 方法,后向传播直接调用:

# 1. 前向传播
out_cuda = Trilinear_interpolation_cuda.apply(feats, points)

# 2. 计算 Loss(只是简单的加起来作为损失)
loss = out_cuda.sum()

# 2. Pytorch 会自动计算 dL_dfeat_interp,也就是 Loss 关于 out_cuda 的梯度,传递给 backward 函数作为参数
loss.backward()

标签:dL,interp,传播,points,反向,CUDA,教学,feats,cuda
From: https://www.cnblogs.com/7ytr5/p/18218756

相关文章

  • AI绘画整合包最新Stable Diffusion安装包+教程+模型+插件+动作来了(纯教学)
    首先了解一下AI绘画工具,介绍一下什么是StableDiffusion,模型的主要功能和作用StableDiffusion(简称SD),是一种先进的人工智能技术。这项技术的核心能力在于,它能够根据用户提供的文字描述,生成丰富且细致的图像内容。不仅如此,SD还能够处理图像修补、扩展以及基于文本指导的图像转......
  • ChatGPT-Next-Web一键部署搭建教学:Github开源+Vercel+API 快速部署
    ChatGPT-Next-Web一键部署搭建教学:Github开源+Vercel+API快速部署文章目录ChatGPT-Next-Web一键部署搭建教学:Github开源+Vercel+API快速部署导语:需要用到的链接汇总1、github项目直达地址2、vercel服务器直达地址3、三方API获取一、Github项目`star`+Vercel......
  • ASP+ACCESS教学评估系统
    摘 要:本文从计算机系的实际情况出发,经过对计算机系本科评估事项的一番考察和分析,确立了计算机系本科评估网站具体实现功能。并阐述网站的结构设计和功能设计,实现用户的分类显示、最近新闻的提示、留言板功能等。管理员用户可以通过Web浏览器,以人机交互式的客户端程序实现对本......
  • 使用 Unity Barracuda 和 Compute Shader,Yolov2 进行高效物体识别
    前言通过整合UnityBarracuda和TinyYOLOv2模型,开发者可以在Unity中实现高效的实时物体识别功能。这种技术不仅可以增强游戏和应用的交互性,还可以应用于虚拟现实(VR)和增强现实(AR)等创新项目中,为用户创造更加沉浸和动态的体验。TinyYOLOv2模型概述TinyYOLOv2是YOLO(You......
  • ubuntu24.04安装cuda12.5版本
    概述最近新学习的JAX在使用时,提示:2024-05-2619:46:32.016388:Wexternal/xla/xla/service/gpu/nvptx_compiler.cc:760]TheNVIDIAdriver'sCUDAversionis12.2whichisolderthantheptxasCUDAversion(12.5.40).Becausethedriverisolderthantheptxasvers......
  • 【C语言】C语言零基础纯干货教学(下)
    个人主页~C语言零基础纯干货教学(上)C语言零基础纯干货教学(中)C语言入门四、数组1、概念2、一维数组(1)一维数组创建(2)数组的初始化3、一维数组的使用(1)访问下标(2)数组输入和打印4、一维数组在内存中的存储5、sizeof计算数组元素个数6、二维数组(1)概念(2)二维数组的创建7、......
  • 2024电工杯数学建模B题Python代码+结果表数据教学
    2024电工杯B题保姆级分析完整思路+代码+数据教学B题题目:大学生平衡膳食食谱的优化设计及评价 以下仅展示部分,完整版看文末的文章importpandasaspddf1=pd.read_excel('附件1:1名男大学生的一日食谱.xlsx')df1#获取所有工作表名称excel_file=pd.ExcelFile('附件1......
  • 【教学类-58-04】黑白三角拼图04(2-10宫格,每个宫格随机1张-6张,带空格纸)
    背景需求:前期制作了黑白三角拼图2*2、3*3、4*4,确定了基本模板,就可以批量制作更多格子数【教学类-58-01】黑白三角拼图01(2*2宫格)固定256种+随机抽取10张-CSDN博客文章浏览阅读522次,点赞13次,收藏16次。【教学类-58-01】黑白三角拼图01(2*2宫格)固定256种+随机抽取10张https://bl......
  • 计算机毕业设计项目推荐,82131基于SSM的流浪动物救助网站的设计与实现(开题答辩+程序定
    SSM流浪动物救助网站摘要随着生活水平的持续提高和家庭规模的缩小,宠物已经成为越来越多都市人生活的一部分,随着宠物的增多,流浪的动物的日益增多,中国的流浪动物领养和救助也随之形成规模,同时展现巨大潜力。本次系统的是基于SSM框架的流浪动物救助网站管理系统,平台用户可以......
  • (免费领取源码)计算机毕业设计项目:07558基于Python的校园宿舍(开题答辩+程序定制+全套文
    摘要本论文主要论述了如何使用django开发一个校园宿舍管理系统,本系统将严格按照软件开发流程进行各个阶段的工作,采用B/S架构,面向对象编程思想进行项目开发。在引言中,作者将论述校园宿舍管理系统的当前背景以及系统开发的目的,后续章节将严格按照软件开发流程,对系统进行各......