首页 > 其他分享 >JAX-中文文档-九-

JAX-中文文档-九-

时间:2024-06-21 14:25:03浏览次数:23  
标签:tmp 中文 JAX jax py f32 add 文档 ipykernel

JAX 中文文档(九)

原文:jax.readthedocs.io/en/latest/

使用jax.checkpoint控制自动微分的保存数值(又名jax.remat

原文:jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html

import jax
import jax.numpy as jnp 

简而言之

使用jax.checkpoint装饰器(别名为jax.remat),结合jax.grad来控制前向传播时保存哪些中间值,以及在反向传播时重新计算哪些中间值,从而在内存和 FLOP 之间进行权衡。

不要错过关于jax.checkpoint如何与jax.jit交互的实用说明。

如果不使用jax.checkpointjax.grad(f)(x)的前向传播将保存雅可比系数和其他中间值以供后向传播使用。我们称这些保存的值为残差

def g(W, x):
  y = jnp.dot(W, x)
  return jnp.sin(y)

def f(W1, W2, W3, x):
  x = g(W1, x)
  x = g(W2, x)
  x = g(W3, x)
  return x

W1 = jnp.ones((5, 4))
W2 = jnp.ones((6, 5))
W3 = jnp.ones((7, 6))
x = jnp.ones(4)

# Inspect the 'residual' values to be saved on the forward pass
# if we were to evaluate `jax.grad(f)(W1, W2, W3, x)`
from jax.ad_checkpoint import print_saved_residuals
jax.ad_checkpoint.print_saved_residuals(f, W1, W2, W3, x) 
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[5] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of cos from <ipython-input-4-f510dde58e22>:3 (g)
f32[7] output of cos from <ipython-input-4-f510dde58e22>:3 (g) 

通过对子函数应用jax.checkpoint,无论是作为装饰器还是在特定的应用站点,我们都强制 JAX 不保存该子函数的任何残差。相反,只有jax.checkpoint装饰的函数的输入可能会被保存,并且在反向传播时从这些输入重新计算任何消耗的残差:

def f2(W1, W2, W3, x):
  x = jax.checkpoint(g)(W1, x)
  x = jax.checkpoint(g)(W2, x)
  x = jax.checkpoint(g)(W3, x)
  return x

jax.ad_checkpoint.print_saved_residuals(f2, W1, W2, W3, x) 
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of sin from <ipython-input-4-f510dde58e22>:3 (g)
f32[6] output of sin from <ipython-input-4-f510dde58e22>:3 (g) 

这里保存了两个sin应用的值,因为它们是jax.checkpoint装饰的g函数后续应用的参数,并且jax.checkpoint装饰的函数的输入可能会被保存。但没有保存任何cos应用的值。

要控制哪些值可保存,而无需编辑要区分的函数的定义,您可以使用重新材料化策略。以下是一个例子,仅保存没有批次维度的dot操作的结果(因为它们通常是 FLOP 限制的,因此值得保存而不是重新计算):

f3 = jax.checkpoint(f, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
jax.ad_checkpoint.print_saved_residuals(f3, W1, W2, W3, x) 
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[6] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g)
f32[7] output of dot_general from <ipython-input-4-f510dde58e22>:2 (g) 

您还可以使用策略来引用使用jax.ad_checkpoint.checkpoint_name命名的中间值:

from jax.ad_checkpoint import checkpoint_name

def f4(W1, W2, W3, x):
  x = checkpoint_name(g(W1, x), name='a')
  x = checkpoint_name(g(W2, x), name='b')
  x = checkpoint_name(g(W3, x), name='c')
  return x

f4 = jax.checkpoint(f4, policy=jax.checkpoint_policies.save_only_these_names('a'))
jax.ad_checkpoint.print_saved_residuals(f4, W1, W2, W3, x) 
f32[5,4] from the argument 'W1'
f32[6,5] from the argument 'W2'
f32[7,6] from the argument 'W3'
f32[4] from the argument 'x'
f32[5] named 'a' from <ipython-input-7-fc0ed1c14b8d>:4 (f4) 

在玩弄这些玩具示例时,我们可以使用在此笔记本中定义的print_fwd_bwd实用程序更详细地了解正在进行的操作:

from jax.tree_util import tree_flatten, tree_unflatten

from rich.console import Console
from rich.table import Table
import rich.text

def print_fwd_bwd(f, *args, **kwargs) -> None:
  args, in_tree = tree_flatten((args, kwargs))

  def f_(*args):
    args, kwargs = tree_unflatten(in_tree, args)
    return f(*args, **kwargs)

  fwd = jax.make_jaxpr(lambda *args: jax.vjp(f_, *args))(*args).jaxpr

  y, f_vjp = jax.vjp(f_, *args)
  res, in_tree = tree_flatten(f_vjp)

  def g_(*args):
    *res, y = args
    f_vjp = tree_unflatten(in_tree, res)
    return f_vjp(y)

  bwd = jax.make_jaxpr(g_)(*res, y).jaxpr

  table = Table(show_header=False, show_lines=True, padding=(1, 2, 0, 2), box=None)
  table.add_row("[bold green]forward computation:",
                "[bold green]backward computation:")
  table.add_row(rich.text.Text.from_ansi(str(fwd)),
                rich.text.Text.from_ansi(str(bwd)))
  console = Console(width=240, force_jupyter=True)
  console.print(table)

def _renderable_repr(self):
  return self.html
rich.jupyter.JupyterRenderable._repr_html_ = _renderable_repr 
# no use of jax.checkpoint:
print_fwd_bwd(f, W1, W2, W3, x) 

  forward computation:                                                        backward computation:                                                                   

  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[7] b:f32[6] c:f32[7,6] d:f32[6] e:f32[5] f:f32[6,5] g:f32[5] h:f32[4]  
   e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d        i:f32[5,4] j:f32[7]. let                                                            
      f:f32[5] = sin e                                                         k:f32[7] = mul j a                                                                  
      g:f32[5] = cos e                                                            l:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] k c                
      h:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f        m:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] k b                
      i:f32[6] = sin h                                                            n:f32[6] = mul l d                                                                  
      j:f32[6] = cos h                                                            o:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] n f                
      k:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c i        p:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] n e                
      l:f32[7] = sin k                                                            q:f32[5] = mul o g                                                                  
      m:f32[7] = cos k                                                            r:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] q i                
   in (l, m, i, c, j, f, b, g, d, a) }                                           s:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] q h                
                                                                               in (s, p, m, r) }                                                                     

# using jax.checkpoint with policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable:
print_fwd_bwd(f3, W1, W2, W3, x) 

  forward computation:                                                        backward computation:                                                                          

  { lambda ; a:f32[5,4] b:f32[6,5] c:f32[7,6] d:f32[4]. let                   { lambda ; a:f32[5] b:f32[6] c:f32[7] d:f32[5,4] e:f32[6,5] f:f32[7,6] g:f32[4] h:f32[7]. let  
   e:f32[5] = dot_general[dimension_numbers=(([1], [0]), ([], []))] a d     i:f32[5,4] j:f32[6,5] k:f32[7,6] l:f32[4] = remat2[                                        
      f:f32[5] = sin e                                                              differentiated=True                                                                      
      g:f32[6] = dot_general[dimension_numbers=(([1], [0]), ([], []))] b f    jaxpr={ lambda ; m:f32[5] n:f32[6] o:f32[7] p:f32[5,4] q:f32[6,5] r:f32[7,6]             
      h:f32[6] = sin g                                                                  s:f32[4] t:f32[7]. let                                                               
      i:f32[7] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c h     u:f32[5] = sin m                                                                     
      j:f32[7] = sin i                                                                  v:f32[5] = cos m                                                                     
   in (j, e, g, i, a, b, c, d) }                                                       w:f32[6] = sin n                                                                     
                                                                                        x:f32[6] = cos n                                                                     
                                                                                        y:f32[7] = cos o                                                                     
                                                                                        z:f32[7] = mul t y                                                                   
                                                                                        ba:f32[6] = dot_general[dimension_numbers=(([0], [0]), ([], []))] z r                
                                                                                        bb:f32[6] = mul ba x                                                                 
                                                                                        bc:f32[5] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bb q               
                                                                                        bd:f32[5] = mul bc v                                                                 
                                                                                        be:f32[4] = dot_general[dimension_numbers=(([0], [0]), ([], []))] bd p               
                                                                                        bf:f32[5,4] = dot_general[dimension_numbers=(([], []), ([], []))] bd s               
                                                                                        bg:f32[6,5] = dot_general[dimension_numbers=(([], []), ([], []))] bb u               
                                                                                        bh:f32[7,6] = dot_general[dimension_numbers=(([], []), ([], []))] z w                
                                                                               in (bf, bg, bh, be) }                                                                  
                                                                                    policy=<function dot_with_no_batch_dims at 0x7f5e469b1700>                               
                                                                                    prevent_cse=True                                                                         
                                                                                  ] a b c d e f g h                                                                          
                                                                               in (i, j, k, l) }                                                                            

让我们一步一步地思考

您可能希望首先(重新)阅读自动微分手册第一部分

jax.checkpoint的基础知识

jax.linearizejax.vjp中,如何以及何时计算某些值有灵活性。不同的选择可以在内存使用和 FLOP 之间进行权衡。JAX 通过jax.checkpoint提供了对这些选择的控制。

其中之一是在前向传播时执行雅可比系数计算,即在输入可用时立即进行,或者在反向传播时,在需要系数之前进行。考虑sin_vjp的例子:

def sin_vjp(x):
  y = jnp.sin(x)
  cos_x = jnp.cos(x)
  return y, lambda y_bar: cos_x * y_bar 

在反向传播时,另一种有效的实现方式是计算jnp.cos(x)的值,而不是在前向传播时:

def sin_vjp2(x):
  y = jnp.sin(x)
  return y, lambda y_bar: jnp.cos(x) * y_bar 

对于这个特定的函数,两个版本使用的内存量是相同的,尽管我们减少了原始计算的 FLOP 并增加了余切计算的 FLOP。

当涉及函数组合时,我们还有另一种选择。回顾我们的两个函数组合的 VJP 规则:

def f(x):
  y = g(x)
  z = h(y)
  return z

def f_vjp(x):
  y, g_vjp = jax.vjp(g, x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd(z_bar):
    y_bar, = h_vjp(z_bar)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd 

另一种选择是:

def f_vjp_checkpoint(x):
  y = g(x)
  z, h_vjp = jax.vjp(h, y)
  def f_bwd2(z_bar):
    y_bar, = h_vjp(z_bar)
    _, g_vjp = jax.vjp(g, x)
    x_bar, = g_vjp(y_bar)
    return x_bar
  return z, f_bwd2 

换句话说,这种替代实现不会在前向传播中计算g_vjp或其闭包中的残差值。而是只在后向传播f_bwd2中计算它们。这意味着f_vjp_checkpoint需要更少的内存:如果gh每个都需要类似量级的内存来存储其残差,远大于x,那么由f_vjp_checkpoint(x)生成的函数所需的内存量仅为f_vjp(x)的一半!

我们所付出的代价是冗余工作:在f_bwd2中,我们必须重新评估g(x)作为jax.vjp(g, x)的一部分,只是为了丢弃它的值(在下划线变量的行中_, g_vjp = jax.vjp(g, x))。

我们可以在自动微分中实现这种 VJP 行为,而不必直接编写 VJP 函数,而是通过在原始函数f的另一种定义中使用jax.checkpoint来实现:

def f_checkpoint(x):
  y = jax.checkpoint(g)(x)
  z = h(y)
  return z 

换句话说,我们将jax.checkpoint应用于f的第一阶段g,而不是f本身。这样,当我们评估jax.grad(f_checkpoint)(x)时,我们会得到如下计算:

  1. 运行g的前向传播,丢弃残差值;

  2. 运行h的前向传播,保存残差;

  3. 运行h的后向传播,使用步骤 2 中的残差;

  4. 重新运行g的前向传播,保存残差;

  5. 运行g的后向传播,使用步骤 4 中的残差。

换句话说,通过评估jax.grad(f_checkpoint)(x),我们会得到与如下计算相同的结果:

def f_checkpoint_grad(x):
  y = g(x)                  # step 1
  _, h_vjp = jax.vjp(h)(y)  # step 2
  y_bar, = h_vjp(1.0)       # step 3
  _, g_vjp = jax.vjp(g, x)  # step 4
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar 

通常情况下,jax.checkpoint(foo)是一个新函数,其输入输出行为与foo相同,但在自动微分下行为不同,特别是在jax.linearizejax.vjp(以及它们的包装器,如jax.grad)中,但不包括jax.jvp。在求导时,只有经过jax.checkpoint的函数的输入会在前向传播时存储;在后向传播时,会重新计算残差(即来自foo及其雅可比系数值的中间值,这些值在后向传播时需要重新计算)。

注意,如果f = lambda x: h(g(x))是我们想要求导的函数,即如果我们想应用jax.grad(f),那么对f本身应用jax.checkpoint不会节省任何内存。这是因为评估jax.grad(jax.checkpoint(f))(x)会导致如下计算:

  1. 运行前向传播,丢弃所有残差;

  2. 立即重新运行前向传播,保存残差;

  3. 运行后向传播,使用步骤 2 中的残差。

换句话说,代码中我们会有类似这样的东西:

def f_grad_bad(x):
  _ = f(x)                  # step 1
  _, f_vjp = jax.vjp(f, x)  # step 2
  x_bar, = f_vjp(1.0)       # step 3
  return x_bar 

如果对h的第二阶段应用jax.checkpoint,我们也不会获得任何内存节省。这是因为评估jax.grad(lambda x: jax.checkpoint(h)(g(x)))会导致如下计算:

  1. 运行g的前向传播,保存残差;

  2. 运行h的前向传播,丢弃残差;

  3. 立即重新运行h的前向传播,保存残差;

  4. 运行h的后向传播,使用步骤 3 中的残差;

  5. 运行g的后向传播,消耗步骤 1 中的剩余项。

这样,在代码中,我们会有类似以下的内容:

def f_grad_bad2(x):
  y, g_vjp = jax.vjp(g, x)  # step 1
  z = h(y)                  # step 2
  _, h_vjp = jax.vjp(h, y)  # step 3
  y_bar, = h_vjp(1.0)       # step 3
  x_bar, = g_vjp(y_bar)     # step 5
  return x_bar 

稍微更一般地说,如果我们有一个函数链组合,如f = lambda x: f3(f2(f1(x))),并且我们有兴趣评估jax.grad(f),我们可以说:

  • 我们不应将jax.checkpoint应用于整个函数f,因为这不会节省任何内存(并且会执行浪费的重新计算);

  • 我们不应将jax.checkpoint应用于最后一个子函数f3,因为这不会节省任何内存(并且会执行浪费的重新计算);

  • 我们可以将jax.checkpoint应用于f1f2或它们的组合lambda x: f2(f1(x)),因为这些任意一个都可能节省内存,并且会表达不同的内存/重新计算折衷。

什么可以保存的自定义策略

到目前为止所展示的,使用jax.checkpoint会从一个极端切换到另一个:

  • 没有jax.checkpoint,JAX 的自动微分倾向于在前向传播中计算尽可能多的内容,并为后向传播存储它;

  • 使用jax.checkpoint装饰器,我们在前向传播中尽量少计算,并根据需要在后向传播中重新计算值。

要在这两个极端之间操作,保存某些东西而不保存其他东西,我们可以在子函数上谨慎地放置jax.checkpoint装饰器。但这需要编辑要求微分的函数,例如模型代码,这可能不方便。也很难对变体进行实验。

因此,一个替代方法是使用jax.checkpointpolicy参数。策略是一个可调用对象(即一个函数),它以一种类型级别的原始应用规范作为输入,并返回一个布尔值,指示是否允许将相应的输出值保存为剩余项(或者必须在(共)切向计算中根据需要重新计算)。为了编写健壮的代码,应从jax.checkpoint_policies的属性中选择策略,例如jax.checkpoint_policies.dots_with_no_batch_dims_saveable,因为编写自定义策略可调用对象的 API 被认为是内部的。

例如,考虑要微分的这个函数:

def loss(params, x, y):
  return jnp.sum((predict(params, x) - y)**2)

def predict(params, x):
  *Ws, Wlast = params
  for W in Ws:
    x = layer(W, x)
  x = jnp.dot(Wlast, x)
  return x

def layer(W, x):
  return jnp.sin(jnp.dot(W, x)) 
W1 = W2 = W3 = jnp.ones((4, 4))
params = [W1, W2, W3]
x = jnp.ones(4)
y = jnp.ones(4) 
print_saved_residuals(loss, params, x, y) 
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of sin from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss) 

