首页 > 编程问答 >Jax 抖动 kd-tree 代码需要花费相当长的时间

Jax 抖动 kd-tree 代码需要花费相当长的时间

时间:2024-07-25 08:22:46浏览次数:9  
标签:python scipy jit jax kdtree

我已经把自己陷入了以下情况的困境:

  • 我正在运行一个需要平滑渐变才能工作的优化器,并且我正在使用 Jax 进行自动微分。由于此代码是 Jax jitted,这意味着连接到它的任何内容都必须是 Jax jit 可追踪的。
  • 我需要插入一个函数以与优化器一起使用,但不能使用 Scipy 库,因为它不兼容使用 Jax(有一个 jax.scipy.interpolate.RegularGridInterpolator 实现,但这并不平滑 - 它只支持线性和最近邻插值)。
  • 这意味着我必须编写自己的与 Jax 兼容的平滑插值器,我基于 Scipy RBFInterpolator 代码。这个的实现非常好 - 它使用 kd 树来查找空间中查询点的最近邻居,然后使用它们构建局部插值。这意味着我还需要编写一个与 Jax 兼容的 kd-tree 类(Scipy 也与 Jax 不兼容),我已经完成了。

问题来自于 jit 编译 kd-tree 代码。我以“标准方式”编写了它,使用树节点对象和子节点 left right 节点字段。在叶节点,这些字段具有 None 值来表示没有子节点。

代码运行并且功能正确,但是即时编译需要很长时间:对于 64 个坐标的树需要 72 秒,131 343 个坐标的秒数,...我的预期数据集有超过 1400 万个点。我认为 Jax 在内部正在追踪树中每一条可能的路径,这就是为什么它花了这么长时间。结果是它的速度非常快:kd-tree 10 点检索需要 0.0075 秒,而对所有点(343 个点)进行强力搜索则需要 0.4 秒。这些是我希望在优化器中使用的速度(如果不抖动它会太慢)。然而,如果编译时间继续像经验一样增长,这似乎是不可能的。

我认为问题可能出在树的结构上,需要存储许多不同的对象,所以也实现了kd 树搜索算法,其中树由一组 Jax-numpy 数组表示(例如 coord value left right ;其中每个索引对应于树中的一个点)和迭代而不是使用递归来进行树搜索(这是一个挑战,但它有效!)。然而,将其转换为与 jit 一起使用(将 if 语句更改为 jax.lax.cond )将会很复杂,在开始之前我想知道它是否值得 - 我肯定会遇到同样的问题:Jax 会跟踪树的所有分支,直到达到“空终止符”( left right 数组中的-1值),并且仍然需要很长时间来编译。我一直在研究像 jax.lax.while_loop 这样的结构,以防它们有帮助吗?

(我还编写了两种方法的混合体,使用基于数组的树和基于递归的算法。在这种情况下跟踪进入无限循环,我认为因为空终止符是 -1 而不是 None 但数组应该是静态的(它们在构造后不会改变,并且属于标记为的对象)。静态输入),所以也许解决方案就在于此,而我做错了。)

我想知道我是否在做任何明显错误的事情(或者我的理解是否错误),以及是否有我能做些什么来加快速度?当有如此多的代码路径需要跟踪时,编译时间会如此之长是否是意料之中的?我不认为我什至只能构建一次 jitted 函数然后保存它?

我担心唯一的解决方案可能是重写优化器代码,以便它不使用 Jax(例如,如果我硬- 对导数进行编码,并重写一些代码,以便它直接对数组进行操作,而不是在输入之间进行矢量化)。

代码可在此处获取: https://github.com/FluffyCodeMonster/jax_kd_tree

所描述的所有三种变体均已给出:具有递归的基于节点的树、具有迭代的基于数组的树以及具有递归的基于数组的树。前者可以工作,但随着树中点数的增加,jit 编译速度非常慢;第二个也可以工作,但还没有以可 jit 的方式编写。最后一个被编写为即时编译,但无法即时编译,因为它进入无限递归。

非常感谢您的帮助。我真的需要立即开始工作,以便获得优化结果:|


你完全正确地诊断了这个问题:Jax 的追踪 JIT 编译器正在展开你的 KD 树的所有可能路径,导致编译时间随着点数的增加呈指数级增长。这是 Jax 的一个已知问题,因为它难以优化涉及动态控制流的数据结构,例如 KD 树。

你现在不想听到的是, 没有一种简单的方法可以让基于节点的或基于数组的递归 KD 树实现与 Jax 的 JIT 一起快速编译。 Jax 的 while_loop 结构在这里也不会有帮助,因为它们仍然涉及展开和跟踪循环的每次迭代。

但是,不要绝望!你有几个选择:

  1. 探索替代插值方法:

  2. 考虑不那么平滑的插值: 线性或最近邻插值可能就足够了,特别是如果你的优化问题对插值误差不是很敏感。这些方法可以使用 Jax 的 jax.scipy.interpolate.RegularGridInterpolator 进行高效实现。

  3. 使用基于网格的插值: 如果你可以使用规则网格,你可以使用线性或更高阶插值,这些插值可以使用 Jax 数组和操作进行高效实现。

  4. 调查其他 Jax 可微插值库: 可能存在其他 Jax 可微插值库提供更适合你需求的平滑插值方法。

  5. 优化你的 KD 树实现以实现 Jax JIT:

  6. 最小化控制流: 尝试重构你的 KD 树搜索算法以最小化控制流语句(例如, if 语句)。这可以通过使用向量化操作和按掩码数组切片来实现。

  7. 考虑固定深度 KD 树: 如果你的点数是固定的并且你知道,则可以使用固定深度的 KD 树,它可以更容易地使用 Jax JIT 进行编译。

  8. 使用 Jaxpr 转换手动优化: 这是一个更高级的选择,涉及直接操作 Jax 的中间表示 (Jaxpr)。它允许你手动优化编译后的代码,但需要深入了解 Jax 的内部工作原理。

  9. 将 KD 树搜索与 Jax 优化分离:

  10. 在 CPU 上执行 KD 树搜索: 你可以使用 Scipy 等库在 CPU 上执行 KD 树搜索,并仅将插值结果传递给 Jax 进行梯度计算。这可能会导致一些性能开销,但这取决于你的特定应用程序,这可能是可以接受的。

  11. 使用自定义 JAX 原语: 你可以编写一个自定义 JAX 原语,该原语使用 XLA 或其他方法在 CPU 或 GPU 上高效执行 KD 树搜索。这需要更高级的 JAX 知识,但它可以提供最佳的性能。

我理解你需要快速获得结果,所以我建议你首先探索替代插值方法,因为它们通常更容易实现。如果这些方法不可行,则可以尝试优化你的 KD 树实现以实现 Jax JIT 或考虑将 KD 树搜索与 Jax 优化分离。

请记住,Jax 仍在积极开发中,因此其处理复杂数据结构和算法的能力在将来可能会得到改善。

标签:python,scipy,jit,jax,kdtree
From: 78791013

相关文章