JAX 中文文档(九)
使用jax.checkpoint
控制自动微分的保存数值(又名jax.remat
)
原文:
jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html
import jax
import jax.numpy as jnp
简而言之
使用jax.checkpoint
装饰器(别名为jax.remat
),结合jax.grad
来控制前向传播时保存哪些中间值,以及在反向传播时重新计算哪些中间值,从而在内存和 FLOP 之间进行权衡。
不要错过关于jax.checkpoint
如何与jax.jit
交互的实用说明。
如果不使用jax.checkpoint
,jax.grad(f)(x)
的前向传播将保存雅可比系数和其他中间值以供后向传播使用。我们称这些保存的值为残差:
def g(W, x):
y = jnp.dot(W, x)
return jnp.sin(y)
def f(W1, W2, W3, x):
x = g(W1, x)
x = g(W2, x)
x = g(W3, x)
return x
W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)
# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
通过对子函数应用jax.checkpoint
,无论是作为装饰器还是在特定的应用站点,我们都强制 JAX 不保存该子函数的任何残差。相反,只有jax.checkpoint
装饰的函数的输入可能会被保存,并且在反向传播时从这些输入重新计算任何消耗的残差:
def f2(W1, W2, W3, x):
x = jax.checkpoint(g)(W1, x)
x = jax.checkpoint(g)(W2, x)
x = jax.checkpoint(g)(W3, x)
return x
jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
这里保存了两个sin
应用的值,因为它们是jax.checkpoint
装饰的g
函数后续应用的参数,并且jax.checkpoint
装饰的函数的输入可能会被保存。但没有保存任何cos
应用的值。
要控制哪些值可保存,而无需编辑要区分的函数的定义,您可以使用重新材料化策略。以下是一个例子,仅保存没有批次维度的dot
操作的结果(因为它们通常是 FLOP 限制的,因此值得保存而不是重新计算):
f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
您还可以使用策略来引用使用jax.ad_checkpoint.checkpoint_name
命名的中间值:
from jax.ad_checkpoint import checkpoint_name
def f4(W1, W2, W3, x):
x = checkpoint_name(g(W1, x), name='a')
x = checkpoint_name(g(W2, x), name='b')
x = checkpoint_name(g(W3, x), name='c')
return x
f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x)
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4)
在玩弄这些玩具示例时,我们可以使用在此笔记本中定义的print_fwd_bwd
实用程序更详细地了解正在进行的操作:
from jax.tree_util import tree_flatten, tree_unflatten
from rich.console import Console
from rich.table import Table
import rich.text
def print_fwd_bwd(f, *args, **kwargs) -> None:
args, in_tree = tree_flatten((args, kwargs))
def f_(*args):
args, kwargs = tree_unflatten(in_tree, args)
return f(*args, **kwargs)
fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr
y, f_vjp = jax.vjp(f_, *args)
res, in_tree = tree_flatten(f_vjp)
def g_(*args):
*res, y = args
f_vjp = tree_unflatten(in_tree, res)
return f_vjp(y)
bwd = jax.make_jaxpr(g_)(*res, y).jaxpr
table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
table.add_row("[bold green]forward computation:",
"[bold green]backward computation:")
table.add_row(rich.text.Text.from_ansi(str(fwd)),
rich.text.Text.from_ansi(str(bwd)))
console = Console(width=240, force_jupyter=True)
console.print(table)
def _renderable_repr(self):
return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x)
forward computation: backward computation:
{ lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4]
e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[7]. let
f:f32[5] = sin e k:f32[7] = mul j a
g:f32[5] = cos e l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c
h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b
i:f32[6] = sin h n:f32[6] = mul l d
j:f32[6] = cos h o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f
k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e
l:f32[7] = sin k q:f32[5] = mul o g
m:f32[7] = cos k r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i
in (l, m, i, c, j, f, b, g, d, a) } s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h
in (s, p, m, r) }
# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x)
forward computation: backward computation:
{ lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let
e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[
f:f32[5] = sin e differentiated=True
g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]
h:f32[6] = sin g s:f32[4] t:f32[7]. let
i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h u:f32[5] = sin m
j:f32[7] = sin i v:f32[5] = cos m
in (j, e, g, i, a, b, c, d) } w:f32[6] = sin n
x:f32[6] = cos n
y:f32[7] = cos o
z:f32[7] = mul t y
ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r
bb:f32[6] = mul ba x
bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q
bd:f32[5] = mul bc v
be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p
bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s
bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u
bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w
in (bf, bg, bh, be) }
policy=<function dot_with_no_batch_dims at 0x7f5e469b1700>
prevent_cse=True
] a b c d e f g h
in (i, j, k, l) }
让我们一步一步地思考
您可能希望首先(重新)阅读自动微分手册第一部分。
jax.checkpoint
的基础知识
在jax.linearize
和jax.vjp
中,如何以及何时计算某些值有灵活性。不同的选择可以在内存使用和 FLOP 之间进行权衡。JAX 通过jax.checkpoint
提供了对这些选择的控制。
其中之一是在前向传播时执行雅可比系数计算,即在输入可用时立即进行,或者在反向传播时,在需要系数之前进行。考虑sin_vjp
的例子:
def sin_vjp(x):
y = jnp.sin(x)
cos_x = jnp.cos(x)
return y, lambda y_bar: cos_x * y_bar
在反向传播时,另一种有效的实现方式是计算jnp.cos(x)
的值,而不是在前向传播时:
def sin_vjp2(x):
y = jnp.sin(x)
return y, lambda y_bar: jnp.cos(x) * y_bar
对于这个特定的函数,两个版本使用的内存量是相同的,尽管我们减少了原始计算的 FLOP 并增加了余切计算的 FLOP。
当涉及函数组合时,我们还有另一种选择。回顾我们的两个函数组合的 VJP 规则:
def f(x):
y = g(x)
z = h(y)
return z
def f_vjp(x):
y, g_vjp = jax.vjp(g, x)
z, h_vjp = jax.vjp(h, y)
def f_bwd(z_bar):
y_bar, = h_vjp(z_bar)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd
另一种选择是:
def f_vjp_checkpoint(x):
y = g(x)
z, h_vjp = jax.vjp(h, y)
def f_bwd2(z_bar):
y_bar, = h_vjp(z_bar)
_, g_vjp = jax.vjp(g, x)
x_bar, = g_vjp(y_bar)
return x_bar
return z, f_bwd2
换句话说,这种替代实现不会在前向传播中计算g_vjp
或其闭包中的残差值。而是只在后向传播f_bwd2
中计算它们。这意味着f_vjp_checkpoint
需要更少的内存:如果g
和h
每个都需要类似量级的内存来存储其残差,远大于x
,那么由f_vjp_checkpoint(x)
生成的函数所需的内存量仅为f_vjp(x)
的一半!
我们所付出的代价是冗余工作:在f_bwd2
中,我们必须重新评估g(x)
作为jax.vjp(g, x)
的一部分,只是为了丢弃它的值(在下划线变量的行中_, g_vjp = jax.vjp(g, x)
)。
我们可以在自动微分中实现这种 VJP 行为,而不必直接编写 VJP 函数,而是通过在原始函数f
的另一种定义中使用jax.checkpoint
来实现:
def f_checkpoint(x):
y = jax.checkpoint(g)(x)
z = h(y)
return z
换句话说,我们将jax.checkpoint
应用于f
的第一阶段g
,而不是f
本身。这样,当我们评估jax.grad(f_checkpoint)(x)
时,我们会得到如下计算:
-
运行
g
的前向传播,丢弃残差值; -
运行
h
的前向传播,保存残差; -
运行
h
的后向传播,使用步骤 2 中的残差; -
重新运行
g
的前向传播,保存残差; -
运行
g
的后向传播,使用步骤 4 中的残差。
换句话说,通过评估jax.grad(f_checkpoint)(x)
,我们会得到与如下计算相同的结果:
def f_checkpoint_grad(x):
y = g(x) # step 1
_, h_vjp = jax.vjp(h)(y) # step 2
y_bar, = h_vjp(1.0) # step 3
_, g_vjp = jax.vjp(g, x) # step 4
x_bar, = g_vjp(y_bar) # step 5
return x_bar
通常情况下,jax.checkpoint(foo)
是一个新函数,其输入输出行为与foo
相同,但在自动微分下行为不同,特别是在jax.linearize
和jax.vjp
(以及它们的包装器,如jax.grad
)中,但不包括jax.jvp
。在求导时,只有经过jax.checkpoint
的函数的输入会在前向传播时存储;在后向传播时,会重新计算残差(即来自foo
及其雅可比系数值的中间值,这些值在后向传播时需要重新计算)。
注意,如果f = lambda x: h(g(x))
是我们想要求导的函数,即如果我们想应用jax.grad(f)
,那么对f
本身应用jax.checkpoint
不会节省任何内存。这是因为评估jax.grad(jax.checkpoint(f))(x)
会导致如下计算:
-
运行前向传播,丢弃所有残差;
-
立即重新运行前向传播,保存残差;
-
运行后向传播,使用步骤 2 中的残差。
换句话说,代码中我们会有类似这样的东西:
def f_grad_bad(x):
_ = f(x) # step 1
_, f_vjp = jax.vjp(f, x) # step 2
x_bar, = f_vjp(1.0) # step 3
return x_bar
如果对h
的第二阶段应用jax.checkpoint
,我们也不会获得任何内存节省。这是因为评估jax.grad(lambda x: jax.checkpoint(h)(g(x)))
会导致如下计算:
-
运行
g
的前向传播,保存残差; -
运行
h
的前向传播,丢弃残差; -
立即重新运行
h
的前向传播,保存残差; -
运行
h
的后向传播,使用步骤 3 中的残差; -
运行
g
的后向传播,消耗步骤 1 中的剩余项。
这样,在代码中,我们会有类似以下的内容:
def f_grad_bad2(x):
y, g_vjp = jax.vjp(g, x) # step 1
z = h(y) # step 2
_, h_vjp = jax.vjp(h, y) # step 3
y_bar, = h_vjp(1.0) # step 3
x_bar, = g_vjp(y_bar) # step 5
return x_bar
稍微更一般地说,如果我们有一个函数链组合,如f = lambda x: f3(f2(f1(x)))
,并且我们有兴趣评估jax.grad(f)
,我们可以说:
-
我们不应将
jax.checkpoint
应用于整个函数f
,因为这不会节省任何内存(并且会执行浪费的重新计算); -
我们不应将
jax.checkpoint
应用于最后一个子函数f3
,因为这不会节省任何内存(并且会执行浪费的重新计算); -
我们可以将
jax.checkpoint
应用于f1
、f2
或它们的组合lambda x: f2(f1(x))
,因为这些任意一个都可能节省内存,并且会表达不同的内存/重新计算折衷。
什么可以保存的自定义策略
到目前为止所展示的,使用jax.checkpoint
会从一个极端切换到另一个:
-
没有
jax.checkpoint
,JAX 的自动微分倾向于在前向传播中计算尽可能多的内容,并为后向传播存储它; -
使用
jax.checkpoint
装饰器,我们在前向传播中尽量少计算,并根据需要在后向传播中重新计算值。
要在这两个极端之间操作,保存某些东西而不保存其他东西,我们可以在子函数上谨慎地放置jax.checkpoint
装饰器。但这需要编辑要求微分的函数,例如模型代码,这可能不方便。也很难对变体进行实验。
因此,一个替代方法是使用jax.checkpoint
的policy
参数。策略是一个可调用对象(即一个函数),它以一种类型级别的原始应用规范作为输入,并返回一个布尔值,指示是否允许将相应的输出值保存为剩余项(或者必须在(共)切向计算中根据需要重新计算)。为了编写健壮的代码,应从jax.checkpoint_policies
的属性中选择策略,例如jax.checkpoint_policies.dots_with_no_batch_dims_saveable
,因为编写自定义策略可调用对象的 API 被认为是内部的。
例如,考虑要微分的这个函数:
def loss(params, x, y):
return jnp.sum((predict(params, x) - y)**2)
def predict(params, x):
*Ws, Wlast = params
for W in Ws:
x = layer(W, x)
x = jnp.dot(Wlast, x)
return x
def layer(W, x):
return jnp.sin(jnp.dot(W, x))
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4)
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
而不是在前向传播中保存这么多值,也许我们只想保存没有批处理维度的矩阵乘法结果(因为它们可能是 FLOP 而不是内存绑定)。我们可以使用策略jax.checkpoint_policies.dots_with_no_batch_dims_saveable
来实现这一点:
loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict)
还要注意,通过提供一个策略,我们无需编辑定义loss
、predict
或layer
的代码。如果我们希望在调用代码(例如训练脚本)中进行策略实验而不更改库代码(例如神经网络库),这特别方便。
一些策略可以引用名为jax.ad_checkpoint.checkpoint_name
的值:
from jax.ad_checkpoint import checkpoint_name
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
x = layer(W, x)
x = checkpoint_name(x, name=f'layer{i}_output')
x = jnp.dot(Wlast, x)
return x
单独看,checkpoint_name
只是一个身份函数。但因为某些策略函数知道如何查找它们,我们可以使用这些名称来控制 checkpoint_name
输出的某些值是否被视为可保存的:
print_saved_residuals(loss, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss)
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y)
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
另一个涉及名称的策略是 jax.checkpoint_policies.save_only_these_names
。
某些策略包括:
-
everything_saveable
(默认策略,就像根本没有使用jax.checkpoint
一样) -
nothing_saveable
(即重新生成所有内容,就像根本没有使用自定义策略一样) -
dots_saveable
或其别名checkpoint_dots
-
dots_with_no_batch_dims_saveable
或其别名checkpoint_dots_with_no_batch_dims
-
save_anything_but_these_names
(保存任何值,但不包括具有给定名称的checkpoint_name
输出) -
save_any_names_but_these
(仅保存命名值,即checkpoint_name
的任何输出,但不包括给定名称) -
save_only_these_names
(仅保存命名值,并且仅限于给定的名称)
策略仅指示可保存的内容;只有在反向传播实际需要时才会保存值。
高级:递归的 jax.checkpoint
通过适当地应用 jax.checkpoint
,可以表达许多内存使用和(重新)计算之间的权衡。一个令人惊讶的例子是 递归 检查点处理,在这种情况下,我们将 jax.checkpoint
应用于一个函数,该函数本身调用以 jax.checkpoint
装饰的函数,以便从 (D) 函数链的组合中内存使用按 (\mathcal{O}(\log_2 D)) 而非 (\mathcal{O}(D)) 缩放。
作为一个玩具例子,考虑多个 jnp.sin
函数的链式组合:
def chain_compose(funs):
def f(x):
for fun in funs:
x = fun(x)
return x
return f
f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
通常来说,存储的残差数量与链的长度成线性比例:
f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
但我们可以递归地应用 jax.checkpoint
来改善缩放效果:
def recursive_checkpoint(funs):
if len(funs) == 1:
return funs[0]
elif len(funs) == 2:
f1, f2 = funs
return lambda x: f1(f2(x))
else:
f1 = recursive_checkpoint(funs[:len(funs)//2])
f2 = recursive_checkpoint(funs[len(funs)//2:])
return lambda x: f1(jax.checkpoint(f2)(x))
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.)
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
这里的成本,与通常一样,是重新计算:特别是,我们最终要执行 (\mathcal{O}(\log_2 D)) 倍的 FLOPs:
f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation:
{ lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let
b:f32[] = sin a j:f32[] = mul i a
c:f32[] = cos a k:f32[] = mul j b
d:f32[] = sin b l:f32[] = mul k c
e:f32[] = cos b m:f32[] = mul l d
f:f32[] = sin d n:f32[] = mul m e
g:f32[] = cos d o:f32[] = mul n f
h:f32[] = sin f p:f32[] = mul o g
i:f32[] = cos f q:f32[] = mul p h
j:f32[] = sin h in (q,) }
k:f32[] = cos h
l:f32[] = sin j
m:f32[] = cos j
n:f32[] = sin l
o:f32[] = cos l
p:f32[] = sin n
q:f32[] = cos n
in (p, q, o, m, k, i, g, e, c) }
f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.)
forward computation: backward computation:
{ lambda ; a:f32[]. let { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let
b:f32[] = remat2[ e:f32[] = mul d a
differentiated=False f:f32[] = mul e b
jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) } g:f32[] = remat2[
policy=None differentiated=True
prevent_cse=True jaxpr={ lambda ; h:f32[] i:f32[]. let
] a j:f32[] = sin h
f:f32[] = sin b k:f32[] = cos h
g:f32[] = sin f l:f32[] = cos j
h:f32[] = sin g m:f32[] = mul i l
i:f32[] = sin h n:f32[] = mul m k
j:f32[] = sin i in (n,) }
k:f32[] = cos i policy=None
l:f32[] = sin j prevent_cse=True
m:f32[] = cos j ] c f
in (l, m, k, g, a) } o:f32[] = remat2[
differentiated=True
jaxpr={ lambda ; p:f32[] q:f32[]. let
r:f32[] = sin p
s:f32[] = sin r
t:f32[] = sin s
u:f32[] = cos s
v:f32[] = cos t
w:f32[] = mul q v
x:f32[] = mul w u
y:f32[] = remat2[
differentiated=True
jaxpr={ lambda ; z:f32[] ba:f32[]. let
bb:f32[] = sin z
bc:f32[] = cos z
bd:f32[] = cos bb
be:f32[] = mul ba bd
bf:f32[] = mul be bc
in (bf,) }
policy=None
prevent_cse=True
] p x
in (y,) }
policy=None
prevent_cse=True
] 3.0 g
in (o,) }
实际注意事项
当不同函数被分阶段送到 XLA 进行编译时,例如将 jax.jit
应用于包含 jax.grad
调用的函数时,XLA 将自动优化计算,包括决定何时计算或重新生成值。因此,在 jax.jit
下,通常不需要使用 jax.checkpoint
对不同函数进行检查点处理。XLA 将为您优化这些内容。
一个例外是在使用分阶段控制流(例如 jax.lax.scan
)时。跨多个控制流原语的自动编译器优化,例如在正向传播 scan
和相应的反向传播 scan
之间,通常不够彻底。因此,经常建议在传递给 jax.lax.scan
的主体函数上使用 jax.checkpoint
。
例如,在大型Transformer 模型中的一个常见模式是将架构表达为通过层的jax.lax.scan
,以减少编译时间。也就是说,类比于一个简单的全连接网络,我们不是写像这样的代码:
LayerParam = tuple[jnp.ndarray, jnp.ndarray] # weights, bias pair for a layer
ParamsList = list[LayerParam]
def net(params: ParamsList, x: jnp.ndarray):
for W, b in params:
x = jnp.maximum(jnp.dot(x, W) + b, 0.)
return x
我们可以使用jax.lax.scan
来迭代层应用:
StackedWeights = jnp.ndarray # all weight matrices stacked together
StackedBiases = jnp.ndarray # all bias vectors stacked together
all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
def net(all_weights, all_biases, x):
x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
return x
这种逐层扫描的版本可以减少编译时间,但可能会阻碍一些编译器优化,导致梯度计算效率低下。为了缓解这个问题,我们可以在扫描函数上使用jax.checkpoint
:
from functools import partial
@partial(jax.checkpoint,
policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
W, b = W_b_pair
out = jnp.maximum(jnp.dot(x, W) + b, 0.)
return out, None
通过这种方式使用jax.checkpoint
,我们手动控制 JAX 自动微分在前向和反向传播之间保存的值,从而不依赖于 XLA 优化来为我们选择。
JAX 基元的工作方式
原文:
jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
necula@google.com,2019 年 10 月。
JAX 实现了 Python 函数的某些转换,例如 jit
、grad
、vmap
或 pmap
。要转换的 Python 函数必须是 JAX 可追踪的,这意味着当 Python 函数执行时,它对数据应用的唯一操作是检查数据属性(例如形状或类型)或称为 JAX 基元的特殊操作。特别地,JAX 可追踪的函数有时会被 JAX 用抽象参数调用。例如,JAX 抽象值的一个示例是 ShapedArray(float32[2,2])
,它捕获了值的类型和形状,但不是具体数据值。JAX 基元知道如何在具体数据值和 JAX 抽象值上操作。
转换后的 JAX 函数本身必须是 JAX 可追踪的函数,以确保这些转换可以组合,例如 jit(jacfwd(grad(f)))
。
JAX 已经预定义了对应大多数 XLA 操作的基元,例如 add、matmul、sin、cos 和索引。JAX 还提供了以 JAX 基元为基础实现 numpy 函数的功能,这意味着使用 JAX 的 numpy 实现的 Python 程序是 JAX 可追踪的,因此可以进行变换。其他库可以通过在 JAX 基元的基础上实现它们来使其能够被 JAX 追踪。
JAX 基元的集合是可扩展的。可以定义一个新的基元,封装函数的行为,而不是在预定义的 JAX 基元的基础上重新实现函数。
本文档的目标是解释 JAX 基元必须支持的接口,以允许 JAX 执行其所有转换。
考虑我们想要为 JAX 添加支持三个参数的乘加函数,数学上定义为“multiply_add(x, y, z) = x * y + z”。该函数在三个形状相同的浮点数值张量上逐点执行操作。
使用现有的基元
定义新函数的最简单方法是使用 JAX 基元或者已经用 JAX 基元编写的其他函数,例如在 jax.lax
模块中定义的函数:
from jax import lax
from jax._src import api
def multiply_add_lax(x, y, z):
"""Implementation of multiply-add using the jax.lax primitives."""
return lax.add(lax.mul(x, y), z)
def square_add_lax(a, b):
"""A square-add function using the newly defined multiply-add."""
return multiply_add_lax(a, a, b)
print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
square_add_lax = 14.0
grad(square_add_lax) = 4.0
为了理解 JAX 如何内部使用这些基元,我们添加了一些帮助函数来跟踪函数调用。
#@title Helper functions (execute this cell)
import functools
import traceback
_indentation = 0
def _trace(msg=None):
"""Print a message at current indentation."""
if msg is not None:
print(" " * _indentation + msg)
def _trace_indent(msg=None):
"""Print a message and then indent the rest."""
global _indentation
_trace(msg)
_indentation = 1 + _indentation
def _trace_unindent(msg=None):
"""Unindent then print a message."""
global _indentation
_indentation = _indentation - 1
_trace(msg)
def trace(name):
"""A decorator for functions to trace arguments and results."""
def trace_func(func): # pylint: disable=missing-docstring
def pp(v):
"""Print certain values more succinctly"""
vtype = str(type(v))
if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
return "<JaxComputationBuilder>"
elif "jaxlib.xla_extension.XlaOp" in vtype:
return "<XlaOp at 0x{:x}>".format(id(v))
elif ("partial_eval.JaxprTracer" in vtype or
"batching.BatchTracer" in vtype or
"ad.JVPTracer" in vtype):
return "Traced<{}>".format(v.aval)
elif isinstance(v, tuple):
return "({})".format(pp_values(v))
else:
return str(v)
def pp_values(args):
return ", ".join([pp(arg) for arg in args])
@functools.wraps(func)
def func_wrapper(*args):
_trace_indent("call {}({})".format(name, pp_values(args)))
res = func(*args)
_trace_unindent("|<- {} = {}".format(name, pp(res)))
return res
return func_wrapper
return trace_func
class expectNotImplementedError(object):
"""Context manager to check for NotImplementedError."""
def __enter__(self): pass
def __exit__(self, type, value, tb):
global _indentation
_indentation = 0
if type is NotImplementedError:
print("\nFound expected exception:")
traceback.print_exc(limit=3)
return True
elif type is None: # No exception
assert False, "Expected NotImplementedError"
else:
return False
而不是直接使用 jax.lax
基元,我们可以使用已经用这些基元编写的其他函数,例如 jax.numpy
中的函数:
import jax.numpy as jnp
import numpy as np
@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
return jnp.add(jnp.multiply(x, y), z)
@trace("square_add_numpy")
def square_add_numpy(a, b):
return multiply_add_numpy(a, a, b)
print("\nNormal evaluation:")
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
Normal evaluation:
call square_add_numpy(2.0, 10.0)
call multiply_add_numpy(2.0, 2.0, 10.0)
|<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy = 14.0
Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
|<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) = 4.0
注意,在计算 grad
的过程中,JAX 调用了 square_add_numpy
和 multiply_add_numpy
,并使用特殊的参数 ConcreteArray(...)
(在此 colab 中进一步描述)。重要的是要记住,一个 JAX 可追溯的函数必须能够不仅在具体参数上运行,还能在 JAX 可能使用的特殊抽象参数上运行。
只要函数是用 JAX 原语编写的,JAX 的可追溯性属性就得到满足。
定义新的 JAX 原语
为支持乘加功能的正确方式是使用现有的 JAX 原语,如上所示。然而,为了展示 JAX 原语的工作方式,让我们假装我们想为 JAX 添加一个新的原语来实现乘加功能。
from jax import core
multiply_add_p = core.Primitive("multiply_add") # Create the primitive
@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
"""The JAX-traceable way to use the JAX primitive.
Note that the traced arguments must be passed as positional arguments
to `bind`.
"""
return multiply_add_p.bind(x, y, z)
@trace("square_add_prim")
def square_add_prim(a, b):
"""A square-add function implemented using the new JAX-primitive."""
return multiply_add_prim(a, a, b)
如果我们尝试调用新定义的函数,我们会得到一个错误,因为我们尚未告诉 JAX 关于新原语的语义。
with expectNotImplementedError():
square_add_prim(2., 10.)
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1319/2844449444.py", line 2, in <module>
square_add_prim(2., 10.)
File "/tmp/ipykernel_1319/1393342955.py", line 48, in func_wrapper
res = func(*args)
File "/tmp/ipykernel_1319/1308506715.py", line 16, in square_add_prim
return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented
原始评估规则
@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
"""Concrete implementation of the primitive.
This function does not need to be JAX traceable.
Args:
x, y, z: the concrete arguments of the primitive. Will only be called with
concrete values.
Returns:
the concrete result of the primitive.
"""
# Note that we can use the original numpy, which is not JAX traceable
return np.add(np.multiply(x, y), z)
# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl)
<function __main__.multiply_add_impl(x, y, z)>
assert square_add_prim(2., 10.) == 14.
call square_add_prim(2.0, 10.0)
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0
JIT
现在如果我们尝试使用 jit
,我们会得到一个 NotImplementedError
:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented
抽象评估规则
为了 JIT 函数以及其他转换,JAX 首先使用只有参数的形状和类型的抽象方式进行评估。这种抽象评估有多重目的:
-
获取计算中使用的 JAX 原语序列。这个序列将被编译。
-
计算所有向量和操作在计算中使用的形状和类型。
例如,具有 3 个元素的向量的抽象可能是 ShapedArray(float32[3])
或 ConcreteArray([1., 2., 3.])
。在后一种情况下,JAX 使用实际的具体值包装为抽象值。
from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
"""Abstract evaluation of the primitive.
This function does not need to be JAX traceable. It will be invoked with
abstractions of the actual arguments.
Args:
xs, ys, zs: abstractions of the arguments.
Result:
a ShapedArray for the result of the primitive.
"""
assert xs.shape == ys.shape
assert xs.shape == zs.shape
return core.ShapedArray(xs.shape, xs.dtype)
# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
<function __main__.multiply_add_abstract_eval(xs, ys, zs)>
如果我们重新尝试进行 JIT 编译,我们可以看到抽象评估的过程,但是我们会遇到另一个错误,关于缺少实际的 XLA 编译规则:
with expectNotImplementedError():
api.jit(square_add_prim)(2., 10.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
api.jit(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu
XLA 编译规则
JAX 编译通过将每个原语编译成 XLA 操作的图形来工作。
这是向 JAX 添加新功能的最大障碍,因为 XLA 操作的集合是有限的,并且 JAX 已经为大多数操作预定义了原语。然而,XLA 包括一个 CustomCall
操作,可以用来封装使用 C++ 定义的任意功能。
from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
"""The compilation to XLA of the primitive.
Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
the results of the function.
Does not need to be a JAX-traceable function.
"""
return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)>
现在我们成功 JIT。请注意下面,JAX 首先抽象评估函数,触发 multiply_add_abstract_eval
函数,然后编译它遇到的一系列原语,包括 multiply_add
。在这一点上,JAX 调用 multiply_add_xla_translation
。
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc664db0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc688cf0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc688d70>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc682b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8fd060>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a45a0d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d4ea0, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object <module> at 0x7f0afd8d6b80, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/1570919344.py': '/tmp/ipykernel_1319/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afd8fe8f0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6835b0>]
下面是另一个 jit
的用法,我们只编译关于第一个参数的部分。请注意,square_add_prim
的第二个参数是具体的,这导致第三个参数 multiply_add_abstract_eval
是 ConcreteArray
。我们看到 multiply_add_abstract_eval
可以与 ShapedArray
和 ConcreteArray
一起使用。
assert api.jit(lambda x, y: square_add_prim(x, y),
static_argnums=1)(2., 10.) == 14.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc666480>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc690530>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6905b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc690570>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8ffc40>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d5b00, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object <module> at 0x7f0b2cd8b3c0, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/4165789807.py': '/tmp/ipykernel_1319/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69c250>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dcff0>]
前向微分
JAX 在形式上实现了前向微分,即雅可比向量积(参见JAX 自动微分手册)。
现在,如果我们尝试计算 jvp
函数,会出现错误,因为我们尚未告诉 JAX 如何区分 multiply_add
原语。
# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
api.jvp(square_add_prim, (2., 10.), (1., 1.))
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1319/800067577.py", line 5, in <module>
api.jvp(square_add_prim, (2., 10.), (1., 1.))
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1901, in jvp
return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1930, in _jvp
out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented
from jax.interpreters import ad
@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
"""Evaluates the primal output and the tangents (Jacobian-vector product).
Given values of the arguments and perturbation of the arguments (tangents),
compute the output of the primitive and the perturbation of the output.
This method must be JAX-traceable. JAX may invoke it with abstract values
for the arguments and tangents.
Args:
arg_values: a tuple of arguments
arg_tangents: a tuple with the tangents of the arguments. The tuple has
the same length as the arg_values. Some of the tangents may also be the
special value ad.Zero to specify a zero tangent.
Returns:
a pair of the primal output and the tangent.
"""
x, y, z = arg_values
xt, yt, zt = arg_tangents
_trace("Primal evaluation:")
# Now we have a JAX-traceable computation of the output.
# Normally, we can use the ma primitive itself to compute the primal output.
primal_out = multiply_add_prim(x, y, z)
_trace("Tangent evaluation:")
# We must use a JAX-traceable way to compute the tangent. It turns out that
# the output tangent can be computed as (xt * y + x * yt + zt),
# which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
# We do need to deal specially with Zero. Here we just turn it into a
# proper tensor of 0s (of the same shape as 'x').
# An alternative would be to check for Zero and perform algebraic
# simplification of the output tangent computation.
def make_zero(tan):
return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
return (primal_out, output_tangent)
# Register the forward differentiation rule with JAX
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
# Tangent is: xt*y + x*yt + zt = 1.*2\. + 2.*1\. + 1\. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, 1.0, 1.0)
call multiply_add_impl(2.0, 1.0, 1.0)
|<- multiply_add_impl = 3.0
|<- multiply_add_prim = 3.0
call multiply_add_prim(1.0, 2.0, 3.0)
call multiply_add_impl(1.0, 2.0, 3.0)
|<- multiply_add_impl = 5.0
|<- multiply_add_prim = 5.0
|<- multiply_add_value_and_jvp = (14.0, 5.0)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
解释如下:
-
JAX 在
square_add_prim
中为何使用 ConcreteArray?这里没有进行抽象评估。 -
不确定如何解释
multiply_add_prim
是如何使用 ConcreteValue 调用的,但我们却没有调用multiply_add_abstract_eval
。 -
我认为在这里展示 jaxpr 将会很有用。
JIT 的前向微分
我们可以将 JIT 应用于前向微分函数:
assert api.jit(lambda arg_values, arg_tangents:
api.jvp(square_add_prim, arg_values, arg_tangents))(
(2., 10.), (1., 1.)) == (14., 5.)
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cc10>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc68ae70>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a598430>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cca0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dc2b0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a598430>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a3459c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cc10>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8b3270>]
注意,我们首先抽象评估 multiply_add_value_and_jvp
,它进而抽象评估 ma
的原始和切线评估(共 3 次调用 ma
原语)。然后编译这 3 次出现的原语。
反向微分
如果我们现在尝试使用反向微分,我们会看到 JAX 首先使用 multiply_add_value_and_jvp
来计算抽象值的前向微分,但随后遇到 NotImplementedError
。
在计算反向微分时,JAX 首先对前向微分代码 multiply_add_value_and_jvp
进行抽象评估,以获取一个追踪原语,用于计算输出切线。请注意,JAX 使用具体值评估此抽象评估以进行微分点,而使用抽象值评估切线。还需注意,JAX 对第三个参数的特殊抽象切线值 Zero
,反映了我们不对 square_add_prim
的第二个参数进行微分,其流向 multiply_add_prim
的第三个参数。
还需注意,在计算切线的抽象评估期间,我们将值 0.0
作为第三个参数的切线传递。这是因为在 multiply_add_value_and_jvp
的定义中使用了 make_zero
函数。
# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
api.grad(square_add_prim)(2., 10.)
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
Found expected exception:
Traceback (most recent call last):
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 284, in get_primitive_transpose
return primitive_transposes[p]
KeyError: multiply_add
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1319/339076514.py", line 3, in <module>
api.grad(square_add_prim)(2., 10.)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 621, in grad_f
_, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented
上述错误是因为缺少一个部分,JAX 无法使用前向微分代码来计算反向微分。
转置
正如上文所述,在计算反向微分时,JAX 获取了一个原语的追踪,使用前向微分计算切线。然后,JAX 以抽象方式反向解释此追踪,并对每个原语应用转置规则。
要理解正在发生的情况,请暂时考虑一个更简单的例子,函数“f(x, y) = x * y + y”。假设我们需要在点 (2., 4.)
处进行微分。JAX 将从输入 xt
和 yt
的切线计算中生成以下 JVP 切线计算的 ft
:
a = xt * 4.
b = 2. * yt
c = a + b
ft = c + yt
由于构造,切线计算在输入切线中始终是线性的。在切线计算中可能出现的唯一非线性操作符是乘法,但其中一个操作数是常量。
JAX 将通过反向处理 JVP 计算来生成反向微分计算。对于切线计算中的每个操作,它累积操作使用的变量的余切,使用操作结果的余切:
# Initialize cotangents of inputs and intermediate vars
xct = yct = act = bct = cct = 0.
# Initialize cotangent of the output
fct = 1.
# Process "ft = c + yt"
cct += fct
yct += fct
# Process "c = a + b"
act += cct
bct += cct
# Process "b = 2\. * yt"
yct += 2. * bct
# Process "a = xt * 4."
xct += act * 4.
可以验证该计算产生了 xct = 4.
和 yct = 3.
,这是函数 f
的偏导数。
JAX 对于可能出现在 JVP 计算中的每个原语都知道如何对其进行转置。从概念上讲,如果原语 p(x, y, z)
在参数 y
和 z
的常量值 x
下是线性的,例如 p(x, y, z) = y*cy + z*cz
,那么原语的转置是:
p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
注意 p_transpose
获取原语输出的余切以及与原语的每个参数对应的值。对于线性参数,转置获取未定义的 _
值,对于其他参数,获取实际的常数。转置为原语的每个参数返回一个余切值,对于常数参数返回 None
值。
特别地,
add_transpose(out_ct, _, _) = (out_ct, out_ct)
mult_transpose(out_ct, x, _) = (None, x * out_ct)
mult_transpose(out_ct, _, y) = (out_ct * y, None)
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
"""Evaluates the transpose of a linear primitive.
This method is only used when computing the backward gradient following
value_and_jvp, and is only needed for primitives that are used in the JVP
calculation for some other primitive. We need transposition for multiply_add_prim,
because we have used multiply_add_prim in the computation of the output_tangent in
multiply_add_value_and_jvp.
In our case, multiply_add is not a linear primitive. However, it is used linearly
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
Always one of the first two multiplicative arguments is a constant.
Args:
ct: the cotangent of the output of the primitive.
x, y, z: values of the arguments. The arguments that are used linearly
get an ad.UndefinedPrimal value. The other arguments get a constant
value.
Returns:
a tuple with the cotangent of the inputs, with the value None
corresponding to the constant arguments.
"""
if not ad.is_undefined_primal(x):
# This use of multiply_add is with a constant "x"
assert ad.is_undefined_primal(y)
ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
res = None, ct_y, ct
else:
# This use of multiply_add is with a constant "y"
assert ad.is_undefined_primal(x)
ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
res = ct_x, None, ct
return res
ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
现在我们可以完成 grad
的运行:
assert api.grad(square_add_prim)(2., 10.) == 4.
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(2.0, 2.0, 10.0)
call multiply_add_impl(2.0, 2.0, 10.0)
|<- multiply_add_impl = 14.0
|<- multiply_add_prim = 14.0
Tangent evaluation:
call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(1.0, 2.0, 0.0)
call multiply_add_impl(1.0, 2.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
call multiply_add_prim(2.0, 1.0, 0.0)
call multiply_add_impl(2.0, 1.0, 0.0)
|<- multiply_add_impl = 2.0
|<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0)
注意到两次调用 multiply_add_transpose
。它们对应于在 multiply_add_value_and_jvp
的 output_tangent
计算中使用 multiply_add_prim
的两次使用。第一次调用转置对应于 multiply_add_prim(xt, y, ...)
的最后使用,其中 y
是常数 2.0。
反向微分的 JIT
注意 multiply_add_value_and_jvp
的抽象评估仅使用抽象值,在 JIT 缺失时我们使用了 ConcreteArray
。
assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
Primal evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
Tangent evaluation:
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
|<- multiply_add_abstract_eval = ShapedArray(float32[])
|<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51c360>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc50ba30>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc508d30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc50bc30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afc69ef20>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a56fb00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc6694d0, file "/tmp/ipykernel_1319/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/3085343041.py': '/tmp/ipykernel_1319/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69d600>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6d1b30>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51c360>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc50ba30>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc508d30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc50bc30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afc69ef20>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a56fb00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a611410>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc6694d0, file "/tmp/ipykernel_1319/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/3085343041.py': '/tmp/ipykernel_1319/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69d930>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8e6f70>]
批处理
批处理转换将点式计算转变为向量上的计算。如果我们现在尝试,会得到 NotImplementedError
:
# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
np.array([10., 20.]))
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
Found expected exception:
Traceback (most recent call last):
File "/tmp/ipykernel_1319/2641678767.py", line 3, in <module>
api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1214, in vmap_f
out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented
我们需要告诉 JAX 如何评估原语的批处理版本。在这种特殊情况下,multiply_add_prim
已经适用于任意维度的输入向量逐点运算。因此,批处理版本可以使用相同的 multiply_add_prim
实现。
from jax.interpreters import batching
@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
"""Computes the batched version of the primitive.
This must be a JAX-traceable function.
Since the multiply_add primitive already operates pointwise on arbitrary
dimension tensors, to batch it we can use the primitive itself. This works as
long as both the inputs have the same dimensions and are batched along the
same axes. The result is batched along the axis that the inputs are batched.
Args:
vector_arg_values: a tuple of two arguments, each being a tensor of matching
shape.
batch_axes: the axes that are being batched. See vmap documentation.
Returns:
a tuple of the result, and the result axis that was batched.
"""
assert batch_axes[0] == batch_axes[1]
assert batch_axes[0] == batch_axes[2]
_trace("Using multiply_add to compute the batch:")
res = multiply_add_prim(*vector_arg_values)
return res, batch_axes[0]
batching.primitive_batchers[multiply_add_p] = multiply_add_batch
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch(([2\. 3.], [2\. 3.], [10\. 20.]), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim([2\. 3.], [2\. 3.], [10\. 20.])
call multiply_add_impl([2\. 3.], [2\. 3.], [10\. 20.])
|<- multiply_add_impl = [14\. 29.]
|<- multiply_add_prim = [14\. 29.]
|<- multiply_add_batch = ([14\. 29.], 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
批处理的 JIT
assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
(np.array([2., 3.]),
np.array([10., 20.])),
[14., 29.])
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
Using multiply_add to compute the batch:
call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
|<- multiply_add_abstract_eval = ShapedArray(float32[2])
|<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
|<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51cd10>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc68bdb0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc68aaf0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc689eb0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a884960>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_1319/184469370.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_batch at 0x7f0afc6687c0, file "/tmp/ipykernel_1319/184469370.py", line 4>, 52): loc("multiply_add_batch"("/tmp/ipykernel_1319/184469370.py":25:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc668a80, file "/tmp/ipykernel_1319/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_1319/1392464762.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/184469370.py': '/tmp/ipykernel_1319/184469370.py', '/tmp/ipykernel_1319/1392464762.py': '/tmp/ipykernel_1319/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/184469370.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69e860>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8920f0>]
在 JAX 中编写自定义 Jaxpr 解释器
原文:
jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html
JAX 提供了几个可组合的函数转换(jit
,grad
,vmap
等),可以编写简洁且加速的代码。
这里我们展示了如何通过编写自定义 Jaxpr 解释器来向系统添加自己的函数转换。而且我们将自动获得与所有其他转换的可组合性。
此示例使用了内部 JAX API,可能随时会中断。任何不在API 文档中的内容都应视为内部内容。
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random
JAX 在做什么?
JAX 为数值计算提供了类似 NumPy 的 API,可以直接使用,但 JAX 真正的强大之处在于可组合的函数转换。例如jit
函数转换接受一个函数并返回一个语义上相同的函数,但由 XLA 进行惰性编译以加速器。
x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f)
当我们调用fast_f
时,会发生什么?JAX 会追踪函数并构建一个 XLA 计算图。然后将图进行即时编译(JIT)并执行。其他转换类似,它们首先会追踪函数并以某种方式处理输出追踪。要了解更多关于 JAX 追踪机制的信息,您可以参考 README 中的“How it works”部分。
Jaxpr 追踪器
Jax 中一个特别重要的追踪器是 Jaxpr 追踪器,它将操作记录到一个 Jaxpr(Jax 表达式)中。Jaxpr 是一种数据结构,可以像小型函数式编程语言一样进行评估,因此 Jaxprs 是函数转换的有用中间表示。
要首次查看 Jaxprs,可以考虑make_jaxpr
转换。make_jaxpr
本质上是一个“漂亮打印”转换:它将一个函数转换为一个函数,给定示例参数,生成其计算的 Jaxpr 表示。make_jaxpr
对于调试和内省非常有用。让我们使用它来查看一些示例 Jaxprs 的结构。
def examine_jaxpr(closed_jaxpr):
jaxpr = closed_jaxpr.jaxpr
print("invars:", jaxpr.invars)
print("outvars:", jaxpr.outvars)
print("constvars:", jaxpr.constvars)
for eqn in jaxpr.eqns:
print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
print()
print("jaxpr:", jaxpr)
def foo(x):
return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))
print()
def bar(w, b, x):
return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
foo
=====
invars: [Var(id=140117887103104):int32[]]
outvars: [Var(id=140117887103296):int32[]]
constvars: []
equation: [Var(id=140117887103104):int32[], 1] add [Var(id=140117887103296):int32[]] {}
jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }
bar
=====
invars: [Var(id=140117843771968):float32[5,10], Var(id=140117843772032):float32[5], Var(id=140117843772096):float32[10]]
outvars: [Var(id=140117843772352):float32[5], Var(id=140117843772096):float32[10]]
constvars: []
equation: [Var(id=140117843771968):float32[5,10], Var(id=140117843772096):float32[10]] dot_general [Var(id=140117843772160):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [Var(id=140117843772160):float32[5], Var(id=140117843772032):float32[5]] add [Var(id=140117843772224):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=140117843772288):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=140117843772224):float32[5], Var(id=140117843772288):float32[5]] add [Var(id=140117843772352):float32[5]] {}
jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
d:f32[5] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] a c
e:f32[5] = add d b
f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
g:f32[5] = add e f
in (g, c) }
-
jaxpr.invars
- Jaxpr 的invars
是一个输入变量列表,类似于 Python 函数的参数。 -
jaxpr.outvars
- Jaxpr 的outvars
是由 Jaxpr 返回的变量。每个 Jaxpr 都有多个输出。 -
jaxpr.constvars
-constvars
是一个变量列表,它们也是 Jaxpr 的输入之一,但对应于跟踪中的常量(我们稍后会更详细地讨论这些内容)。 -
jaxpr.eqns
- 一个方程列表,实质上是 let 绑定。每个方程包含输入变量列表、输出变量列表和一个原语,用于评估输入以生成输出。每个方程还有一个params
,即参数字典。
总的来说,一个 Jaxpr 封装了一个简单的程序,可以使用输入进行评估以生成输出。稍后我们将详细介绍如何做到这一点。现在需要注意的重要事项是,Jaxpr 是一个可以按我们想要的方式操作和评估的数据结构。
Jaxprs 有什么用处?
Jaxprs 是简单的程序表示,易于转换。由于 Jax 允许我们从 Python 函数中分离出 Jaxprs,它为我们提供了一种转换用 Python 编写的数值程序的方法。
您的第一个解释器:invert
让我们尝试实现一个简单的函数“inverter”,它接收原始函数的输出,并返回产生这些输出的输入。现在,让我们专注于由其他可逆的一元函数组成的简单一元函数。
目标:
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
我们将通过 (1) 将 f
追踪到 Jaxpr 中,然后 (2) 反向解释 Jaxpr 的方式来实现这一点。在反向解释 Jaxpr 过程中,对于每个方程,我们将在表中查找原语的逆,并应用它。
1. 追踪一个函数
让我们使用 make_jaxpr
来追踪一个函数到 Jaxpr 中。
# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps
from jax import core
from jax import lax
from jax._src.util import safe_map
jax.make_jaxpr
返回一个封闭的 Jaxpr,即一个已经与跟踪中的常量(literals
)捆绑在一起的 Jaxpr。
def f(x):
return jnp.exp(jnp.tanh(x))
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals)
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[]
2. 评估 Jaxpr
在编写自定义 Jaxpr 解释器之前,让我们首先实现“默认”解释器 eval_jaxpr
,它按原样评估 Jaxpr,计算与未转换的原始 Python 函数相同的值。
为此,我们首先创建一个环境来存储每个变量的值,并在评估 Jaxpr 中的每个方程时更新该环境。
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable -> value
env = {}
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Bind args and consts to environment
safe_map(write, jaxpr.invars, args)
safe_map(write, jaxpr.constvars, consts)
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
return safe_map(read, jaxpr.outvars)
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)]
注意,即使原始函数不返回平坦列表,eval_jaxpr
也将始终返回一个平坦列表。
此外,这个解释器不处理高阶原语(如 jit
和 pmap
),这些内容不在本指南讨论范围内。您可以参考 core.eval_jaxpr
(链接) 来查看此解释器不涵盖的边界情况。
自定义inverse
Jaxpr 解释器
inverse
解释器看起来与 eval_jaxpr
并无太大不同。我们首先设置注册表,将原语映射到它们的逆。然后编写一个自定义解释器,在注册表中查找原语。
结果表明,这个解释器看起来也类似于反向模式自动微分中使用的“转置”解释器,可以在此处找到:链接。
inverse_registry = {}
现在我们将为一些原语注册它们的逆。按照惯例,Jax 中的原语以 _p
结尾,而其中许多流行的原语位于 lax
中。
inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh
inverse
将首先跟踪函数,然后自定义解释 Jaxpr。让我们建立一个简单的框架。
def inverse(fun):
@wraps(fun)
def wrapped(*args, **kwargs):
# Since we assume unary functions, we won't worry about flattening and
# unflattening arguments.
closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
return out[0]
return wrapped
现在我们只需要定义 inverse_jaxpr
,它将反向遍历 Jaxpr 并在可能时反转原语。
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
def read(var):
if type(var) is core.Literal:
return var.val
return env[var]
def write(var, val):
env[var] = val
# Args now correspond to Jaxpr outvars
safe_map(write, jaxpr.outvars, args)
safe_map(write, jaxpr.constvars, consts)
# Looping backward
for eqn in jaxpr.eqns[::-1]:
# outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
# Assuming a unary function
outval = inverse_registryeqn.primitive
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
就是这样!
def f(x):
return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0)
重要的是,你可以通过 Jaxpr 解释器进行跟踪。
jax.make_jaxpr(inverse(f))(f(1.))
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) }
这就是向系统添加新转换所需的全部内容,而且你可以免费获得所有其他转换的组合!例如,我们可以在 inverse
中使用 jit
、vmap
和 grad
!
jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.)
Array([-3.1440797, 15.584931 , 2.2551253, 1.3155028, 1\. ], dtype=float32, weak_type=True)
读者的练习
-
处理具有多个参数的原语,其中输入部分已知,例如
lax.add_p
,lax.mul_p
。 -
处理
xla_call
和xla_pmap
原语,这些原语不会与eval_jaxpr
和inverse_jaxpr
一样正常工作。
使用 C++ 和 CUDA 进行 GPU 自定义操作
原文:
jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html
JAX 预装有大量内置操作,但用户偶尔会遇到需要新操作但 JAX 不支持的情况。
为了适应这些情况,JAX 允许用户定义自定义操作,本教程旨在解释如何为 GPU 定义并在单 GPU 和多 GPU 环境中使用它们。
本教程包含来自 使用自定义 C++ 和 CUDA 代码扩展 JAX 的信息,并假设您熟悉 JAX 原语。
RMS 标准化
本教程将 RMS 标准化作为 JAX 中的自定义操作添加。请注意,可以直接使用 jax.numpy
表达 RMS 标准化。但是,我们使用它作为示例来展示如何为 GPU 创建自定义操作的过程。此操作在 gpu_ops/rms_norm_kernels.cu
中的 CUDA 代码已从 Apex 借用,并进行了修改,以消除对 PyTorch 的任何依赖。
高级步骤
本教程展示了如何编写自定义操作及其梯度。
在 C 中:每个新的 JAX 原语都需要按照以下步骤进行操作。
-
具有 CUDA 核心(核心)。
-
创建分派 CUDA 核心的 C 函数,该函数将由 XLA 调用。
-
创建描述符以传达计算所需的信息。
- 类型、形状和其他属性。
-
将 C 函数绑定到 Python
- 以创建描述符并在执行期间调用原语。
在 Python 中:您需要按照以下步骤进行操作。
-
定义新的 JAX 原语(指令/操作)
-
编写 Python 函数以使用原语构建图节点。
-
定义其抽象评估。
-
定义其降低到 MLIR。
-
[可选] 定义梯度。
-
[可选] 使用 custom_partitioning 或 shard_map 函数实现快速多 GPU 支持。
C 代码
参见 gpu_ops
代码列表,其中包含完整的 C++ 和 CUDA 文件代码列表。gpu_ops/rms_norm_kernels.cu
定义了以下函数,这些函数使用给定的 buffers
在指定的 stream
上启动 RMS 标准化核心。
namespace gpu_ops {
void rms_forward_affine_mixed_dtypes(cudaStream_t stream, void **buffers,
const char *opaque,
std::size_t opaque_len);
void rms_backward_affine(cudaStream_t stream, void **buffers,
const char *opaque, std::size_t opaque_len);
} // namespace gpu_ops
-
stream
是用于在 GPU 上执行任何核心的 CUDA 流。 -
buffers
包含所有指向输入缓冲区的指针,后跟所有指向输出缓冲区的指针。 -
opaque
是传递给自定义函数的任何额外信息的缓冲区,而opaque_len
是opaque
的长度。
在本教程中,我们将通过opaque
将一个RMSNormDescriptor
对象传递给这些函数。
namespace gpu_ops {
enum ElementType { BF16, F16, F32, F64 };
struct RMSNormDescriptor {
int n1;
int n2;
double eps;
ElementType x_type;
ElementType w_type;
int part_grad_size;
};
} // namespace gpu_ops
现在,我们需要通过pybind11
将这些函数以及ElementType
和RMSNormDescriptor
作为 Python 模块gpu_ops
公开。
pybind11::dict RMSNormRegistrations() {
pybind11::dict dict;
dict["rms_forward_affine_mixed_dtype"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
dict["rms_backward_affine"] =
gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
return dict;
}
PYBIND11_MODULE(gpu_ops, m) {
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
m.def("create_rms_norm_descriptor",
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
gpu_ops::ElementType w_type, int part_grad_size) {
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
n1, n2, eps, x_type, w_type, part_grad_size});
});
pybind11::enum_<gpu_ops::ElementType>(m, "ElementType")
.value("BF16", gpu_ops::ElementType::BF16)
.value("F16", gpu_ops::ElementType::F16)
.value("F32", gpu_ops::ElementType::F32)
.value("F64", gpu_ops::ElementType::F64);
}
构建gpu_ops
扩展模块
我们使用上述代码构建了gpu_ops
Python 扩展模块。(请参阅 C++和 CUDA 文件的完整代码清单,查看gpu_ops
代码列表。)
python -m pip install pybind11==2.10.1
mkdir -p build
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
python_executable=$(python -c 'import sys; print(sys.executable)')
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
strip build/gpu_ops$(${python_executable}-config --extension-suffix)
将 RMS 归一化添加到 JAX 作为自定义调用
gpu_ops
只是一个 Python 扩展模块,我们需要更多工作来将其插入到 JAX 中。
创建原语
我们首先创建了原语_rms_norm_fwd_p
和_rms_norm_bwd_p
,这些原语可以映射到自定义函数。我们为这些操作设置了multiple_results
属性为True
,表示该操作作为元组产生多个输出。当设置为False
时,该操作将产生单个输出而不是元组。有关更多详细信息,请参见How JAX primitives work。
from functools import partial
import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from build import gpu_ops
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client
# Create _rms_norm_fwd_p for forward operation.
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
_rms_norm_fwd_p.multiple_results = True
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))
def rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
return output
# Create _rms_norm_bwd_p for backward operation.
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
_rms_norm_bwd_p.multiple_results = True
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))
def rms_norm_bwd(g, invvar, x, weight, eps):
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
return grad_input, grad_weight
降低到 MLIR 自定义调用
为了将自定义函数映射到新原语_rms_norm_fwd_p
和_rms_norm_bwd_p
,我们需要:
-
使用
xla_client.register_custom_call_target
注册自定义函数作为自定义调用目标,并且 -
注册将原语降低为 MLIR 自定义调用的降低函数,并使用注册的自定义调用目标。
下面的函数_rms_norm_fwd_cuda_lowering
和_rms_norm_bwd_cuda_lowering
通过gpu_ops
中的自定义目标将原语降低为 MLIR 自定义调用操作。这些函数已经注册到jax.interpreters.mlir.register_lowering
中。
注意,在降低函数中创建了一个RMSNormDescriptor
对象,并将其作为opaque
传递给自定义调用。
from functools import reduce
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call
# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
def element_type_to_descriptor_type_mapping(element_type):
_element_type_to_descriptor_type_mapping = {
ir.BF16Type.get(): gpu_ops.ElementType.BF16,
ir.F16Type.get(): gpu_ops.ElementType.F16,
ir.F32Type.get(): gpu_ops.ElementType.F32,
ir.F64Type.get(): gpu_ops.ElementType.F64,
}
return _element_type_to_descriptor_type_mapping.get(element_type)
def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]
def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_element_type = (
ir.F32Type.get()
if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
else x_type.element_type
)
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
0, # unused
)
out = custom_call(
b"rms_forward_affine_mixed_dtype",
result_types=[
ir.RankedTensorType.get(x_shape, w_type.element_type),
ir.RankedTensorType.get((n1,), iv_element_type),
],
operands=[x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, w_shape),
result_layouts=default_layouts(x_shape, (n1,)),
).results
return out
mlir.register_lowering(
_rms_norm_fwd_p,
_rms_norm_fwd_cuda_lowering,
platform="gpu",
)
def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
x_type = ir.RankedTensorType(x.type)
x_shape = x_type.shape
w_type = ir.RankedTensorType(weight.type)
w_shape = w_type.shape
iv_type = ir.RankedTensorType(invvar.type)
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
part_grad_shape = ctx.avals_out[-1].shape
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
element_type_to_descriptor_type_mapping(x_type.element_type),
element_type_to_descriptor_type_mapping(w_type.element_type),
part_grad_shape[0],
)
out = custom_call(
b"rms_backward_affine",
result_types=[
ir.RankedTensorType.get(x_shape, x_type.element_type),
ir.RankedTensorType.get(w_shape, w_type.element_type),
ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
],
operands=[grad_output, invvar, x, weight],
backend_config=opaque,
operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
).results
return out
mlir.register_lowering(
_rms_norm_bwd_p,
_rms_norm_bwd_cuda_lowering,
platform="gpu",
)
让我们进行测试
per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
jax.random.PRNGKey(0),
shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
dtype=jnp.bfloat16,
)
norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16)
测试前向函数
out = rms_norm_fwd(x, weight)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In [5], line 1
----> 1 out = rms_norm_fwd(x, weight)
...
NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
抽象评估
上述测试失败,报错信息为NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
。为什么测试失败?这是什么意思?
作为执行的一部分,JAX 执行抽象评估。由于 JAX 对新原语没有任何了解,因此不知道如何计算输出形状和输出数据类型,因此无法进行这些操作的抽象评估。
我们需要为每个原语的抽象评估提供一个函数。这些抽象评估函数计算输出的形状和数据类型,但不计算操作的实际值。
这些函数将传递给.def_abstract_eval
方法,以便与相应的原语进行注册。
更多关于抽象评估的信息,请参见How JAX primitives work。
from functools import reduce
from operator import mul
from jax.core import ShapedArray
def _rms_norm_fwd_abstract(x, weight, eps):
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
iv_dtype = dtypes.canonicalize_dtype(x.dtype)
if iv_dtype in [jnp.float16, jnp.bfloat16]:
iv_dtype = jnp.float32
n2 = reduce(mul, weight.shape)
n1 = reduce(mul, x.shape) // n2
return (
ShapedArray(x.shape, w_dtype, named_shape=x.named_shape), # output
ShapedArray((n1,), iv_dtype, named_shape=x.named_shape), # invvar
)
_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)
def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
w_dtype = dtypes.canonicalize_dtype(weight.dtype)
x_dtype = dtypes.canonicalize_dtype(x.dtype)
n2 = reduce(lambda x, y: x * y, weight.shape)
n1 = reduce(lambda x, y: x * y, x.shape) // n2
part_grad_shape = (16, n2)
assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
assert grad_output.shape == x.shape
assert invvar.shape == (n1,)
assert (
iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
)
assert grad_output.named_shape == x.named_shape
weight_named_shape = (
weight_named_shape if weight.named_shape else x.named_shape
)
return (
ShapedArray(
x.shape, x_dtype, named_shape=x.named_shape
), # grad input
ShapedArray(
weight.shape, w_dtype, named_shape=weight_named_shape
), # grad weight
ShapedArray(
part_grad_shape, iv_dtype, named_shape=weight_named_shape
), # part grad
)
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
让我们再次进行测试
测试前向函数
out = rms_norm_fwd(x, weight)
测试反向函数
现在让我们使用jax.grad
和jtu.check_grads
测试反向操作。
def loss(x, weight):
predictions = rms_norm_fwd(x, weight)
return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In [8], line 7
3 return -jnp.mean(predictions**2)
6 loss_grad = jax.grad(loss)
----> 7 out = loss_grad(x, weight)
...
NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
差分规则
反向操作以 NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
错误失败。这意味着,尽管我们定义了 rms_norm_fwd
和 rms_norm_bwd
,但 JAX 不知道它们之间的关系。
我们可以使用 jax.custom_vjp
及其约定,教给 JAX rms_norm_bwd
是 rms_norm_fwd
的反向操作。作为第一步,我们需要完善 rms_norm_fwd
和 rms_norm_bwd
的定义。
# rms_norm_fwd was previously defined as
#
# def rms_norm_fwd(x, weight, eps=1e-05):
# output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
# return output
#
def rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
return output, (invvar, x, weight)
# rms_norm_bwd was previously defined as
#
# def rms_norm_bwd(g, invvar, x, weight, eps):
# grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
# g, invvar, x, weight, eps=eps
# )
# return grad_input, grad_weight
#
def rms_norm_bwd(eps, res, g):
invvar, x, weight = res
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
return grad_input, grad_weight
rms_norm_fwd
现在返回额外的输出 (invvar, x, weight)
作为剩余数据,而 rms_norm_bwd
接受 eps
、res
和 g
作为参数。
通过 jax.custom_vjp
建立 rms_norm_fwd
和 rms_norm_bwd
之间的关系后,JAX 将确保从 rms_norm_fwd
中传递的剩余数据作为反向操作的 res
传递给 rms_norm_bwd
。对于像 eps
这样的不可微参数,JAX 确保它们在剩余数据之前传递给反向操作。这就是为什么 eps
在 rms_norm_bwd
的参数列表中位于 res
之前。
现在 rms_norm_fwd
返回了不需要用于简单 RMS 标准化操作的剩余数据,我们在其周围定义了一个包装器 rms_norm
,它简单地调用 rms_norm_fwd
并仅返回 output
。请注意,rms_norm
使用 @partial(jax.custom_vjp, nondiff_argnums=(2,))
进行了注释,并且我们将 rms_norm_fwd
和 rms_norm_bwd
传递给 rms_norm.defvjp
。这教会了 JAX,在对 rms_norm
进行微分时,使用 rms_norm_fwd
进行前向操作,使用 rms_norm_bwd
进行反向操作。
有关使用 jax.custom_vjp
定义 JAX 可转换 Python 函数的自定义导数规则,请参阅自定义导数规则。
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def rms_norm(x, weight, eps=1e-05):
output, _ = rms_norm_fwd(x, weight, eps=eps)
return output
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
经过我们的改进,反向操作测试与修改一起正常运行:loss
现在调用 rms_norm
而不是 rms_norm_fwd
。
def loss(x, weight):
predictions = rms_norm(x, weight)
return -jnp.mean(predictions**2)
loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
让我们在多个设备上进行测试。
我们正在使用 jax.experimental.pjit.pjit
在多个设备上进行并行执行,并在单个设备上的顺序执行中生成参考值。
测试前向函数。
让我们先在多个设备上测试前向操作。我们创建了一个简单的 1D 网格,并在所有设备上分片 x
。
from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit
mesh = Mesh(jax.local_devices(), ("x",))
ref = rms_norm(x, weight)
pjitted = pjit(
rms_norm,
# Shard x by batch dimension and replicate weight on all devices.
in_shardings=(PartitionSpec("x", None, None), PartitionSpec(None, None)),
# Shard the output by batch dimension.
out_shardings=PartitionSpec("x", None, None),
)
with mesh:
print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string())
out = pjitted(x, weight)
jnp.allclose(ref, out, atol=1e-5, rtol=1e-5)
HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}}
%fused_computation (param_1: bf16[32,512,512], param_1.3: u32[]) -> bf16[4,512,512] {
%param_1 = bf16[32,512,512]{2,1,0} parameter(0)
%param_1.3 = u32[] parameter(1)
%convert.2 = s32[] convert(u32[] %param_1.3), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%constant_9 = s32[] constant(4), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%multiply.3 = s32[] multiply(s32[] %convert.2, s32[] %constant_9), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%constant_8 = s32[] constant(0), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
ROOT %dynamic-slice.2 = bf16[4,512,512]{2,1,0} dynamic-slice(bf16[32,512,512]{2,1,0} %param_1, s32[] %multiply.3, s32[] %constant_8, s32[] %constant_8), dynamic_slice_sizes={4,512,512}, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] {
%param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
%all-gather = bf16[32,512,512]{2,1,0} all-gather(bf16[4,512,512]{2,1,0} %param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated}
%custom-call.0 = (bf16[32,512,512]{2,1,0}, f32[32]{0}) custom-call(bf16[32,512,512]{2,1,0} %all-gather, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[32,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config=" \000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\255\177\000\000"
%get-tuple-element = bf16[32,512,512]{2,1,0} get-tuple-element((bf16[32,512,512]{2,1,0}, f32[32]{0}) %custom-call.0), index=0, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
%partition-id = u32[] partition-id(), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
ROOT %fusion = bf16[4,512,512]{2,1,0} fusion(bf16[32,512,512]{2,1,0} %get-tuple-element, u32[] %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}
True
对于前向操作,值已经计算正确,然而生成的 HLO 模块显示一个 all-gather
操作来在所有设备上复制 x
,导致大量的通信开销。
由于 XLA 对于自定义函数不具备足够的知识来分片输入张量,它决定在进行自定义调用之前将它们复制以生成正确的值。
为了避免这种重复,我们可以:
-
custom_partitioning:使其表现得像所有本机 JAX 操作一样(但更复杂)。
-
使用手动分片
此示例演示了使用 custom_partitioning 的用法。
使用 custom_partitioning 分片向前函数
首先创建一个辅助函数来帮助所有需要的 JAX/XLA 回调注册。
def register_primitive(cls):
"""
register jax primitive
The order of calls. Each operation is composed of two primitives: Inner and Outer.
Inner, only the basic to wrap the custom_call itself.
- impl to XLA custom_call in C.
- abstract to know the static shapes
- lower to StableHLO XLA custom_call.
Outer, mostly all the rest:
- impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind.
- abstract: same
- lower to StableHLO custom_p. (XLA will call the python callback from it)
- custom_p
- vmap: could be added here.
VJP is based on Outer, but not handled in this function.
"""
def name_of_wrapper_p():
return cls.name + "_wrapper"
inner_p = core.Primitive(cls.name)
dispatch.prim_requires_devices_during_lowering.add(inner_p)
inner_p.multiple_results = cls.multiple_results
inner_p.def_impl(partial(xla.apply_primitive, inner_p))
inner_p.def_abstract_eval(cls.abstract)
mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
cls.inner_primitive = inner_p
outer_p = core.Primitive(name_of_wrapper_p())
dispatch.prim_requires_devices_during_lowering.add(outer_p)
outer_p.multiple_results = cls.multiple_results
outer_p.def_impl(cls.impl)
outer_p.def_abstract_eval(cls.abstract)
batching.primitive_batchers[outer_p] = cls.batcher
outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
partition=cls.partition)
mlir.register_lowering(outer_p,
mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
cls.outer_primitive = outer_p
...
我们定义了两个 JAX 原语,一个内部原语映射到我们想要在 JAX 中封装的真实内核。还有一个外部原语,将与自定义分区注册一起使用,并用于梯度。(如果您实现支持 vmat 的接口,它也将位于外部原语中)。
JAX custom_partitioning 实现是 XLA 在 XLA 分片逻辑期间从 XLA 到 Python 的回调。XLA 分片分为两个阶段:分片传播阶段和分区阶段。传播阶段是 XLA 规划要创建的分片的阶段。分区阶段创建分片图。为了让 XLA 能够分片我们的自定义操作,它需要我们定义两个额外的函数:infer_sharding_from_operands() 和 partition()。它们分别在第一阶段和第二阶段中使用。
infer_sharding_from_operands() 函数必须做其名称所述的事情:从输入分片推断输出分片。
partition() 函数将执行几个操作:
-
告诉预期将有哪些输入分片。如有必要,XLA 将进行重新分片。
-
告诉输出分片的最终版本。
-
给出一个函数,将从分片输入创建新指令。
查看代码注释以获取更多解释:
class RmsNormFwdClass:
name = "rms_forward_affine_mixed_dtype"
multiple_results = True
impl_static_args = (2,) # eps
inner_primitive = None
outer_primitive = None
@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
del eps, result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
# partition() will force all dims of all inputs to be replicated except the
# first dim of x that will be kept as is.
# This is because the implementaion can only be sharded on the batch dimensions.
x_spec = arg_infos[0].sharding.spec
# None mean that we replicate on that dimension.
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
return (output_sharding, invvar_sharding)
@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
del result_infos # Not needed for this example.
x_info, weight_info = arg_infos
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
x_spec = arg_infos[0].sharding.spec
# We only support sharding on the batch dimensions.
# Force sharding on all others dimensions with None.
arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
NamedSharding(mesh, PartitionSpec(None, None)))
invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
output_shardings = (arg_shardings[0], invvar_sharding)
# Sharded_impl only accepts positional arugments
# And they should be Jax traceable variables
impl = partial(RmsNormFwdClass.impl, eps=eps)
return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormFwdClass)
接下来我们定义 RMSNorm 后向传递的原语
使用 custom_partitioning 分片向后函数
class RmsNormBwdClass:
name = "rms_norm_bwd"
multiple_results = True
impl_static_args = (4,) # eps
inner_primitive = None
outer_primitive = None
@staticmethod
def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.core.ShapedArray]):
del eps, result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3
assert len(invvar_info.shape) == 1
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
# partition() will force all dims to be replicated except the batch dimension.
x_spec = x_info.sharding.spec
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
return (output_sharding, invvar_sharding, output_sharding, )
@staticmethod
def partition(eps : float, mesh : jax.sharding.Mesh,
arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
del result_infos # Not needed for this example.
g_info, invvar_info, x_info, weight_info = arg_infos
assert len(g_info.shape) == 3
assert len(invvar_info.shape) == 1
assert len(x_info.shape) == 3
assert len(weight_info.shape) == 2
# We only support sharding on the batch dimensions.
# Force sharding on all others dimensions with None.
# Also force gx, x and invvar to have the same batch sharding/replication.
x_spec = x_info.sharding.spec
arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
NamedSharding(mesh, PartitionSpec(x_spec[0],)),
NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
NamedSharding(mesh, PartitionSpec(None, None)))
output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
output_shardings = (output_sharding, invvar_sharding, invvar_sharding)
# Sharded_impl only accepts positional arugments
# And they should be Jax traceable variables
def impl(g, invvar, x, weight):
grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
g, invvar, x, weight, eps=eps
)
# We need to sum the weight gradient from all partition.
global_weight = grad_weight
if x_spec[0]:
global_weight = jax.lax.psum(grad_weight, x_spec[0])
return grad_input, global_weight, part_grad
return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormBwdClass)
通过与以前相同的自定义 _vjp 规则建立前向和后向原语的管道:
@partial(jax.custom_vjp, nondiff_argnums=(2,))
def custom_p_rms_norm(x, weight, eps=1e-05):
output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps)
return output
def custom_p_rms_norm_fwd(x, weight, eps=1e-05):
output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps)
return output, (invvar, x, weight)
def custom_p_rms_norm_bwd(eps, res, g):
invvar, x, weight = res
grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind(
g, invvar, x, weight, eps=eps)
return grad_input, grad_weight
custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd)
有了这些,我们完全定义了我们的自定义 RMS 规范原语与自定义分区。为了检查正确性,我们定义了以下损失函数:ref_loss 是要与之比较的参考值,而 custom_p_loss 使用了我们新实现的实现了自定义分区的原语。
def ref_loss(x, weight):
predictions = rms_norm(x, weight)
return -jnp.mean(predictions**2)
ref = jax.grad(ref_loss, argnums=(0, 1))(x, weight)
def custom_p_loss(x, weight):
predictions = custom_p_rms_norm(x, weight)
return -jnp.mean(predictions**2)
检查正确性
with Mesh(jax.local_devices(), ("x",)):
def run_and_verify(loss):
pjitted = pjit(
jax.grad(loss, argnums=(0, 1)),
# Shard x by batch dimension and replicate weight on all devices.
in_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
# Shard the output by batch dimension and replicate weight grad on all devices.
out_shardings=(
PartitionSpec("x", None, None),
PartitionSpec(None, None),
),
)
hlo = pjitted.lower(x, weight).compile().as_text()
out = pjitted(x, weight)
print(hlo)
assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
if "all-gather-start" in hlo:
print("NOT OPTIMIZED, ALL_GATHER in the graph!")
return out
custom_p_out = run_and_verify(custom_p_loss)
for r, o in zip(ref_out, custom_p_out):
print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6))
HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"}
%fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] {
%param_0 = f16[4,512,512]{2,1,0} parameter(0)
%constant_4_1 = f16[] constant(-4.7684e-07)
%broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
}
%region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] {
%Arg_1.11.0 = f16[] parameter(1)
%Arg_0.10.0 = f16[] parameter(0)
ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433}
}
ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) {
%param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated}
%param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]}
%custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000"
%get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
%loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
%get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
%custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000"
%get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
%all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
%all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
%get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done)
}
True
True
现在 HLO 中没有全收集操作,尊重分片,只有通过全归约累积梯度。
让我们把它放在一起
使用 custom_partitioning 完全定义原语的完整定义可以在 Custom_Operation_for_GPUs.py 中找到,以及定义 python 绑定的相应 C++ 代码可以在以下找到:
gpu_ops
代码列表
gpu_ops/kernel_helpers.h
gpu_ops/kernels.h
gpu_ops/pybind11_kernel_helpers.h
gpu_ops/gpu_ops.cpp
gpu_ops/rms_norm_kernels.cu
标签:tmp,中文,JAX,jax,py,f32,add,文档,ipykernel From: https://www.cnblogs.com/apachecn/p/18260392