而不是在前向传播中保存这么多值,也许我们只想保存没有批处理维度的矩阵乘法结果(因为它们可能是 FLOP 而不是内存绑定)。我们可以使用策略jax.checkpoint_policies.dots_with_no_batch_dims_saveable来实现这一点:

loss_checkpoint = jax.checkpoint(loss, policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
print_saved_residuals(loss_checkpoint, params, x, y) 
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y'
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] output of dot_general from <ipython-input-18-3808b5023c3d>:8 (predict) 

还要注意,通过提供一个策略,我们无需编辑定义losspredictlayer的代码。如果我们希望在调用代码(例如训练脚本)中进行策略实验而不更改库代码(例如神经网络库),这特别方便。

一些策略可以引用名为jax.ad_checkpoint.checkpoint_name的值:

from jax.ad_checkpoint import checkpoint_name

def predict(params, x):
  *Ws, Wlast = params
  for i, W in enumerate(Ws):
    x = layer(W, x)
    x = checkpoint_name(x, name=f'layer{i}_output')
  x = jnp.dot(Wlast, x)
  return x 

单独看,checkpoint_name 只是一个身份函数。但因为某些策略函数知道如何查找它们,我们可以使用这些名称来控制 checkpoint_name 输出的某些值是否被视为可保存的:

print_saved_residuals(loss, params, x, y) 
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer0_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of cos from <ipython-input-18-3808b5023c3d>:12 (layer)
f32[4] named 'layer1_output' from <ipython-input-22-e48aedf368ad>:7 (predict)
f32[4] output of mul from <ipython-input-18-3808b5023c3d>:2 (loss) 
loss_checkpoint2 = jax.checkpoint(loss, policy=jax.checkpoint_policies.save_any_names_but_these('layer1_output'))
print_saved_residuals(loss_checkpoint2, params, x, y) 
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4,4] from the argument 'params'
f32[4] from the argument 'x'
f32[4] from the argument 'y' 

另一个涉及名称的策略是 jax.checkpoint_policies.save_only_these_names

某些策略包括:

  • everything_saveable(默认策略,就像根本没有使用 jax.checkpoint 一样)

  • nothing_saveable(即重新生成所有内容,就像根本没有使用自定义策略一样)

  • dots_saveable 或其别名 checkpoint_dots

  • dots_with_no_batch_dims_saveable 或其别名 checkpoint_dots_with_no_batch_dims

  • save_anything_but_these_names(保存任何值,但不包括具有给定名称的 checkpoint_name 输出)

  • save_any_names_but_these(仅保存命名值,即 checkpoint_name 的任何输出,但不包括给定名称)

  • save_only_these_names(仅保存命名值,并且仅限于给定的名称)

策略仅指示可保存的内容;只有在反向传播实际需要时才会保存值。

高级:递归的 jax.checkpoint

通过适当地应用 jax.checkpoint,可以表达许多内存使用和(重新)计算之间的权衡。一个令人惊讶的例子是 递归 检查点处理,在这种情况下,我们将 jax.checkpoint 应用于一个函数,该函数本身调用以 jax.checkpoint 装饰的函数,以便从 (D) 函数链的组合中内存使用按 (\mathcal{O}(\log_2 D)) 而非 (\mathcal{O}(D)) 缩放。

作为一个玩具例子,考虑多个 jnp.sin 函数的链式组合:

def chain_compose(funs):
  def f(x):
    for fun in funs:
      x = fun(x)
    return x
  return f

f = chain_compose([jnp.sin] * 8)
print_saved_residuals(f, 3.) 
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) 

通常来说,存储的残差数量与链的长度成线性比例:

f = chain_compose([jnp.sin] * 16)
print_saved_residuals(f, 3.) 
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f)
f32[] output of cos from <ipython-input-25-46b5594773cb>:4 (f) 

但我们可以递归地应用 jax.checkpoint 来改善缩放效果:

def recursive_checkpoint(funs):
  if len(funs) == 1:
    return funs[0]
  elif len(funs) == 2:
    f1, f2 = funs
    return lambda x: f1(f2(x))
  else:
    f1 = recursive_checkpoint(funs[:len(funs)//2])
    f2 = recursive_checkpoint(funs[len(funs)//2:])
    return lambda x: f1(jax.checkpoint(f2)(x)) 
f = recursive_checkpoint([jnp.sin] * 8)
print_saved_residuals(f, 3.) 
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>) 
f = recursive_checkpoint([jnp.sin] * 16)
print_saved_residuals(f, 3.) 
f32[] from the argument 'x'
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of sin from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>)
f32[] output of cos from <ipython-input-27-86f83c871e81>:6 (<lambda>) 

这里的成本,与通常一样,是重新计算:特别是,我们最终要执行 (\mathcal{O}(\log_2 D)) 倍的 FLOPs:

f = chain_compose([jnp.sin] * 8)
print_fwd_bwd(f, 3.) 

  forward computation:                  backward computation:                                                                    

  { lambda ; a:f32[]. let               { lambda ; a:f32[] b:f32[] c:f32[] d:f32[] e:f32[] f:f32[] g:f32[] h:f32[] i:f32[]. let  
   b:f32[] = sin a                    j:f32[] = mul i a                                                                    
      c:f32[] = cos a                       k:f32[] = mul j b                                                                    
      d:f32[] = sin b                       l:f32[] = mul k c                                                                    
      e:f32[] = cos b                       m:f32[] = mul l d                                                                    
      f:f32[] = sin d                       n:f32[] = mul m e                                                                    
      g:f32[] = cos d                       o:f32[] = mul n f                                                                    
      h:f32[] = sin f                       p:f32[] = mul o g                                                                    
      i:f32[] = cos f                       q:f32[] = mul p h                                                                    
      j:f32[] = sin h                    in (q,) }                                                                              
      k:f32[] = cos h                                                                                                            
      l:f32[] = sin j                                                                                                            
      m:f32[] = cos j                                                                                                            
      n:f32[] = sin l                                                                                                            
      o:f32[] = cos l                                                                                                            
      p:f32[] = sin n                                                                                                            
      q:f32[] = cos n                                                                                                            
   in (p, q, o, m, k, i, g, e, c) }                                                                                             

f = recursive_checkpoint([jnp.sin] * 8)
print_fwd_bwd(f, 3.) 

  forward computation:                                                              backward computation:                               

  { lambda ; a:f32[]. let                                                           { lambda ; a:f32[] b:f32[] c:f32[] d:f32[]. let     
   b:f32[] = remat2[                                                              e:f32[] = mul d a                               
        differentiated=False                                                            f:f32[] = mul e b                               
  jaxpr={ lambda ; c:f32[]. let d:f32[] = sin c; e:f32[] = sin d in (e,) }        g:f32[] = remat2[                               
        policy=None                                                                       differentiated=True                           
        prevent_cse=True                                                            jaxpr={ lambda ; h:f32[] i:f32[]. let         
      ] a                                                                            j:f32[] = sin h                           
      f:f32[] = sin b                                                                         k:f32[] = cos h                           
      g:f32[] = sin f                                                                         l:f32[] = cos j                           
      h:f32[] = sin g                                                                         m:f32[] = mul i l                         
      i:f32[] = sin h                                                                         n:f32[] = mul m k                         
      j:f32[] = sin i                                                                in (n,) }                                   
      k:f32[] = cos i                                                                     policy=None                                   
      l:f32[] = sin j                                                                     prevent_cse=True                              
      m:f32[] = cos j                                                                   ] c f                                           
   in (l, m, k, g, a) }                                                                o:f32[] = remat2[                               
                                                                                          differentiated=True                           
                                                                                    jaxpr={ lambda ; p:f32[] q:f32[]. let         
                                                                                     r:f32[] = sin p                           
                                                                                              s:f32[] = sin r                           
                                                                                              t:f32[] = sin s                           
                                                                                              u:f32[] = cos s                           
                                                                                              v:f32[] = cos t                           
                                                                                              w:f32[] = mul q v                         
                                                                                              x:f32[] = mul w u                         
                                                                                              y:f32[] = remat2[                         
                                                                                                differentiated=True                     
                                                                                    jaxpr={ lambda ; z:f32[] ba:f32[]. let  
                                                                                     bb:f32[] = sin z                    
                                                                                                    bc:f32[] = cos z                    
                                                                                                    bd:f32[] = cos bb                   
                                                                                                    be:f32[] = mul ba bd                
                                                                                                    bf:f32[] = mul be bc                
                                                                                     in (bf,) }                            
                                                                                                policy=None                             
                                                                                                prevent_cse=True                        
                                                                                              ] p x                                     
                                                                                     in (y,) }                                   
                                                                                          policy=None                                   
                                                                                          prevent_cse=True                              
                                                                                        ] 3.0 g                                         
                                                                                     in (o,) }                                         

实际注意事项

当不同函数被分阶段送到 XLA 进行编译时,例如将 jax.jit 应用于包含 jax.grad 调用的函数时,XLA 将自动优化计算,包括决定何时计算或重新生成值。因此,在 jax.jit 下,通常不需要使用 jax.checkpoint 对不同函数进行检查点处理。XLA 将为您优化这些内容。

一个例外是在使用分阶段控制流(例如 jax.lax.scan)时。跨多个控制流原语的自动编译器优化,例如在正向传播 scan 和相应的反向传播 scan 之间,通常不够彻底。因此,经常建议在传递给 jax.lax.scan 的主体函数上使用 jax.checkpoint

例如,在大型Transformer 模型中的一个常见模式是将架构表达为通过层的jax.lax.scan,以减少编译时间。也就是说,类比于一个简单的全连接网络,我们不是写像这样的代码:

LayerParam = tuple[jnp.ndarray, jnp.ndarray]  # weights, bias pair for a layer
ParamsList = list[LayerParam]

def net(params: ParamsList, x: jnp.ndarray):
  for W, b in params:
    x = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return x 

我们可以使用jax.lax.scan来迭代层应用:

StackedWeights = jnp.ndarray  # all weight matrices stacked together
StackedBiases = jnp.ndarray   # all bias vectors stacked together

all_weights = jnp.stack([W for W, _ in params])
all_biases = jnp.stack([b for _, b in params])

def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None

def net(all_weights, all_biases, x):
  x, _ = jax.lax.scan(layer, x, (all_weights, all_biases))
  return x 

这种逐层扫描的版本可以减少编译时间,但可能会阻碍一些编译器优化,导致梯度计算效率低下。为了缓解这个问题,我们可以在扫描函数上使用jax.checkpoint

from functools import partial

@partial(jax.checkpoint,
         policy=jax.checkpoint_policies.dots_with_no_batch_dims_saveable)
def layer(x, W_b_pair):
  W, b = W_b_pair
  out = jnp.maximum(jnp.dot(x, W) + b, 0.)
  return out, None 

通过这种方式使用jax.checkpoint,我们手动控制 JAX 自动微分在前向和反向传播之间保存的值,从而不依赖于 XLA 优化来为我们选择。

JAX 基元的工作方式

原文:jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html

在 Colab 中打开 在 Kaggle 中打开

necula@google.com,2019 年 10 月。

JAX 实现了 Python 函数的某些转换,例如 jitgradvmappmap。要转换的 Python 函数必须是 JAX 可追踪的,这意味着当 Python 函数执行时,它对数据应用的唯一操作是检查数据属性(例如形状或类型)或称为 JAX 基元的特殊操作。特别地,JAX 可追踪的函数有时会被 JAX 用抽象参数调用。例如,JAX 抽象值的一个示例是 ShapedArray(float32[2,2]),它捕获了值的类型和形状,但不是具体数据值。JAX 基元知道如何在具体数据值和 JAX 抽象值上操作。

转换后的 JAX 函数本身必须是 JAX 可追踪的函数,以确保这些转换可以组合,例如 jit(jacfwd(grad(f)))

JAX 已经预定义了对应大多数 XLA 操作的基元,例如 add、matmul、sin、cos 和索引。JAX 还提供了以 JAX 基元为基础实现 numpy 函数的功能,这意味着使用 JAX 的 numpy 实现的 Python 程序是 JAX 可追踪的,因此可以进行变换。其他库可以通过在 JAX 基元的基础上实现它们来使其能够被 JAX 追踪。

JAX 基元的集合是可扩展的。可以定义一个新的基元,封装函数的行为,而不是在预定义的 JAX 基元的基础上重新实现函数。

本文档的目标是解释 JAX 基元必须支持的接口,以允许 JAX 执行其所有转换。

考虑我们想要为 JAX 添加支持三个参数的乘加函数,数学上定义为“multiply_add(x, y, z) = x * y + z”。该函数在三个形状相同的浮点数值张量上逐点执行操作。

使用现有的基元

定义新函数的最简单方法是使用 JAX 基元或者已经用 JAX 基元编写的其他函数,例如在 jax.lax 模块中定义的函数:

from jax import lax
from jax._src import api

def multiply_add_lax(x, y, z):
  """Implementation of multiply-add using the jax.lax primitives."""
  return lax.add(lax.mul(x, y), z)

def square_add_lax(a, b):
  """A square-add function using the newly defined multiply-add."""
  return multiply_add_lax(a, a, b)

print("square_add_lax = ", square_add_lax(2., 10.))
# Differentiate w.r.t. the first argument
print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.)) 
square_add_lax =  14.0
grad(square_add_lax) =  4.0 

