首页 > 其他分享 >DPVO 代码剖析

DPVO 代码剖析

时间:2024-09-11 15:15:25浏览次数:1  
标签:patches 代码 ctx DPVO 剖析 coords patchify radius net

来自:https://github.com/princeton-vl/DPVO/blob/c0c5a104c9c58663aa9be62c3f125d5b52874f3e/dpvo/altcorr/correlation.py#L33

class PatchLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, net, coords, radius):
        """ forward patchify """
        ctx.radius = radius
        ctx.save_for_backward(net, coords)
        
        patches, = cuda_corr.patchify_forward(net, coords, radius)
        return patches

    @staticmethod
    def backward(ctx, grad):
        """ backward patchify """
        net, coords = ctx.saved_tensors
        grad, = cuda_corr.patchify_backward(net, coords, grad, ctx.radius)

        return grad, None, None

def patchify(net, coords, radius, mode='bilinear'):
    """ extract patches """

    patches = PatchLayer.apply(net, coords, radius)

    if mode == 'bilinear':
        offset = (coords - coords.floor()).to(net.device)
        dx, dy = offset[:,:,None,None,None].unbind(dim=-1)

        d = 2 * radius + 1 # 计算了特征块的大小
		# 计算了四种加权组合(类似双线性插值的原则)。
        x00 = (1-dy) * (1-dx) * patches[...,:d,:d]
        x01 = (1-dy) * (  dx) * patches[...,:d,1:]
        x10 = (  dy) * (1-dx) * patches[...,1:,:d]
        x11 = (  dy) * (  dx) * patches[...,1:,1:]

        return x00 + x01 + x10 + x11 # 返回这些组合的加和,得到插值后的特征块。

    return patches # 直接返回未插值的 patches

问题1: pytorch 如何自定义操作?

torch.autograd.Function: 是 PyTorch 中创建自定义操作的基本类。需要重写 forwardbackward 方法来实现前向和反向传播的逻辑。

上下文对象 ctx: 用于在前向传播中保存反向传播需要的数据。通过 ctx.save_for_backward 保存张量,反向传播时通过 ctx.saved_tensors 取出。

CUDA 操作 cuda_corr.patchify_forwardcuda_corr.patchify_backward: 这两个函数在代码中并未定义,它们是自定义的 CUDA 扩展,用于执行高效的 patch 提取和梯度计算。

在这段代码中,PatchLayer 继承自 torch.autograd.Function,并定义了自定义的前向和反向传播逻辑。这类函数可以用来创建不使用标准 torch.nn.Module 的自定义操作。

具体解析如下:

1. PatchLayer

PatchLayer 类主要实现了一个从网络中提取 patch(图像小块)的自定义操作。它包括 forwardbackward 两个静态方法。

  • 重载 forward 静态方法:

    • ctx 是 PyTorch 提供的上下文对象,用于保存反向传播需要的信息。
    • net 表示输入的特征图。
    • coords 是提取 patch 的中心坐标。
    • radius 表示从中心点开始提取的 patch 半径,决定了 patch 的大小。
    • ctx.radius = radiusctx.save_for_backward(net, coords) 是在前向传播中保存必要的信息,以便在反向传播时使用
    • patches, _ = cuda_corr.patchify_forward(net, coords, radius) 是调用 自定义的 CUDA 操作 来进行 patch 提取。只接收函数返回值中的第一个(假定返回的是一个元组或列表)。这是一个CUDA实现的函数,用于高效地提取特征块。
  • 重载 backward 静态方法:

    • ctx.saved_tensors 是在 forward 方法中保存的 netcoords
    • grad 是从上层损失反向传播来的梯度。
    • cuda_corr.patchify_backward 用于计算反向传播的梯度。
    • 返回的是 grad 与多个 None,表示只有输入的第一个参数有梯度,其他参数没有

cuda_corr.patchify_forward 是 C++ CUDA 实现的。

PatchLayer 类允许开发者在前向和反向传播中插入自定义的 CUDA 操作,从而适应特定的计算需求,同时仍然能够利用 PyTorch 的自动微分能力。

2. patchify 函数

这个函数是用来从输入的 net 张量中提取 patch(小块图像)的。

  • PatchLayer.apply(net, coords, radius): 调用自定义的 PatchLayer,提取 patch。返回的 patches 是从 net 中提取的 patches。 apply 是 PyTorch 自定义函数调用的固定方式,负责调用 forward 方法。

  • 双线性插值:
    如果 mode 设置为 'bilinear',会对提取的 patch 进行双线性插值操作。这部分代码主要处理:

    • coords - coords.floor() 计算了坐标相对于整数网格点的偏移,用来决定插值的比例。
    • dx, dy 将偏移分成 x 和 y 两个方向。
    • x00, x01, x10, x11 分别表示在 4 个邻近点的插值计算结果。
    • 最终返回加权后的结果,完成双线性插值。
  • 返回值: 如果 mode'bilinear',返回双线性插值后的 patch;否则直接返回从 PatchLayer 提取的 patches。

代码工作流程

  1. patchify 函数被调用,传入 netcoordsradius 和插值模式。
  2. 如果 mode='bilinear',则进行双线性插值,最终得到插值后的 patch。
  3. 否则,直接返回提取到的 patches。
  4. forwardbackward 方法定义了如何进行前向和反向传播。