为了理解 JAX 如何内部使用这些基元,我们添加了一些帮助函数来跟踪函数调用。

#@title Helper functions (execute this cell)
import functools
import traceback

_indentation = 0
def _trace(msg=None):
  """Print a message at current indentation."""
    if msg is not None:
        print("  " * _indentation + msg)

def _trace_indent(msg=None):
  """Print a message and then indent the rest."""
    global _indentation
    _trace(msg)
    _indentation = 1 + _indentation

def _trace_unindent(msg=None):
  """Unindent then print a message."""
    global _indentation
    _indentation = _indentation - 1
    _trace(msg)

def trace(name):
  """A decorator for functions to trace arguments and results."""

  def trace_func(func):  # pylint: disable=missing-docstring
    def pp(v):
  """Print certain values more succinctly"""
        vtype = str(type(v))
        if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
            return "<JaxComputationBuilder>"
        elif "jaxlib.xla_extension.XlaOp" in vtype:
            return "<XlaOp at 0x{:x}>".format(id(v))
        elif ("partial_eval.JaxprTracer" in vtype or
              "batching.BatchTracer" in vtype or
              "ad.JVPTracer" in vtype):
            return "Traced<{}>".format(v.aval)
        elif isinstance(v, tuple):
            return "({})".format(pp_values(v))
        else:
            return str(v)
    def pp_values(args):
        return ", ".join([pp(arg) for arg in args])

    @functools.wraps(func)
    def func_wrapper(*args):
      _trace_indent("call {}({})".format(name, pp_values(args)))
      res = func(*args)
      _trace_unindent("|<- {} = {}".format(name, pp(res)))
      return res

    return func_wrapper

  return trace_func

class expectNotImplementedError(object):
  """Context manager to check for NotImplementedError."""
  def __enter__(self): pass
  def __exit__(self, type, value, tb):
    global _indentation
    _indentation = 0
    if type is NotImplementedError:
      print("\nFound expected exception:")
      traceback.print_exc(limit=3)
      return True
    elif type is None:  # No exception
      assert False, "Expected NotImplementedError"
    else:
      return False 

而不是直接使用 jax.lax 基元,我们可以使用已经用这些基元编写的其他函数,例如 jax.numpy 中的函数:

import jax.numpy as jnp
import numpy as np

@trace("multiply_add_numpy")
def multiply_add_numpy(x, y, z):
    return jnp.add(jnp.multiply(x, y), z)

@trace("square_add_numpy")
def square_add_numpy(a, b):
    return multiply_add_numpy(a, a, b)

print("\nNormal evaluation:")  
print("square_add_numpy = ", square_add_numpy(2., 10.))
print("\nGradient evaluation:")
print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) 
Normal evaluation:
call square_add_numpy(2.0, 10.0)
  call multiply_add_numpy(2.0, 2.0, 10.0)
  |<- multiply_add_numpy = 14.0
|<- square_add_numpy = 14.0
square_add_numpy =  14.0

Gradient evaluation:
call square_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_numpy(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  |<- multiply_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
|<- square_add_numpy = Traced<ConcreteArray(14.0, dtype=float32, weak_type=True)>
grad(square_add_numpy) =  4.0 

注意,在计算 grad 的过程中,JAX 调用了 square_add_numpymultiply_add_numpy,并使用特殊的参数 ConcreteArray(...)(在此 colab 中进一步描述)。重要的是要记住,一个 JAX 可追溯的函数必须能够不仅在具体参数上运行,还能在 JAX 可能使用的特殊抽象参数上运行。

只要函数是用 JAX 原语编写的,JAX 的可追溯性属性就得到满足。

定义新的 JAX 原语

为支持乘加功能的正确方式是使用现有的 JAX 原语,如上所示。然而,为了展示 JAX 原语的工作方式,让我们假装我们想为 JAX 添加一个新的原语来实现乘加功能。

from jax import core
multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

@trace("multiply_add_prim")
def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.

 Note that the traced arguments must be passed as positional arguments
 to `bind`. 
 """
  return multiply_add_p.bind(x, y, z)

@trace("square_add_prim")
def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b) 

如果我们尝试调用新定义的函数,我们会得到一个错误,因为我们尚未告诉 JAX 关于新原语的语义。

with expectNotImplementedError():
  square_add_prim(2., 10.) 
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)

Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/2844449444.py", line 2, in <module>
    square_add_prim(2., 10.)
  File "/tmp/ipykernel_1319/1393342955.py", line 48, in func_wrapper
    res = func(*args)
  File "/tmp/ipykernel_1319/1308506715.py", line 16, in square_add_prim
    return multiply_add_prim(a, a, b)
NotImplementedError: Evaluation rule for 'multiply_add' not implemented 

原始评估规则

@trace("multiply_add_impl")
def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

 This function does not need to be JAX traceable.
 Args:
 x, y, z: the concrete arguments of the primitive. Will only be called with 
 concrete values.
 Returns:
 the concrete result of the primitive.
 """
  # Note that we can use the original numpy, which is not JAX traceable
  return np.add(np.multiply(x, y), z)

# Now we register the primal implementation with JAX
multiply_add_p.def_impl(multiply_add_impl) 
<function __main__.multiply_add_impl(x, y, z)> 
assert square_add_prim(2., 10.) == 14. 
call square_add_prim(2.0, 10.0)
  call multiply_add_prim(2.0, 2.0, 10.0)
    call multiply_add_impl(2.0, 2.0, 10.0)
    |<- multiply_add_impl = 14.0
  |<- multiply_add_prim = 14.0
|<- square_add_prim = 14.0 

JIT

现在如果我们尝试使用 jit,我们会得到一个 NotImplementedError

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)

Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: Abstract evaluation for 'multiply_add' not implemented 

抽象评估规则

为了 JIT 函数以及其他转换,JAX 首先使用只有参数的形状和类型的抽象方式进行评估。这种抽象评估有多重目的:

  • 获取计算中使用的 JAX 原语序列。这个序列将被编译。

  • 计算所有向量和操作在计算中使用的形状和类型。

例如,具有 3 个元素的向量的抽象可能是 ShapedArray(float32[3])ConcreteArray([1., 2., 3.])。在后一种情况下,JAX 使用实际的具体值包装为抽象值。

from jax import core
@trace("multiply_add_abstract_eval")
def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

 This function does not need to be JAX traceable. It will be invoked with
 abstractions of the actual arguments. 
 Args:
 xs, ys, zs: abstractions of the arguments.
 Result:
 a ShapedArray for the result of the primitive.
 """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape
  return core.ShapedArray(xs.shape, xs.dtype)

# Now we register the abstract evaluation with JAX
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval) 
<function __main__.multiply_add_abstract_eval(xs, ys, zs)> 

如果我们重新尝试进行 JIT 编译,我们可以看到抽象评估的过程,但是我们会遇到另一个错误,关于缺少实际的 XLA 编译规则:

with expectNotImplementedError():
  api.jit(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>

Found expected exception: 
Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_1319/1813425700.py", line 2, in <module>
    api.jit(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py", line 326, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
NotImplementedError: MLIR translation rule for primitive 'multiply_add' not found for platform cpu 

XLA 编译规则

JAX 编译通过将每个原语编译成 XLA 操作的图形来工作。

这是向 JAX 添加新功能的最大障碍,因为 XLA 操作的集合是有限的,并且 JAX 已经为大多数操作预定义了原语。然而,XLA 包括一个 CustomCall 操作,可以用来封装使用 C++ 定义的任意功能。

from jax._src.lib.mlir.dialects import hlo
@trace("multiply_add_lowering")
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

 Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
 the results of the function.

 Does not need to be a JAX-traceable function.
 """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

# Now we register the lowering rule with JAX
# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
# TODO: TPU?
from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') 
<function __main__.multiply_add_lowering(ctx, xc, yc, zc)> 

现在我们成功 JIT。请注意下面,JAX 首先抽象评估函数,触发 multiply_add_abstract_eval 函数,然后编译它遇到的一系列原语,包括 multiply_add。在这一点上,JAX 调用 multiply_add_xla_translation

assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc664db0>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc688cf0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc688d70>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc682b30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8fd060>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a45a0d0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d4ea0, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object <module> at 0x7f0afd8d6b80, file "/tmp/ipykernel_1319/1570919344.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/1570919344.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/1570919344.py': '/tmp/ipykernel_1319/1570919344.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/1570919344.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afd8fe8f0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6835b0>] 

下面是另一个 jit 的用法,我们只编译关于第一个参数的部分。请注意,square_add_prim 的第二个参数是具体的,这导致第三个参数 multiply_add_abstract_evalConcreteArray。我们看到 multiply_add_abstract_eval 可以与 ShapedArrayConcreteArray 一起使用。

assert api.jit(lambda x, y: square_add_prim(x, y), 
               static_argnums=1)(2., 10.) == 14. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, 10.0)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- square_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc666480>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc690530>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6905b0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc690570>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8ffc40>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at "_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afd8d5b00, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 6): loc("<lambda>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object <module> at 0x7f0b2cd8b3c0, file "/tmp/ipykernel_1319/4165789807.py", line 1>, 20): loc("<module>"("/tmp/ipykernel_1319/4165789807.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object run_ast_nodes at 0x7f0b3686e3f0, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3418>, 500): loc("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0)), (<code object run_cell_async at 0x7f0b3686e080, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3183>, 828): loc("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0)), (<code object _pseudo_sync_runner at 0x7f0b36740c90, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py", line 120>, 8): loc("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/4165789807.py': '/tmp/ipykernel_1319/4165789807.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/4165789807.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69c250>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+01> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dcff0>] 

前向微分

JAX 在形式上实现了前向微分,即雅可比向量积(参见JAX 自动微分手册)。

现在,如果我们尝试计算 jvp 函数,会出现错误,因为我们尚未告诉 JAX 如何区分 multiply_add 原语。

# The second argument `(2., 10.)` are the argument values
# where we evaluate the Jacobian, and the third `(1., 1.)`
# are the values of the tangents for the arguments.
with expectNotImplementedError():
  api.jvp(square_add_prim, (2., 10.), (1., 1.)) 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)

Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/800067577.py", line 5, in <module>
    api.jvp(square_add_prim, (2., 10.), (1., 1.))
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1901, in jvp
    return _jvp(lu.wrap_init(fun), primals, tangents, has_aux=has_aux)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1930, in _jvp
    out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)
NotImplementedError: Differentiation rule for 'multiply_add' not implemented 
from jax.interpreters import ad

@trace("multiply_add_value_and_jvp")
def multiply_add_value_and_jvp(arg_values, arg_tangents):
  """Evaluates the primal output and the tangents (Jacobian-vector product).

 Given values of the arguments and perturbation of the arguments (tangents), 
 compute the output of the primitive and the perturbation of the output.

 This method must be JAX-traceable. JAX may invoke it with abstract values 
 for the arguments and tangents.

 Args:
 arg_values: a tuple of arguments
 arg_tangents: a tuple with the tangents of the arguments. The tuple has 
 the same length as the arg_values. Some of the tangents may also be the 
 special value ad.Zero to specify a zero tangent.
 Returns:
 a pair of the primal output and the tangent.
 """
  x, y, z = arg_values
  xt, yt, zt = arg_tangents
  _trace("Primal evaluation:")
  # Now we have a JAX-traceable computation of the output. 
  # Normally, we can use the ma primitive itself to compute the primal output. 
  primal_out = multiply_add_prim(x, y, z)

  _trace("Tangent evaluation:")
  # We must use a JAX-traceable way to compute the tangent. It turns out that 
  # the output tangent can be computed as (xt * y + x * yt + zt),
  # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.

  # We do need to deal specially with Zero. Here we just turn it into a 
  # proper tensor of 0s (of the same shape as 'x'). 
  # An alternative would be to check for Zero and perform algebraic 
  # simplification of the output tangent computation.
  def make_zero(tan):
    return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan  

  output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
  return (primal_out, output_tangent)

# Register the forward differentiation rule with JAX 
ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp 
# Tangent is: xt*y + x*yt + zt = 1.*2\. + 2.*1\. + 1\. = 5.
assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(10.0, dtype=float32, weak_type=True)>)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, 1.0, 1.0)
        call multiply_add_impl(2.0, 1.0, 1.0)
        |<- multiply_add_impl = 3.0
      |<- multiply_add_prim = 3.0
      call multiply_add_prim(1.0, 2.0, 3.0)
        call multiply_add_impl(1.0, 2.0, 3.0)
        |<- multiply_add_impl = 5.0
      |<- multiply_add_prim = 5.0
    |<- multiply_add_value_and_jvp = (14.0, 5.0)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)> 

解释如下:

  • JAX 在 square_add_prim 中为何使用 ConcreteArray?这里没有进行抽象评估。

  • 不确定如何解释 multiply_add_prim 是如何使用 ConcreteValue 调用的,但我们却没有调用 multiply_add_abstract_eval

  • 我认为在这里展示 jaxpr 将会很有用。

JIT 的前向微分

我们可以将 JIT 应用于前向微分函数:

assert api.jit(lambda arg_values, arg_tangents: 
                   api.jvp(square_add_prim, arg_values, arg_tangents))(
         (2., 10.), (1., 1.)) == (14., 5.) 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cc10>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc68ae70>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a598430>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cca0>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 3))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8dc2b0>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc6e5580>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc6f4430>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc6f4470>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc6f4230>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a58e100>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a598430>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a3459c0>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0) at "<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 36): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":27:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <lambda> at 0x7f0afc66b520, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 10): loc("<lambda>"("/tmp/ipykernel_1319/2145028508.py":2:0)), (<code object <module> at 0x7f0afc66b5d0, file "/tmp/ipykernel_1319/2145028508.py", line 1>, 16): loc("<module>"("/tmp/ipykernel_1319/2145028508.py":1:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/2145028508.py': '/tmp/ipykernel_1319/2145028508.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/tmp/ipykernel_1319/2145028508.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(<lambda>)'), Scope(name='jit(main)'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69cc10>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 2), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%3 = "stablehlo.add"(%2, %arg3) : (tensor<f32>, tensor<f32>) -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8b3270>] 

注意,我们首先抽象评估 multiply_add_value_and_jvp,它进而抽象评估 ma 的原始和切线评估(共 3 次调用 ma 原语)。然后编译这 3 次出现的原语。

反向微分

如果我们现在尝试使用反向微分,我们会看到 JAX 首先使用 multiply_add_value_and_jvp 来计算抽象值的前向微分,但随后遇到 NotImplementedError

在计算反向微分时,JAX 首先对前向微分代码 multiply_add_value_and_jvp 进行抽象评估,以获取一个追踪原语,用于计算输出切线。请注意,JAX 使用具体值评估此抽象评估以进行微分点,而使用抽象值评估切线。还需注意,JAX 对第三个参数的特殊抽象切线值 Zero,反映了我们不对 square_add_prim 的第二个参数进行微分,其流向 multiply_add_prim 的第三个参数。

还需注意,在计算切线的抽象评估期间,我们将值 0.0 作为第三个参数的切线传递。这是因为在 multiply_add_value_and_jvp 的定义中使用了 make_zero 函数。

# This is reverse differentiation w.r.t. the first argument of square_add_prim
with expectNotImplementedError():
  api.grad(square_add_prim)(2., 10.) 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>

Found expected exception: 
Traceback (most recent call last):
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py", line 284, in get_primitive_transpose
    return primitive_transposes[p]
KeyError: multiply_add

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/docs/.asdf/installs/python/3.10.14/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
jax._src.source_info_util.JaxStackTraceBeforeTransformation: NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/tmp/ipykernel_1319/339076514.py", line 3, in <module>
    api.grad(square_add_prim)(2., 10.)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 621, in grad_f
    _, g = value_and_grad_f(*args, **kwargs)
NotImplementedError: Transpose rule (for reverse-mode differentiation) for 'multiply_add' not implemented 

上述错误是因为缺少一个部分,JAX 无法使用前向微分代码来计算反向微分。

转置

正如上文所述,在计算反向微分时,JAX 获取了一个原语的追踪,使用前向微分计算切线。然后,JAX 以抽象方式反向解释此追踪,并对每个原语应用转置规则。

要理解正在发生的情况,请暂时考虑一个更简单的例子,函数“f(x, y) = x * y + y”。假设我们需要在点 (2., 4.) 处进行微分。JAX 将从输入 xtyt 的切线计算中生成以下 JVP 切线计算的 ft

 a = xt * 4.
   b = 2. * yt
   c = a + b
   ft = c + yt 

由于构造,切线计算在输入切线中始终是线性的。在切线计算中可能出现的唯一非线性操作符是乘法,但其中一个操作数是常量。

JAX 将通过反向处理 JVP 计算来生成反向微分计算。对于切线计算中的每个操作,它累积操作使用的变量的余切,使用操作结果的余切:

 # Initialize cotangents of inputs and intermediate vars
  xct = yct = act = bct = cct = 0.
  # Initialize cotangent of the output
  fct = 1.
  # Process "ft = c + yt"
  cct += fct
  yct += fct
  # Process "c = a + b"
  act += cct
  bct += cct
  # Process "b = 2\. * yt"
  yct += 2. * bct
  # Process "a = xt * 4."
  xct += act * 4. 

可以验证该计算产生了 xct = 4.yct = 3.,这是函数 f 的偏导数。

JAX 对于可能出现在 JVP 计算中的每个原语都知道如何对其进行转置。从概念上讲,如果原语 p(x, y, z) 在参数 yz 的常量值 x 下是线性的,例如 p(x, y, z) = y*cy + z*cz,那么原语的转置是:

p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz) 

注意 p_transpose 获取原语输出的余切以及与原语的每个参数对应的值。对于线性参数,转置获取未定义的 _ 值,对于其他参数,获取实际的常数。转置为原语的每个参数返回一个余切值,对于常数参数返回 None 值。

特别地,

 add_transpose(out_ct, _, _) = (out_ct, out_ct)
 mult_transpose(out_ct, x, _) = (None, x * out_ct)
 mult_transpose(out_ct, _, y) = (out_ct * y, None) 
@trace("multiply_add_transpose")
def multiply_add_transpose(ct, x, y, z):
  """Evaluates the transpose of a linear primitive.

 This method is only used when computing the backward gradient following 
 value_and_jvp, and is only needed for primitives that are used in the JVP 
 calculation for some other primitive. We need transposition for multiply_add_prim, 
 because we have used multiply_add_prim in the computation of the output_tangent in 
 multiply_add_value_and_jvp.

 In our case, multiply_add is not a linear primitive. However, it is used linearly 
 w.r.t. tangents in multiply_add_value_and_jvp:
 output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))

 Always one of the first two multiplicative arguments is a constant.

 Args:
 ct: the cotangent of the output of the primitive.
 x, y, z: values of the arguments. The arguments that are used linearly
 get an ad.UndefinedPrimal value. The other arguments get a constant
 value.
 Returns:
 a tuple with the cotangent of the inputs, with the value None
 corresponding to the constant arguments.
 """
  if not ad.is_undefined_primal(x):
    # This use of multiply_add is with a constant "x"
    assert ad.is_undefined_primal(y)
    ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
    res = None, ct_y, ct
  else:
    # This use of multiply_add is with a constant "y"
    assert ad.is_undefined_primal(x)
    ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
    res = ct_x, None, ct
  return res

ad.primitive_transposes[multiply_add_p] = multiply_add_transpose 

现在我们可以完成 grad 的运行:

assert api.grad(square_add_prim)(2., 10.) == 4. 
call square_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
  call multiply_add_prim(Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, Traced<ConcreteArray(2.0, dtype=float32, weak_type=True)>, 10.0)
    call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(2.0, 2.0, 10.0)
        call multiply_add_impl(2.0, 2.0, 10.0)
        |<- multiply_add_impl = 14.0
      |<- multiply_add_prim = 14.0
      Tangent evaluation:
      call multiply_add_prim(2.0, Traced<ShapedArray(float32[], weak_type=True)>, 0.0)
        call multiply_add_abstract_eval(ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[], weak_type=True), ConcreteArray(0.0, dtype=float32, weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, 2.0, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ConcreteArray(2.0, dtype=float32, weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (14.0, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
|<- square_add_prim = Traced<ConcreteArray(14.0, dtype=float32)>
call multiply_add_transpose(1.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 2.0, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(1.0, 2.0, 0.0)
    call multiply_add_impl(1.0, 2.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (2.0, None, 1.0)
call multiply_add_transpose(1.0, 2.0, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), 0.0)
  call multiply_add_prim(2.0, 1.0, 0.0)
    call multiply_add_impl(2.0, 1.0, 0.0)
    |<- multiply_add_impl = 2.0
  |<- multiply_add_prim = 2.0
|<- multiply_add_transpose = (None, 2.0, 1.0) 

注意到两次调用 multiply_add_transpose。它们对应于在 multiply_add_value_and_jvpoutput_tangent 计算中使用 multiply_add_prim 的两次使用。第一次调用转置对应于 multiply_add_prim(xt, y, ...) 的最后使用,其中 y 是常数 2.0。

反向微分的 JIT

注意 multiply_add_value_and_jvp 的抽象评估仅使用抽象值,在 JIT 缺失时我们使用了 ConcreteArray

assert api.jit(api.grad(square_add_prim))(2., 10.) == 4. 
call square_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_value_and_jvp((Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>), (Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>, Zero(ShapedArray(float32[], weak_type=True))))
      Primal evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
      Tangent evaluation:
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
      call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
        call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True), ShapedArray(float32[]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[])
      |<- multiply_add_prim = Traced<ShapedArray(float32[])>
    |<- multiply_add_value_and_jvp = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[])))
  call multiply_add_prim(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_transpose(Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, UndefinedPrimal(ShapedArray(float32[], weak_type=True)), Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
  call multiply_add_prim(Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>)
    call multiply_add_abstract_eval(ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True))
    |<- multiply_add_abstract_eval = ShapedArray(float32[])
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
|<- multiply_add_transpose = (None, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>)
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51c360>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc50ba30>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc508d30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc50bc30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afc69ef20>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a56fb00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc6694d0, file "/tmp/ipykernel_1319/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/3085343041.py': '/tmp/ipykernel_1319/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69d600>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(%0 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%1 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afc6d1b30>]
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51c360>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc50ba30>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc508d30>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc50bc30>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afc69ef20>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a56fb00>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))))))))))), <jaxlib.xla_extension.Traceback object at 0x56229a611410>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 88): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc6694d0, file "/tmp/ipykernel_1319/3085343041.py", line 1>, 18): loc("<module>"("/tmp/ipykernel_1319/3085343041.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)), (<code object multiply_add_value_and_jvp at 0x7f0b2cd8ae40, file "/tmp/ipykernel_1319/3197095916.py", line 4>, 86): loc("multiply_add_value_and_jvp"("/tmp/ipykernel_1319/3197095916.py":41:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/3197095916.py': '/tmp/ipykernel_1319/3197095916.py', '/tmp/ipykernel_1319/3085343041.py': '/tmp/ipykernel_1319/3085343041.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/3197095916.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/ad.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/3085343041.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='transpose'), Transform(name='jvp'))), primitive=multiply_add, avals_in=[ShapedArray(float32[], weak_type=True), ShapedArray(float32[]), ShapedArray(float32[], weak_type=True)], avals_out=[ShapedArray(float32[])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69d930>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<f32>' at index: 0), Value(%4 = "stablehlo.constant"() <{value = dense<1.000000e+00> : tensor<f32>}> : () -> tensor<f32>), Value(%5 = "stablehlo.constant"() <{value = dense<0.000000e+00> : tensor<f32>}> : () -> tensor<f32>))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8e6f70>] 

批处理

批处理转换将点式计算转变为向量上的计算。如果我们现在尝试,会得到 NotImplementedError

# The arguments are two vectors instead of two scalars
with expectNotImplementedError():
  api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
                                               np.array([10., 20.])) 
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)

Found expected exception: 
Traceback (most recent call last):
  File "/tmp/ipykernel_1319/2641678767.py", line 3, in <module>
    api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py", line 1214, in vmap_f
    out_flat = batching.batch(
NotImplementedError: Batching rule for 'multiply_add' not implemented 

我们需要告诉 JAX 如何评估原语的批处理版本。在这种特殊情况下,multiply_add_prim 已经适用于任意维度的输入向量逐点运算。因此,批处理版本可以使用相同的 multiply_add_prim 实现。

from jax.interpreters import batching

@trace("multiply_add_batch")
def multiply_add_batch(vector_arg_values, batch_axes):
  """Computes the batched version of the primitive.

 This must be a JAX-traceable function.

 Since the multiply_add primitive already operates pointwise on arbitrary
 dimension tensors, to batch it we can use the primitive itself. This works as
 long as both the inputs have the same dimensions and are batched along the
 same axes. The result is batched along the axis that the inputs are batched.

 Args:
 vector_arg_values: a tuple of two arguments, each being a tensor of matching
 shape.
 batch_axes: the axes that are being batched. See vmap documentation.
 Returns:
 a tuple of the result, and the result axis that was batched. 
 """
  assert batch_axes[0] == batch_axes[1]
  assert batch_axes[0] == batch_axes[2]
  _trace("Using multiply_add to compute the batch:")
  res = multiply_add_prim(*vector_arg_values)
  return res, batch_axes[0]

batching.primitive_batchers[multiply_add_p] = multiply_add_batch 
assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
  np.array([2., 3.]),
  np.array([10., 20.])),
  [14., 29.]) 
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch(([2\. 3.], [2\. 3.], [10\. 20.]), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim([2\. 3.], [2\. 3.], [10\. 20.])
        call multiply_add_impl([2\. 3.], [2\. 3.], [10\. 20.])
        |<- multiply_add_impl = [14\. 29.]
      |<- multiply_add_prim = [14\. 29.]
    |<- multiply_add_batch = ([14\. 29.], 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])> 

批处理的 JIT

assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
                    (np.array([2., 3.]),
                     np.array([10., 20.])),
                    [14., 29.]) 
call square_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
  call multiply_add_prim(Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>, Traced<ShapedArray(float32[])>)
    call multiply_add_batch((Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>), (0, 0, 0))
      Using multiply_add to compute the batch:
      call multiply_add_prim(Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>)
        call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))
        |<- multiply_add_abstract_eval = ShapedArray(float32[2])
      |<- multiply_add_prim = Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>
    |<- multiply_add_batch = (Traced<ShapedArray(float32[2])>with<DynamicJaxprTrace(level=1/0)>, 0)
  |<- multiply_add_prim = Traced<ShapedArray(float32[])>