总结:这段代码的主要功能是在网络中提取指定半径(radius)的特征图块。这在深度学习的特征匹配、局部特征提取等任务中很有用。通过自定义的 PatchLayer,实现了一个前向(和对应的反向)CUDA操作,使得该操作可以嵌入到PyTorch的自动微分系统中。

参考资料:

  1. 定义torch.autograd.Function的子类,自己定义某些操作,且定义反向求导函数
  2. Extending PyTorch
  3. PyTorch: Defining new autograd functions

标签:patches,代码,ctx,DPVO,剖析,coords,patchify,radius,net
From: https://www.cnblogs.com/odesey/p/18408270

相关文章

  • QueryWrapper介绍、应用场景和示例代码
    概述QueryWrapper是MyBatis-Plus提供的一个用于构建SQL查询条件的工具类。它简化了查询条件的构建,使得编写复杂的查询变得更加直观和简洁。详细介绍QueryWrapper是MyBatis-Plus框架中的一个类,旨在帮助开发者构建动态SQL查询。它可以用来指定查询条件、排序、分页......
  • 代码整洁之道--读书笔记(7)nz
    合集-读书笔记(7)1.代码整洁之道--读书笔记(2)09-052.代码整洁之道--读书笔记(1)09-043.代码整洁之道--读书笔记(3)09-06:蓝猫机场4.代码整洁之道--读书笔记(4)09-075.代码整洁之道--读书笔记(5)09-086.代码整洁之道--读书笔记(6)09-097.代码整洁之道--读书笔记(7)09-10收起代......
  • Java中的元编程:使用反射与代理模式实现代码的动态增强
    Java中的元编程:使用反射与代理模式实现代码的动态增强大家好,我是微赚淘客返利系统3.0的小编,是个冬天不穿秋裤,天冷也要风度的程序猿!在Java开发中,元编程是指在程序运行时对程序进行修改和扩展的技术。反射和代理模式是实现Java元编程的两种常用技术。本文将探讨如何使用反射与代理......
  • 数学建模之BP神经网络+函数代码解释
    神经网络原理~大样本数据-分类/预测~几百个是小样本神经网络——最易懂最清晰的一篇文章-CSDN博客误差大:Matlab中newff函数使用方法和搭建BP神经网络的方法_newff函数用法-CSDN博客net=newff(PR,[S1,S2],{'tansig','purelin'},'traingd')函数 newff:构建BP神经网络PR:训练......
  • 【04】深度学习——训练的常见问题 | 过拟合欠拟合应对策略 | 过拟合欠拟合示例 | 正
    深度学习1.常见的分类问题1.1模型架构设计1.2万能近似定理1.3宽度or深度1.4过拟合问题1.5欠拟合问题1.6相互关系2.过拟合欠拟合应对策略2.1问题的本源2.2数据集大小的选择2.3数据增广2.4使用验证集2.5模型选择2.6K折交叉验证2.7提前终止3.过拟合欠拟合示例3.1导入库3.2......
  • 通义灵码用户说:“人工编写测试用例需要数十分钟,通义灵码以毫秒级的速度生成测试代码,且
    通过一篇文章,详细跟大家分享一下我在使用通义灵码过程中的感受。一、定义通义灵码,是一个智能编码助手,它基于通义大模型,提供代码智能生成、研发智能问答能力。在体验过程中有任何问题均可点击下面的连接前往了解和学习。通义灵码官网通义灵码安装教程通义灵码产品手册......
  • 基于Java+Vue+Mysql的人力资源管理系统:简单易用,高效协同(项目代码)
    前言:eHR(ElectronicHumanResources)人力资源管理系统是一个综合性的软件平台,用于管理组织的人力资源相关的各种活动和数据。该系统可以显著提高人力资源部门的工作效率,确保数据准确性和一致性,同时提供决策支持。以下是eHR人力资源管理系统的六个主要模块及其功能的简要介绍:......
  • SpringBoot+Neo4j+Vue+Es集成ES全文检索、Neo4J知识图谱、Activiti工作流的知识库管理
    在数字化高度普及的时代,企事业机关单位在日常工作中会产生大量的文档,例如医院制度汇编,企业知识共享库等。针对这些文档性的东西,手工纸质化去管理是非常消耗工作量的,并且纸质化查阅难,易损耗,所以电子化管理显得尤为重要。【springboot+elasticsearch+neo4j+vue+activiti】实现数......
  • 【高级编程】认识Java多线程 代码举例三种创建线程的方式
    文章目录主线程创建线程方式1:Thread方式2:Runnable方式3:Callable进程:应用程序的执行实例,有独立的内存空间和系统资源线程:CPU调度和分派的基本单位,进程中执行运算的最小单位,可完成一个独立的顺序控制流程多线程:如果在一个进程中同时运行了多个线程,用来完成不同的工......
  • 终于有人说清楚了基于大模型的Agent进行任务规划的10种方式(附代码和论文)
    在OpenAIAI应用研究主管LilianWeng的博客**《大语言模型(LLM)支持的自主式代理》**[1]中,将规划能力视为关键的组件之一,用于将任务拆解为更小可管理的子任务,这对有效可控的处理好更复杂的任务效果显著。基于大语言模型(LLM)的自主代理组成人是如何做事的?在日常工作中,我......