|<- square_add_prim = Traced<ShapedArray(float32[])>
call multiply_add_lowering(LoweringRuleContext(module_context=ModuleContext(context=<jaxlib.mlir._mlir_libs._site_initialize.<locals>.Context object at 0x7f0afc51cd10>, module=<jaxlib.mlir._mlir_libs._mlir.ir.Module object at 0x7f0afc68bdb0>, ip=<jaxlib.mlir._mlir_libs._mlir.ir.InsertionPoint object at 0x7f0afc68aaf0>, symbol_table=<jaxlib.mlir._mlir_libs._mlir.ir.SymbolTable object at 0x7f0afc689eb0>, backend_or_name=<jaxlib.xla_extension.Client object at 0x7f0afd95b880>, platforms=('cpu',), axis_context=ShardingContext(num_devices=1, device_assignment=None), keepalives=[], channel_iterator=count(1), host_callbacks=[], shape_poly_state=<jax._src.interpreters.mlir.ShapePolyLoweringState object at 0x7f0afd8b7190>, cached_primitive_lowerings={}, traceback_caches=TracebackCaches(traceback_cache={<jaxlib.xla_extension.Traceback object at 0x56229a884960>: loc(callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_batch"("/tmp/ipykernel_1319/184469370.py":25:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0) at callsite("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0) at callsite("<module>"("/tmp/ipykernel_1319/1392464762.py":1:0) at "run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0)))))))))))}, location_cache={(<code object multiply_add_prim at 0x7f0afd853b50, file "/tmp/ipykernel_1319/1308506715.py", line 4>, 10): loc("multiply_add_prim"("/tmp/ipykernel_1319/1308506715.py":11:0)), (<code object func_wrapper at 0x7f0b2cd8b260, file "/tmp/ipykernel_1319/1393342955.py", line 45>, 24): loc("func_wrapper"("/tmp/ipykernel_1319/1393342955.py":48:0)), (<code object multiply_add_batch at 0x7f0afc6687c0, file "/tmp/ipykernel_1319/184469370.py", line 4>, 52): loc("multiply_add_batch"("/tmp/ipykernel_1319/184469370.py":25:0)), (<code object square_add_prim at 0x7f0afd8d5c60, file "/tmp/ipykernel_1319/1308506715.py", line 13>, 8): loc("square_add_prim"("/tmp/ipykernel_1319/1308506715.py":16:0)), (<code object <module> at 0x7f0afc668a80, file "/tmp/ipykernel_1319/1392464762.py", line 1>, 48): loc("<module>"("/tmp/ipykernel_1319/1392464762.py":1:0)), (<code object run_code at 0x7f0b3686e550, file "/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3541>, 76): loc("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0))}, canonical_name_cache={'/tmp/ipykernel_1319/1308506715.py': '/tmp/ipykernel_1319/1308506715.py', '/tmp/ipykernel_1319/1393342955.py': '/tmp/ipykernel_1319/1393342955.py', '/tmp/ipykernel_1319/184469370.py': '/tmp/ipykernel_1319/184469370.py', '/tmp/ipykernel_1319/1392464762.py': '/tmp/ipykernel_1319/1392464762.py', '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py'}, is_user_file_cache={'/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/source_info_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/core.py': False, '/tmp/ipykernel_1319/1308506715.py': True, '/tmp/ipykernel_1319/1393342955.py': True, '/tmp/ipykernel_1319/184469370.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/interpreters/batching.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/linear_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/api.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/traceback_util.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/profiler.py': False, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/jax/_src/pjit.py': False, '/tmp/ipykernel_1319/1392464762.py': True, '/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py': True}), lowering_parameters=LoweringParameters(override_lowering_rules=None, global_constant_computation=False, for_export=False)), name_stack=NameStack(stack=(Scope(name='jit(square_add_prim)'), Scope(name='jit(main)'), Transform(name='vmap'))), primitive=multiply_add, avals_in=[ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2])], avals_out=[ShapedArray(float32[2])], tokens_in=<jax._src.interpreters.mlir.TokenSet object at 0x7f0afc69e860>, tokens_out=None, axis_size_env=None, dim_var_values=[], compute_type=None, platforms=None), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 0), Value(<block argument> of type 'tensor<2xf32>' at index: 1))
|<- multiply_add_lowering = [<jaxlib.mlir._mlir_libs._mlir.ir.OpResult object at 0x7f0afd8920f0>] 

在 JAX 中编写自定义 Jaxpr 解释器

原文:jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html

在 Colab 中打开 在 Kaggle 中打开

JAX 提供了几个可组合的函数转换(jitgradvmap等),可以编写简洁且加速的代码。

这里我们展示了如何通过编写自定义 Jaxpr 解释器来向系统添加自己的函数转换。而且我们将自动获得与所有其他转换的可组合性。

此示例使用了内部 JAX API,可能随时会中断。任何不在API 文档中的内容都应视为内部内容。

import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax import random 

JAX 在做什么?

JAX 为数值计算提供了类似 NumPy 的 API,可以直接使用,但 JAX 真正的强大之处在于可组合的函数转换。例如jit函数转换接受一个函数并返回一个语义上相同的函数,但由 XLA 进行惰性编译以加速器。

x = random.normal(random.key(0), (5000, 5000))
def f(w, b, x):
  return jnp.tanh(jnp.dot(x, w) + b)
fast_f = jit(f) 

当我们调用fast_f时,会发生什么?JAX 会追踪函数并构建一个 XLA 计算图。然后将图进行即时编译(JIT)并执行。其他转换类似,它们首先会追踪函数并以某种方式处理输出追踪。要了解更多关于 JAX 追踪机制的信息,您可以参考 README 中的“How it works”部分。

Jaxpr 追踪器

Jax 中一个特别重要的追踪器是 Jaxpr 追踪器,它将操作记录到一个 Jaxpr(Jax 表达式)中。Jaxpr 是一种数据结构,可以像小型函数式编程语言一样进行评估,因此 Jaxprs 是函数转换的有用中间表示。

要首次查看 Jaxprs,可以考虑make_jaxpr转换。make_jaxpr本质上是一个“漂亮打印”转换:它将一个函数转换为一个函数,给定示例参数,生成其计算的 Jaxpr 表示。make_jaxpr对于调试和内省非常有用。让我们使用它来查看一些示例 Jaxprs 的结构。

def examine_jaxpr(closed_jaxpr):
  jaxpr = closed_jaxpr.jaxpr
  print("invars:", jaxpr.invars)
  print("outvars:", jaxpr.outvars)
  print("constvars:", jaxpr.constvars)
  for eqn in jaxpr.eqns:
    print("equation:", eqn.invars, eqn.primitive, eqn.outvars, eqn.params)
  print()
  print("jaxpr:", jaxpr)

def foo(x):
  return x + 1
print("foo")
print("=====")
examine_jaxpr(jax.make_jaxpr(foo)(5))

print()

def bar(w, b, x):
  return jnp.dot(w, x) + b + jnp.ones(5), x
print("bar")
print("=====")
examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10))) 
foo
=====
invars: [Var(id=140117887103104):int32[]]
outvars: [Var(id=140117887103296):int32[]]
constvars: []
equation: [Var(id=140117887103104):int32[], 1] add [Var(id=140117887103296):int32[]] {}

jaxpr: { lambda ; a:i32[]. let b:i32[] = add a 1 in (b,) }

bar
=====
invars: [Var(id=140117843771968):float32[5,10], Var(id=140117843772032):float32[5], Var(id=140117843772096):float32[10]]
outvars: [Var(id=140117843772352):float32[5], Var(id=140117843772096):float32[10]]
constvars: []
equation: [Var(id=140117843771968):float32[5,10], Var(id=140117843772096):float32[10]] dot_general [Var(id=140117843772160):float32[5]] {'dimension_numbers': (((1,), (0,)), ((), ())), 'precision': None, 'preferred_element_type': dtype('float32')}
equation: [Var(id=140117843772160):float32[5], Var(id=140117843772032):float32[5]] add [Var(id=140117843772224):float32[5]] {}
equation: [1.0] broadcast_in_dim [Var(id=140117843772288):float32[5]] {'shape': (5,), 'broadcast_dimensions': ()}
equation: [Var(id=140117843772224):float32[5], Var(id=140117843772288):float32[5]] add [Var(id=140117843772352):float32[5]] {}

jaxpr: { lambda ; a:f32[5,10] b:f32[5] c:f32[10]. let
    d:f32[5] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] a c
    e:f32[5] = add d b
    f:f32[5] = broadcast_in_dim[broadcast_dimensions=() shape=(5,)] 1.0
    g:f32[5] = add e f
  in (g, c) } 
  • jaxpr.invars - Jaxpr 的invars是一个输入变量列表,类似于 Python 函数的参数。

  • jaxpr.outvars - Jaxpr 的outvars是由 Jaxpr 返回的变量。每个 Jaxpr 都有多个输出。

  • jaxpr.constvars - constvars是一个变量列表,它们也是 Jaxpr 的输入之一,但对应于跟踪中的常量(我们稍后会更详细地讨论这些内容)。

  • jaxpr.eqns - 一个方程列表,实质上是 let 绑定。每个方程包含输入变量列表、输出变量列表和一个原语,用于评估输入以生成输出。每个方程还有一个 params,即参数字典。

总的来说,一个 Jaxpr 封装了一个简单的程序,可以使用输入进行评估以生成输出。稍后我们将详细介绍如何做到这一点。现在需要注意的重要事项是,Jaxpr 是一个可以按我们想要的方式操作和评估的数据结构。

Jaxprs 有什么用处?

Jaxprs 是简单的程序表示,易于转换。由于 Jax 允许我们从 Python 函数中分离出 Jaxprs,它为我们提供了一种转换用 Python 编写的数值程序的方法。

您的第一个解释器:invert

让我们尝试实现一个简单的函数“inverter”,它接收原始函数的输出,并返回产生这些输出的输入。现在,让我们专注于由其他可逆的一元函数组成的简单一元函数。

目标:

def f(x):
  return jnp.exp(jnp.tanh(x))
f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0) 

我们将通过 (1) 将 f 追踪到 Jaxpr 中,然后 (2) 反向解释 Jaxpr 的方式来实现这一点。在反向解释 Jaxpr 过程中,对于每个方程,我们将在表中查找原语的逆,并应用它。

1. 追踪一个函数

让我们使用 make_jaxpr 来追踪一个函数到 Jaxpr 中。

# Importing Jax functions useful for tracing/interpreting.
import numpy as np
from functools import wraps

from jax import core
from jax import lax
from jax._src.util import safe_map 

jax.make_jaxpr 返回一个封闭的 Jaxpr,即一个已经与跟踪中的常量(literals)捆绑在一起的 Jaxpr。

def f(x):
  return jnp.exp(jnp.tanh(x))

closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
print(closed_jaxpr.jaxpr)
print(closed_jaxpr.literals) 
{ lambda ; a:f32[5]. let b:f32[5] = tanh a; c:f32[5] = exp b in (c,) }
[] 

2. 评估 Jaxpr

在编写自定义 Jaxpr 解释器之前,让我们首先实现“默认”解释器 eval_jaxpr,它按原样评估 Jaxpr,计算与未转换的原始 Python 函数相同的值。

为此,我们首先创建一个环境来存储每个变量的值,并在评估 Jaxpr 中的每个方程时更新该环境。

def eval_jaxpr(jaxpr, consts, *args):
  # Mapping from variable -> value
  env = {}

  def read(var):
    # Literals are values baked into the Jaxpr
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val

  # Bind args and consts to environment
  safe_map(write, jaxpr.invars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Loop through equations and evaluate primitives using `bind`
  for eqn in jaxpr.eqns:
    # Read inputs to equation from environment
    invals = safe_map(read, eqn.invars)  
    # `bind` is how a primitive is called
    outvals = eqn.primitive.bind(*invals, **eqn.params)
    # Primitives may return multiple outputs or not
    if not eqn.primitive.multiple_results: 
      outvals = [outvals]
    # Write the results of the primitive into the environment
    safe_map(write, eqn.outvars, outvals) 
  # Read the final result of the Jaxpr from the environment
  return safe_map(read, jaxpr.outvars) 
closed_jaxpr = jax.make_jaxpr(f)(jnp.ones(5))
eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5)) 
[Array([2.1416876, 2.1416876, 2.1416876, 2.1416876, 2.1416876], dtype=float32)] 

注意,即使原始函数不返回平坦列表,eval_jaxpr 也将始终返回一个平坦列表。

此外,这个解释器不处理高阶原语(如 jitpmap),这些内容不在本指南讨论范围内。您可以参考 core.eval_jaxpr (链接) 来查看此解释器不涵盖的边界情况。

自定义inverse Jaxpr 解释器

inverse 解释器看起来与 eval_jaxpr 并无太大不同。我们首先设置注册表,将原语映射到它们的逆。然后编写一个自定义解释器,在注册表中查找原语。

结果表明,这个解释器看起来也类似于反向模式自动微分中使用的“转置”解释器,可以在此处找到:链接

inverse_registry = {} 

现在我们将为一些原语注册它们的逆。按照惯例,Jax 中的原语以 _p 结尾,而其中许多流行的原语位于 lax 中。

inverse_registry[lax.exp_p] = jnp.log
inverse_registry[lax.tanh_p] = jnp.arctanh 

inverse 将首先跟踪函数,然后自定义解释 Jaxpr。让我们建立一个简单的框架。

def inverse(fun):
  @wraps(fun)
  def wrapped(*args, **kwargs):
    # Since we assume unary functions, we won't worry about flattening and
    # unflattening arguments.
    closed_jaxpr = jax.make_jaxpr(fun)(*args, **kwargs)
    out = inverse_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, *args)
    return out[0]
  return wrapped 

现在我们只需要定义 inverse_jaxpr,它将反向遍历 Jaxpr 并在可能时反转原语。

def inverse_jaxpr(jaxpr, consts, *args):
  env = {}

  def read(var):
    if type(var) is core.Literal:
      return var.val
    return env[var]

  def write(var, val):
    env[var] = val
  # Args now correspond to Jaxpr outvars
  safe_map(write, jaxpr.outvars, args)
  safe_map(write, jaxpr.constvars, consts)

  # Looping backward
  for eqn in jaxpr.eqns[::-1]:
    #  outvars are now invars 
    invals = safe_map(read, eqn.outvars)
    if eqn.primitive not in inverse_registry:
      raise NotImplementedError(
          f"{eqn.primitive} does not have registered inverse.")
    # Assuming a unary function 
    outval = inverse_registryeqn.primitive
    safe_map(write, eqn.invars, [outval])
  return safe_map(read, jaxpr.invars) 

就是这样!

def f(x):
  return jnp.exp(jnp.tanh(x))

f_inv = inverse(f)
assert jnp.allclose(f_inv(f(1.0)), 1.0) 

重要的是,你可以通过 Jaxpr 解释器进行跟踪。

jax.make_jaxpr(inverse(f))(f(1.)) 
{ lambda ; a:f32[]. let b:f32[] = log a; c:f32[] = atanh b in (c,) } 

这就是向系统添加新转换所需的全部内容,而且你可以免费获得所有其他转换的组合!例如,我们可以在 inverse 中使用 jitvmapgrad

jit(vmap(grad(inverse(f))))((jnp.arange(5) + 1.) / 5.) 
Array([-3.1440797, 15.584931 ,  2.2551253,  1.3155028,  1\.       ],      dtype=float32, weak_type=True) 

读者的练习

  • 处理具有多个参数的原语,其中输入部分已知,例如 lax.add_plax.mul_p

  • 处理 xla_callxla_pmap 原语,这些原语不会与 eval_jaxprinverse_jaxpr 一样正常工作。

使用 C++ 和 CUDA 进行 GPU 自定义操作

原文:jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html

JAX 预装有大量内置操作,但用户偶尔会遇到需要新操作但 JAX 不支持的情况。

为了适应这些情况,JAX 允许用户定义自定义操作,本教程旨在解释如何为 GPU 定义并在单 GPU 和多 GPU 环境中使用它们。

本教程包含来自 使用自定义 C++ 和 CUDA 代码扩展 JAX 的信息,并假设您熟悉 JAX 原语

RMS 标准化

本教程将 RMS 标准化作为 JAX 中的自定义操作添加。请注意,可以直接使用 jax.numpy 表达 RMS 标准化。但是,我们使用它作为示例来展示如何为 GPU 创建自定义操作的过程。此操作在 gpu_ops/rms_norm_kernels.cu 中的 CUDA 代码已从 Apex 借用,并进行了修改,以消除对 PyTorch 的任何依赖。

高级步骤

本教程展示了如何编写自定义操作及其梯度。

在 C 中:每个新的 JAX 原语都需要按照以下步骤进行操作。

  • 具有 CUDA 核心(核心)。

  • 创建分派 CUDA 核心的 C 函数,该函数将由 XLA 调用。

  • 创建描述符以传达计算所需的信息。

    • 类型、形状和其他属性。
  • 将 C 函数绑定到 Python

    • 以创建描述符并在执行期间调用原语。

在 Python 中:您需要按照以下步骤进行操作。

  • 定义新的 JAX 原语(指令/操作)

  • 编写 Python 函数以使用原语构建图节点。

  • 定义其抽象评估。

  • 定义其降低到 MLIR。

  • [可选] 定义梯度。

  • [可选] 使用 custom_partitioningshard_map 函数实现快速多 GPU 支持。

C 代码

参见 gpu_ops 代码列表,其中包含完整的 C++ 和 CUDA 文件代码列表。gpu_ops/rms_norm_kernels.cu 定义了以下函数,这些函数使用给定的 buffers 在指定的 stream 上启动 RMS 标准化核心。

namespace  gpu_ops  {

void  rms_forward_affine_mixed_dtypes(cudaStream_t  stream,  void  **buffers,
  const  char  *opaque,
  std::size_t  opaque_len);

void  rms_backward_affine(cudaStream_t  stream,  void  **buffers,
  const  char  *opaque,  std::size_t  opaque_len);

}  // namespace gpu_ops 
  • stream 是用于在 GPU 上执行任何核心的 CUDA 流。

  • buffers 包含所有指向输入缓冲区的指针,后跟所有指向输出缓冲区的指针。

  • opaque 是传递给自定义函数的任何额外信息的缓冲区,而 opaque_lenopaque 的长度。

在本教程中,我们将通过opaque将一个RMSNormDescriptor对象传递给这些函数。

namespace  gpu_ops  {

enum  ElementType  {  BF16,  F16,  F32,  F64  };

struct  RMSNormDescriptor  {
  int  n1;
  int  n2;
  double  eps;
  ElementType  x_type;
  ElementType  w_type;
  int  part_grad_size;
};

}  // namespace gpu_ops 

现在,我们需要通过pybind11将这些函数以及ElementTypeRMSNormDescriptor作为 Python 模块gpu_ops公开。

pybind11::dict  RMSNormRegistrations()  {
  pybind11::dict  dict;
  dict["rms_forward_affine_mixed_dtype"]  =
  gpu_ops::EncapsulateFunction(gpu_ops::rms_forward_affine_mixed_dtypes);
  dict["rms_backward_affine"]  =
  gpu_ops::EncapsulateFunction(gpu_ops::rms_backward_affine);
  return  dict;
}

PYBIND11_MODULE(gpu_ops,  m)  {
  m.def("get_rms_norm_registrations",  &RMSNormRegistrations);
  m.def("create_rms_norm_descriptor",
  [](int  n1,  int  n2,  double  eps,  gpu_ops::ElementType  x_type,
  gpu_ops::ElementType  w_type,  int  part_grad_size)  {
  return  gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
  n1,  n2,  eps,  x_type,  w_type,  part_grad_size});
  });

  pybind11::enum_<gpu_ops::ElementType>(m,  "ElementType")
  .value("BF16",  gpu_ops::ElementType::BF16)
  .value("F16",  gpu_ops::ElementType::F16)
  .value("F32",  gpu_ops::ElementType::F32)
  .value("F64",  gpu_ops::ElementType::F64);

} 

构建gpu_ops扩展模块

我们使用上述代码构建了gpu_ops Python 扩展模块。(请参阅 C++和 CUDA 文件的完整代码清单,查看gpu_ops代码列表。)

python  -m  pip  install  pybind11==2.10.1
mkdir  -p  build
pybind_include_path=$(python  -c  "import pybind11; print(pybind11.get_include())")
python_executable=$(python  -c  'import sys; print(sys.executable)')

nvcc  --threads  4  -Xcompiler  -Wall  -ldl  --expt-relaxed-constexpr  -O3  -DNDEBUG  -Xcompiler  -O3  --generate-code=arch=compute_70,code=[compute_70,sm_70]  --generate-code=arch=compute_75,code=[compute_75,sm_75]  --generate-code=arch=compute_80,code=[compute_80,sm_80]  --generate-code=arch=compute_86,code=[compute_86,sm_86]  -Xcompiler=-fPIC  -Xcompiler=-fvisibility=hidden  -x  cu  -c  gpu_ops/rms_norm_kernels.cu  -o  build/rms_norm_kernels.cu.o
c++  -I/usr/local/cuda/include  -I$pybind_include_path  $(${python_executable}-config  --cflags)  -O3  -DNDEBUG  -O3  -fPIC  -fvisibility=hidden  -flto  -fno-fat-lto-objects  -o  build/gpu_ops.cpp.o  -c  gpu_ops/gpu_ops.cpp
c++  -fPIC  -O3  -DNDEBUG  -O3  -flto  -shared  -o  build/gpu_ops$(${python_executable}-config  --extension-suffix)  build/gpu_ops.cpp.o  build/rms_norm_kernels.cu.o  -L/usr/local/cuda/lib64  -lcudadevrt  -lcudart_static  -lrt  -lpthread  -ldl
strip  build/gpu_ops$(${python_executable}-config  --extension-suffix) 

将 RMS 归一化添加到 JAX 作为自定义调用

gpu_ops只是一个 Python 扩展模块,我们需要更多工作来将其插入到 JAX 中。

创建原语

我们首先创建了原语_rms_norm_fwd_p_rms_norm_bwd_p,这些原语可以映射到自定义函数。我们为这些操作设置了multiple_results属性为True,表示该操作作为元组产生多个输出。当设置为False时,该操作将产生单个输出而不是元组。有关更多详细信息,请参见How JAX primitives work

from functools import partial

import jax
import jax.numpy as jnp
import jax._src.test_util as jtu
from build import gpu_ops
from jax import core, dtypes
from jax.interpreters import xla
from jax.lib import xla_client

# Create _rms_norm_fwd_p for forward operation.
_rms_norm_fwd_p = core.Primitive("rms_norm_fwd")
_rms_norm_fwd_p.multiple_results = True
_rms_norm_fwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_fwd_p))

def rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
    return output

# Create _rms_norm_bwd_p for backward operation.
_rms_norm_bwd_p = core.Primitive("rms_norm_bwd")
_rms_norm_bwd_p.multiple_results = True
_rms_norm_bwd_p.def_impl(partial(xla.apply_primitive, _rms_norm_bwd_p))

def rms_norm_bwd(g, invvar, x, weight, eps):
    grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
        g, invvar, x, weight, eps=eps
    )
    return grad_input, grad_weight 

降低到 MLIR 自定义调用

为了将自定义函数映射到新原语_rms_norm_fwd_p_rms_norm_bwd_p,我们需要:

  • 使用xla_client.register_custom_call_target注册自定义函数作为自定义调用目标,并且

  • 注册将原语降低为 MLIR 自定义调用的降低函数,并使用注册的自定义调用目标。

下面的函数_rms_norm_fwd_cuda_lowering_rms_norm_bwd_cuda_lowering通过gpu_ops中的自定义目标将原语降低为 MLIR 自定义调用操作。这些函数已经注册到jax.interpreters.mlir.register_lowering中。

注意,在降低函数中创建了一个RMSNormDescriptor对象,并将其作为opaque传递给自定义调用。

from functools import reduce

from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.hlo_helpers import custom_call

# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")

def element_type_to_descriptor_type_mapping(element_type):
    _element_type_to_descriptor_type_mapping = {
        ir.BF16Type.get(): gpu_ops.ElementType.BF16,
        ir.F16Type.get(): gpu_ops.ElementType.F16,
        ir.F32Type.get(): gpu_ops.ElementType.F32,
        ir.F64Type.get(): gpu_ops.ElementType.F64,
    }
    return _element_type_to_descriptor_type_mapping.get(element_type)

def default_layouts(*shapes):
    return [range(len(shape) - 1, -1, -1) for shape in shapes]

def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
    x_type = ir.RankedTensorType(x.type)
    x_shape = x_type.shape
    w_type = ir.RankedTensorType(weight.type)
    w_shape = w_type.shape
    iv_element_type = (
        ir.F32Type.get()
        if x_type.element_type in [ir.F16Type.get(), ir.BF16Type.get()]
        else x_type.element_type
    )

    n2 = reduce(lambda x, y: x * y, w_shape)
    n1 = reduce(lambda x, y: x * y, x_shape) // n2

    opaque = gpu_ops.create_rms_norm_descriptor(
        n1,
        n2,
        eps,
        element_type_to_descriptor_type_mapping(x_type.element_type),
        element_type_to_descriptor_type_mapping(w_type.element_type),
        0,  # unused
    )
    out = custom_call(
        b"rms_forward_affine_mixed_dtype",
        result_types=[
            ir.RankedTensorType.get(x_shape, w_type.element_type),
            ir.RankedTensorType.get((n1,), iv_element_type),
        ],
        operands=[x, weight],
        backend_config=opaque,
        operand_layouts=default_layouts(x_shape, w_shape),
        result_layouts=default_layouts(x_shape, (n1,)),
    ).results
    return out

mlir.register_lowering(
    _rms_norm_fwd_p,
    _rms_norm_fwd_cuda_lowering,
    platform="gpu",
)

def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
    x_type = ir.RankedTensorType(x.type)
    x_shape = x_type.shape
    w_type = ir.RankedTensorType(weight.type)
    w_shape = w_type.shape
    iv_type = ir.RankedTensorType(invvar.type)

    n2 = reduce(lambda x, y: x * y, w_shape)
    n1 = reduce(lambda x, y: x * y, x_shape) // n2

    part_grad_shape = ctx.avals_out[-1].shape

    opaque = gpu_ops.create_rms_norm_descriptor(
        n1,
        n2,
        eps,
        element_type_to_descriptor_type_mapping(x_type.element_type),
        element_type_to_descriptor_type_mapping(w_type.element_type),
        part_grad_shape[0],
    )
    out = custom_call(
        b"rms_backward_affine",
        result_types=[
            ir.RankedTensorType.get(x_shape, x_type.element_type),
            ir.RankedTensorType.get(w_shape, w_type.element_type),
            ir.RankedTensorType.get(part_grad_shape, iv_type.element_type),
        ],
        operands=[grad_output, invvar, x, weight],
        backend_config=opaque,
        operand_layouts=default_layouts(x_shape, (n1,), x_shape, w_shape),
        result_layouts=default_layouts(x_shape, w_shape, part_grad_shape),
    ).results
    return out

mlir.register_lowering(
    _rms_norm_bwd_p,
    _rms_norm_bwd_cuda_lowering,
    platform="gpu",
) 

让我们进行测试

per_core_batch_size=4
seq_len=512
emb_dim=512
x = jax.random.normal(
    jax.random.PRNGKey(0),
    shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim),
    dtype=jnp.bfloat16,
)
norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16) 

测试前向函数

out = rms_norm_fwd(x, weight) 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [5], line 1
----> 1 out = rms_norm_fwd(x, weight)

...

NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented 

抽象评估

上述测试失败,报错信息为NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented。为什么测试失败?这是什么意思?

作为执行的一部分,JAX 执行抽象评估。由于 JAX 对新原语没有任何了解,因此不知道如何计算输出形状和输出数据类型,因此无法进行这些操作的抽象评估。

我们需要为每个原语的抽象评估提供一个函数。这些抽象评估函数计算输出的形状和数据类型,但不计算操作的实际值。

这些函数将传递给.def_abstract_eval方法,以便与相应的原语进行注册。

更多关于抽象评估的信息,请参见How JAX primitives work

from functools import reduce
from operator import mul

from jax.core import ShapedArray

def _rms_norm_fwd_abstract(x, weight, eps):
    w_dtype = dtypes.canonicalize_dtype(weight.dtype)
    iv_dtype = dtypes.canonicalize_dtype(x.dtype)
    if iv_dtype in [jnp.float16, jnp.bfloat16]:
        iv_dtype = jnp.float32
    n2 = reduce(mul, weight.shape)
    n1 = reduce(mul, x.shape) // n2
    return (
        ShapedArray(x.shape, w_dtype, named_shape=x.named_shape),  # output
        ShapedArray((n1,), iv_dtype, named_shape=x.named_shape),  # invvar
    )

_rms_norm_fwd_p.def_abstract_eval(_rms_norm_fwd_abstract)

def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
    iv_dtype = dtypes.canonicalize_dtype(invvar.dtype)
    w_dtype = dtypes.canonicalize_dtype(weight.dtype)
    x_dtype = dtypes.canonicalize_dtype(x.dtype)
    n2 = reduce(lambda x, y: x * y, weight.shape)
    n1 = reduce(lambda x, y: x * y, x.shape) // n2
    part_grad_shape = (16, n2)
    assert dtypes.canonicalize_dtype(grad_output.dtype) == w_dtype
    assert grad_output.shape == x.shape
    assert invvar.shape == (n1,)
    assert (
        iv_dtype == jnp.float32 if x_dtype in [jnp.float16, jnp.bfloat16] else x_dtype
    )
    assert grad_output.named_shape == x.named_shape
    weight_named_shape = (
        weight_named_shape if weight.named_shape else x.named_shape
    )
    return (
        ShapedArray(
            x.shape, x_dtype, named_shape=x.named_shape
        ),  # grad input
        ShapedArray(
            weight.shape, w_dtype, named_shape=weight_named_shape
        ),  # grad weight
        ShapedArray(
            part_grad_shape, iv_dtype, named_shape=weight_named_shape
        ),  # part grad
    )

_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract) 

让我们再次进行测试

测试前向函数

out = rms_norm_fwd(x, weight) 

测试反向函数

现在让我们使用jax.gradjtu.check_grads测试反向操作。

def loss(x, weight):
    predictions = rms_norm_fwd(x, weight)
    return -jnp.mean(predictions**2)

loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1) 
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [8], line 7
      3     return -jnp.mean(predictions**2)
      6 loss_grad = jax.grad(loss)
----> 7 out = loss_grad(x, weight)

...

NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented 

差分规则

反向操作以 NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented 错误失败。这意味着,尽管我们定义了 rms_norm_fwdrms_norm_bwd,但 JAX 不知道它们之间的关系。

我们可以使用 jax.custom_vjp 及其约定,教给 JAX rms_norm_bwdrms_norm_fwd 的反向操作。作为第一步,我们需要完善 rms_norm_fwdrms_norm_bwd 的定义。

# rms_norm_fwd was previously defined as
#
# def rms_norm_fwd(x, weight, eps=1e-05):
#     output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
#     return output
#
def rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = _rms_norm_fwd_p.bind(x, weight, eps=eps)
    return output, (invvar, x, weight)

# rms_norm_bwd was previously defined as
#
# def rms_norm_bwd(g, invvar, x, weight, eps):
#     grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
#         g, invvar, x, weight, eps=eps
#     )
#     return grad_input, grad_weight
#
def rms_norm_bwd(eps, res, g):
    invvar, x, weight = res
    grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
        g, invvar, x, weight, eps=eps
    )
    return grad_input, grad_weight 

rms_norm_fwd 现在返回额外的输出 (invvar, x, weight) 作为剩余数据,而 rms_norm_bwd 接受 epsresg 作为参数。

通过 jax.custom_vjp 建立 rms_norm_fwdrms_norm_bwd 之间的关系后,JAX 将确保从 rms_norm_fwd 中传递的剩余数据作为反向操作的 res 传递给 rms_norm_bwd。对于像 eps 这样的不可微参数,JAX 确保它们在剩余数据之前传递给反向操作。这就是为什么 epsrms_norm_bwd 的参数列表中位于 res 之前。

现在 rms_norm_fwd 返回了不需要用于简单 RMS 标准化操作的剩余数据,我们在其周围定义了一个包装器 rms_norm,它简单地调用 rms_norm_fwd 并仅返回 output。请注意,rms_norm 使用 @partial(jax.custom_vjp, nondiff_argnums=(2,)) 进行了注释,并且我们将 rms_norm_fwdrms_norm_bwd 传递给 rms_norm.defvjp。这教会了 JAX,在对 rms_norm 进行微分时,使用 rms_norm_fwd 进行前向操作,使用 rms_norm_bwd 进行反向操作。

有关使用 jax.custom_vjp 定义 JAX 可转换 Python 函数的自定义导数规则,请参阅自定义导数规则

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def rms_norm(x, weight, eps=1e-05):
    output, _ = rms_norm_fwd(x, weight, eps=eps)
    return output

rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) 

经过我们的改进,反向操作测试与修改一起正常运行:loss 现在调用 rms_norm 而不是 rms_norm_fwd

def loss(x, weight):
    predictions = rms_norm(x, weight)
    return -jnp.mean(predictions**2)

loss_grad = jax.grad(loss)
out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1) 

让我们在多个设备上进行测试。

我们正在使用 jax.experimental.pjit.pjit 在多个设备上进行并行执行,并在单个设备上的顺序执行中生成参考值。

测试前向函数。

让我们先在多个设备上测试前向操作。我们创建了一个简单的 1D 网格,并在所有设备上分片 x

from jax.sharding import Mesh, PartitionSpec
from jax.experimental.pjit import pjit

mesh = Mesh(jax.local_devices(), ("x",))
ref = rms_norm(x, weight)
pjitted = pjit(
    rms_norm,
    # Shard x by batch dimension and replicate weight on all devices.
    in_shardings=(PartitionSpec("x", None, None), PartitionSpec(None, None)),
    # Shard the output by batch dimension.
    out_shardings=PartitionSpec("x", None, None),
)

with mesh:
    print(pjitted.lower(x, weight).compile().runtime_executable().hlo_modules()[0].to_string())
    out = pjitted(x, weight)

jnp.allclose(ref, out, atol=1e-5, rtol=1e-5) 
HloModule pjit_rms_norm, entry_computation_layout={(bf16[4,512,512]{2,1,0},bf16[512,512]{1,0})->bf16[4,512,512]{2,1,0}}

%fused_computation (param_1: bf16[32,512,512], param_1.3: u32[]) -> bf16[4,512,512] {
  %param_1 = bf16[32,512,512]{2,1,0} parameter(0)
  %param_1.3 = u32[] parameter(1)
  %convert.2 = s32[] convert(u32[] %param_1.3), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %constant_9 = s32[] constant(4), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %multiply.3 = s32[] multiply(s32[] %convert.2, s32[] %constant_9), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %constant_8 = s32[] constant(0), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  ROOT %dynamic-slice.2 = bf16[4,512,512]{2,1,0} dynamic-slice(bf16[32,512,512]{2,1,0} %param_1, s32[] %multiply.3, s32[] %constant_8, s32[] %constant_8), dynamic_slice_sizes={4,512,512}, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
}

ENTRY %main.7_spmd (param: bf16[4,512,512], param.1: bf16[512,512]) -> bf16[4,512,512] {
  %param = bf16[4,512,512]{2,1,0} parameter(0), sharding={devices=[8,1,1]0,1,2,3,4,5,6,7}
  %all-gather = bf16[32,512,512]{2,1,0} all-gather(bf16[4,512,512]{2,1,0} %param), channel_id=1, replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, use_global_device_ids=true, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %param.1 = bf16[512,512]{1,0} parameter(1), sharding={replicated}
  %custom-call.0 = (bf16[32,512,512]{2,1,0}, f32[32]{0}) custom-call(bf16[32,512,512]{2,1,0} %all-gather, bf16[512,512]{1,0} %param.1), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={bf16[32,512,512]{2,1,0}, bf16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}, backend_config=" \000\000\000\000\000\004\000\361h\343\210\265\370\344>\000\000\000\000\000\000\000\000\000\000\000\000\255\177\000\000"
  %get-tuple-element = bf16[32,512,512]{2,1,0} get-tuple-element((bf16[32,512,512]{2,1,0}, f32[32]{0}) %custom-call.0), index=0, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  %partition-id = u32[] partition-id(), metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
  ROOT %fusion = bf16[4,512,512]{2,1,0} fusion(bf16[32,512,512]{2,1,0} %get-tuple-element, u32[] %partition-id), kind=kLoop, calls=%fused_computation, metadata={op_name="pjit(rms_norm)/jit(main)/rms_norm_fwd[eps=1e-05]" source_file="/tmp/ipykernel_25235/3343076723.py" source_line=8}
} 
True 

对于前向操作,值已经计算正确,然而生成的 HLO 模块显示一个 all-gather 操作来在所有设备上复制 x,导致大量的通信开销。

由于 XLA 对于自定义函数不具备足够的知识来分片输入张量,它决定在进行自定义调用之前将它们复制以生成正确的值。

为了避免这种重复,我们可以:

  • custom_partitioning:使其表现得像所有本机 JAX 操作一样(但更复杂)。

  • 使用手动分片

此示例演示了使用 custom_partitioning 的用法。

使用 custom_partitioning 分片向前函数

首先创建一个辅助函数来帮助所有需要的 JAX/XLA 回调注册。

def register_primitive(cls):
  """
 register jax primitive

 The order of calls. Each operation is composed of two primitives: Inner and Outer.

 Inner, only the basic to wrap the custom_call itself.
 - impl to XLA custom_call in C.
 - abstract to know the static shapes
 - lower to StableHLO XLA custom_call.
 Outer, mostly all the rest:
 - impl: Bind to the inner primitive. Not used for real computation, but only for tracing. So we only need to bind.
 - abstract: same
 - lower to StableHLO custom_p. (XLA will call the python callback from it)
 - custom_p
 - vmap: could be added here.
 VJP is based on Outer, but not handled in this function.
 """

    def name_of_wrapper_p():
        return cls.name + "_wrapper"

    inner_p = core.Primitive(cls.name)
    dispatch.prim_requires_devices_during_lowering.add(inner_p)
    inner_p.multiple_results = cls.multiple_results
    inner_p.def_impl(partial(xla.apply_primitive, inner_p))
    inner_p.def_abstract_eval(cls.abstract)
    mlir.register_lowering(inner_p, cls.lowering, platform='cuda')
    cls.inner_primitive = inner_p

    outer_p = core.Primitive(name_of_wrapper_p())
    dispatch.prim_requires_devices_during_lowering.add(outer_p)
    outer_p.multiple_results = cls.multiple_results
    outer_p.def_impl(cls.impl)
    outer_p.def_abstract_eval(cls.abstract)
    batching.primitive_batchers[outer_p] = cls.batcher
    outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args)
    outer_p_lower.def_partition(infer_sharding_from_operands=cls.infer_sharding_from_operands,
                                partition=cls.partition)
    mlir.register_lowering(outer_p,
                           mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results))
    cls.outer_primitive = outer_p
... 

我们定义了两个 JAX 原语,一个内部原语映射到我们想要在 JAX 中封装的真实内核。还有一个外部原语,将与自定义分区注册一起使用,并用于梯度。(如果您实现支持 vmat 的接口,它也将位于外部原语中)。

JAX custom_partitioning 实现是 XLA 在 XLA 分片逻辑期间从 XLA 到 Python 的回调。XLA 分片分为两个阶段:分片传播阶段和分区阶段。传播阶段是 XLA 规划要创建的分片的阶段。分区阶段创建分片图。为了让 XLA 能够分片我们的自定义操作,它需要我们定义两个额外的函数:infer_sharding_from_operands() 和 partition()。它们分别在第一阶段和第二阶段中使用。

infer_sharding_from_operands() 函数必须做其名称所述的事情:从输入分片推断输出分片。

partition() 函数将执行几个操作:

  • 告诉预期将有哪些输入分片。如有必要,XLA 将进行重新分片。

  • 告诉输出分片的最终版本。

  • 给出一个函数,将从分片输入创建新指令。

查看代码注释以获取更多解释:

class RmsNormFwdClass:
    name = "rms_forward_affine_mixed_dtype"
    multiple_results = True
    impl_static_args = (2,)    # eps
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
                                     arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                                     result_infos : Tuple[jax._src.core.ShapedArray]):
        del eps, result_infos  # Not needed for this example.
        x_info, weight_info = arg_infos
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        # partition() will force all dims of all inputs to be replicated except the
        # first dim of x that will be kept as is.
        # This is because the implementaion can only be sharded on the batch dimensions.

        x_spec = arg_infos[0].sharding.spec
        # None mean that we replicate on that dimension.
        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
        return (output_sharding, invvar_sharding)

    @staticmethod
    def partition(eps : float, mesh : jax.sharding.Mesh,
                  arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                  result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
        del result_infos  # Not needed for this example.
        x_info, weight_info = arg_infos
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        x_spec = arg_infos[0].sharding.spec
        # We only support sharding on the batch dimensions.
        # Force sharding on all others dimensions with None.
        arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(None, None)))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0]))
        output_shardings = (arg_shardings[0], invvar_sharding)
        # Sharded_impl only accepts positional arugments
        # And they should be Jax traceable variables
        impl = partial(RmsNormFwdClass.impl, eps=eps)

        return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormFwdClass) 

接下来我们定义 RMSNorm 后向传递的原语

使用 custom_partitioning 分片向后函数

class RmsNormBwdClass:
    name = "rms_norm_bwd"
    multiple_results = True
    impl_static_args = (4,)    # eps
    inner_primitive = None
    outer_primitive = None

    @staticmethod
    def infer_sharding_from_operands(eps : float, mesh : jax.sharding.Mesh,
                                     arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                                     result_infos : Tuple[jax._src.core.ShapedArray]):
        del eps, result_infos  # Not needed for this example.
        g_info, invvar_info, x_info, weight_info = arg_infos
        assert len(g_info.shape) == 3
        assert len(invvar_info.shape) == 1
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2
        # partition() will force all dims to be replicated except the batch dimension.
        x_spec = x_info.sharding.spec
        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
        return (output_sharding, invvar_sharding, output_sharding, )

    @staticmethod
    def partition(eps : float, mesh : jax.sharding.Mesh,
                  arg_infos : Tuple[jax._src.api.ShapeDtypeStruct],
                  result_infos : Tuple[jax._src.api.ShapeDtypeStruct]):
        del result_infos  # Not needed for this example.
        g_info, invvar_info, x_info, weight_info = arg_infos
        assert len(g_info.shape) == 3
        assert len(invvar_info.shape) == 1
        assert len(x_info.shape) == 3
        assert len(weight_info.shape) == 2

        # We only support sharding on the batch dimensions.
        # Force sharding on all others dimensions with None.
        # Also force gx, x and invvar to have the same batch sharding/replication.
        x_spec = x_info.sharding.spec
        arg_shardings = (NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(x_spec[0],)),
                         NamedSharding(mesh, PartitionSpec(x_spec[0], None, None)),
                         NamedSharding(mesh, PartitionSpec(None, None)))

        output_sharding = NamedSharding(mesh, PartitionSpec(x_spec[0], None, None))
        invvar_sharding = NamedSharding(mesh, PartitionSpec(None, None))
        output_shardings = (output_sharding, invvar_sharding, invvar_sharding)

        # Sharded_impl only accepts positional arugments
        # And they should be Jax traceable variables
        def impl(g, invvar, x, weight):
            grad_input, grad_weight, part_grad = _rms_norm_bwd_p.bind(
                g, invvar, x, weight, eps=eps
            )
            # We need to sum the weight gradient from all partition.
            global_weight = grad_weight
            if x_spec[0]:
                global_weight = jax.lax.psum(grad_weight, x_spec[0])
            return grad_input, global_weight, part_grad
        return mesh, impl, output_shardings, arg_shardings
register_primitive(RmsNormBwdClass) 

通过与以前相同的自定义 _vjp 规则建立前向和后向原语的管道:

@partial(jax.custom_vjp, nondiff_argnums=(2,))
def custom_p_rms_norm(x, weight, eps=1e-05):
    output, _ = custom_p_rms_norm_fwd(x, weight, eps=eps)
    return output

def custom_p_rms_norm_fwd(x, weight, eps=1e-05):
    output, invvar = RmsNormFwdClass.outer_primitive.bind(x, weight, eps=eps)
    return output, (invvar, x, weight)

def custom_p_rms_norm_bwd(eps, res, g):
    invvar, x, weight = res
    grad_input, grad_weight, part_grad = RmsNormBwdClass.outer_primitive.bind(
        g, invvar, x, weight, eps=eps)
    return grad_input, grad_weight

custom_p_rms_norm.defvjp(custom_p_rms_norm_fwd, custom_p_rms_norm_bwd) 

有了这些,我们完全定义了我们的自定义 RMS 规范原语与自定义分区。为了检查正确性,我们定义了以下损失函数:ref_loss 是要与之比较的参考值,而 custom_p_loss 使用了我们新实现的实现了自定义分区的原语。

def ref_loss(x, weight):
    predictions = rms_norm(x, weight)
    return -jnp.mean(predictions**2)

ref = jax.grad(ref_loss, argnums=(0, 1))(x, weight)

def custom_p_loss(x, weight):
    predictions = custom_p_rms_norm(x, weight)
    return -jnp.mean(predictions**2) 

检查正确性

with Mesh(jax.local_devices(), ("x",)):
    def run_and_verify(loss):
        pjitted = pjit(
            jax.grad(loss, argnums=(0, 1)),
            # Shard x by batch dimension and replicate weight on all devices.
            in_shardings=(
                PartitionSpec("x", None, None),
                PartitionSpec(None, None),
            ),
            # Shard the output by batch dimension and replicate weight grad on all devices.
            out_shardings=(
                PartitionSpec("x", None, None),
                PartitionSpec(None, None),
            ),
        )
        hlo = pjitted.lower(x, weight).compile().as_text()
        out = pjitted(x, weight)
        print(hlo)
        assert "all-reduce-done" in hlo, "The gradient will produce wrong value!"
        if "all-gather-start" in hlo:
            print("NOT OPTIMIZED, ALL_GATHER in the graph!")
        return out

    custom_p_out = run_and_verify(custom_p_loss)

for r, o in zip(ref_out, custom_p_out):
    print(jnp.allclose(r, o, atol=1e-6, rtol=1e-6)) 
HloModule pjit_custom_p_loss, is_scheduled=true, entry_computation_layout={(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})->(f16[4,512,512]{2,1,0}, f16[512,512]{1,0})}, allow_spmd_sharding_propagation_to_parameters={false,false}, allow_spmd_sharding_propagation_to_output={false,false}, num_partitions=4, frontend_attributes={fingerprint_before_lhs="d7b9bc40de002332dd665ff2ab537b76"}

%fused_multiply (param_0: f16[4,512,512]) -> f16[4,512,512] {
  %param_0 = f16[4,512,512]{2,1,0} parameter(0)
  %constant_4_1 = f16[] constant(-4.7684e-07)
  %broadcast.8.1 = f16[4,512,512]{2,1,0} broadcast(f16[] %constant_4_1), dimensions={}, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
  ROOT %multiply.5.1 = f16[4,512,512]{2,1,0} multiply(f16[4,512,512]{2,1,0} %param_0, f16[4,512,512]{2,1,0} %broadcast.8.1), metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
}

%region_0.9._custom_call_lowering_rule (Arg_0.10.0: f16[], Arg_1.11.0: f16[]) -> f16[] {
  %Arg_1.11.0 = f16[] parameter(1)
  %Arg_0.10.0 = f16[] parameter(0)
  ROOT %add.2.0 = f16[] add(f16[] %Arg_0.10.0, f16[] %Arg_1.11.0), metadata={op_name="jit(main)/add" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=433}
}

ENTRY %main.23_spmd (param.2: f16[4,512,512], param.1.0: f16[512,512]) -> (f16[4,512,512], f16[512,512]) {
  %param.1.0 = f16[512,512]{1,0} parameter(1), sharding={replicated}
  %param.2 = f16[4,512,512]{2,1,0} parameter(0), sharding={devices=[4,1,1]<=[4]}
  %custom-call.3.0 = (f16[4,512,512]{2,1,0}, f32[4]{0}) custom-call(f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_forward_affine_mixed_dtype", operand_layout_constraints={f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\000\000\000\000$V\000\000"
  %get-tuple-element.14 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
  %loop_multiply_fusion = f16[4,512,512]{2,1,0} fusion(f16[4,512,512]{2,1,0} %get-tuple-element.14), kind=kLoop, calls=%fused_multiply, metadata={op_name="pjit(custom_p_loss)/jit(main)/mul" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=484}
  %get-tuple-element.1.0 = f32[4]{0} get-tuple-element((f16[4,512,512]{2,1,0}, f32[4]{0}) %custom-call.3.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormFwdClass.partition at 0x7ff99e3980d0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormFwdClass.infer_sharding_from_operands at 0x7ff99e398040> decode_shardings=True in_tree=PyTreeDef((*, *)) out_tree=PyTreeDef((*, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=440}
  %custom-call.5.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) custom-call(f16[4,512,512]{2,1,0} %loop_multiply_fusion, f32[4]{0} %get-tuple-element.1.0, f16[4,512,512]{2,1,0} %param.2, f16[512,512]{1,0} %param.1.0), custom_call_target="rms_backward_affine", operand_layout_constraints={f16[4,512,512]{2,1,0}, f32[4]{0}, f16[4,512,512]{2,1,0}, f16[512,512]{1,0}}, api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config="\004\000\000\000\000\000\004\000\361h\343\210\265\370\344>\001\000\000\000\001\000\000\000\020\000\000\000$V\000\000"
  %get-tuple-element.7.0 = f16[512,512]{1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=1, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  %all-reduce-start = f16[512,512]{1,0} all-reduce-start(f16[512,512]{1,0} %get-tuple-element.7.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%region_0.9._custom_call_lowering_rule, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"collective_backend_config":{"is_sync":true,"no_parallel_custom_call":false}}
  %all-reduce-done = f16[512,512]{1,0} all-reduce-done(f16[512,512]{1,0} %all-reduce-start), metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  %get-tuple-element.12.0 = f16[4,512,512]{2,1,0} get-tuple-element((f16[4,512,512]{2,1,0}, f16[512,512]{1,0}, f32[16,262144]{1,0}) %custom-call.5.0), index=0, metadata={op_name="pjit(custom_p_loss)/jit(main)/custom_partitioning[partition=<function RmsNormBwdClass.partition at 0x7ff99e3985e0> propagate_user_sharding=None infer_sharding_from_operands=<function RmsNormBwdClass.infer_sharding_from_operands at 0x7ff99e398550> decode_shardings=True in_tree=PyTreeDef((*, *, *, *)) out_tree=PyTreeDef((*, *, *)) static_args=[1e-05]]" source_file="/opt/jax/docs/Custom_Operation_for_GPUs.py" source_line=483}
  ROOT %tuple.1.0 = (f16[4,512,512]{2,1,0}, f16[512,512]{1,0}) tuple(f16[4,512,512]{2,1,0} %get-tuple-element.12.0, f16[512,512]{1,0} %all-reduce-done)
} 
True
True 

现在 HLO 中没有全收集操作,尊重分片,只有通过全归约累积梯度。

让我们把它放在一起

使用 custom_partitioning 完全定义原语的完整定义可以在 Custom_Operation_for_GPUs.py 中找到,以及定义 python 绑定的相应 C++ 代码可以在以下找到:

gpu_ops 代码列表

gpu_ops/kernel_helpers.h

gpu_ops/kernels.h

gpu_ops/pybind11_kernel_helpers.h

gpu_ops/gpu_ops.cpp

gpu_ops/rms_norm_kernels.cu

标签:tmp,中文,JAX,jax,py,f32,add,文档,ipykernel
From: https://www.cnblogs.com/apachecn/p/18260392

相关文章

  • JAX-中文文档-二-
    JAX中文文档(二)原文:jax.readthedocs.io/en/latest/JAX教程原文:jax.readthedocs.io/en/latest/tutorials.html快速入门关键概念即时编译自动向量化自动微分调试入门伪随机数使用pytrees工作分片计算入门有状态计算关键概念原文:jax.re......
  • JAX-中文文档-八-
    JAX中文文档(八)原文:jax.readthedocs.io/en/latest/自动微分手册原文:jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.htmlalexbw@,mattjj@JAX拥有非常通用的自动微分系统。在这本手册中,我们将介绍许多巧妙的自动微分思想,您可以根据自己的工作进行选择。i......
  • prometheus 中文说明告警指标
    https://blog.51cto.com/qiangsh/1977449主机和硬件监控可用内存指标主机中可用内存容量不足10%-alert:HostOutOfMemoryexpr:node_memory_MemAvailable_bytes/node_memory_MemTotal_bytes*100<10for:5mlabels:severity:warningannotations:......
  • 软件开发项目全套文档资料参考(规格说明书、详细设计、测试计划、验收报告)
     前言:在软件开发过程中,文档资料是非常关键的一部分,它们帮助团队成员理解项目需求、设计、实施、测试、验收等各个环节,确保项目的顺利进行。以下是各个阶段的文档资料概述:软件项目管理部分文档清单: 工作安排任务书,可行性分析报告,立项申请审批表,产品需求规格说明书,需求调研......
  • 中文检测插件
    大家都知道,做出海应用,尤其是在一些对中国不友好的国家做业务。全面去中文化至关重要。对于开发而言,在代码层如果只靠人为控制这个变量,尤其艰难。所以给大家安利一个我们自研的中文检测插件,他能在您开发过程中时刻检测您的输入是否含有中文。大家先看下效果。如果您有需要,烦......
  • 搜索硬编码中文
    老项目中常常有直接在代码里或者xml布局中硬编码中文的,在后期业务扩展做国际化翻译时,这就是一个巨大的坑,因为我们需要知道哪里硬编码了,然后提取到strings.xml中刚好我最近在弄这个,如何找到代码中所有的硬编码就是核心问题,下面记录下我的步骤 1.首先写好正则,直接百度也行^((?!......
  • MestReNova14.0中文版安装教程
    MestReNova14是一款专业级的核磁共振(NMR)与质谱(MS)数据分析软件,专注于化合物结构解析和验证。该软件以卓越的谱图处理能力和智能化算法为核心,提供自定义参数调整、自动峰识别、精准积分、耦合常数计算等功能。支持多种仪器数据格式导入,可高效处理一维至四维NMR谱图以及各类质谱数据......
  • PDF英语文档怎么翻译成中文?
    外语文献是我们学习和工作中经常遇到的难题,其中包含许多重要工作信息,精确地理解和翻译非常重要。但并不是所有格式的文件都能直接编辑和翻译。例如PDF格式的文件就无法直接进行编辑,当我们需要翻译PDF格式的外语文档时,应该使用什么工具呢?本篇文章就为你提供几个快速翻译PDF文件的方......
  • mac苹果窗口辅助工具:Magnet for mac 2.14.0中文免激活版
    Magnet是一款针对MacOS系统的窗口管理工具软件。它能够帮助用户更加高效地管理和组织桌面上的窗口,通过简单的快捷键操作,可以将窗口自动调整到指定的位置和大小,实现多窗口快速布局。Magnet还支持多显示器环境下的窗口管理,可以让用户更加轻松地在多屏幕之间切换和布局窗口。......
  • 032java jsp ssm大学生第二课堂成绩单系统学生思想道德技术修养文体活动管理(源码+数据
     项目技术:SSM+Maven等等组成,B/S模式+Maven管理等等。环境需要1.运行环境:最好是javajdk1.8,我们在这个平台上运行的。其他版本理论上也可以。2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA;3.tomcat环境:Tomcat7.x,8.x,9.x版本均可4.硬件环境:windows7/8/101G......