首页 > 其他分享 >JAX-中文文档-三-

JAX-中文文档-三-

时间:2024-06-21 14:25:31浏览次数:24  
标签:中文 jnp JAX jax print 文档 debug checkify

JAX 中文文档(三)

原文:jax.readthedocs.io/en/latest/

有状态计算

原文:jax.readthedocs.io/en/latest/stateful-computations.html

JAX 的转换(如jit()vmap()grad())要求它们包装的函数是纯粹的:即,函数的输出仅依赖于输入,并且没有副作用,比如更新全局状态。您可以在JAX sharp bits: Pure functions中找到关于这一点的讨论。

在机器学习的背景下,这种约束可能会带来一些挑战,因为状态可以以多种形式存在。例如:

  • 模型参数,

  • 优化器状态,以及

  • BatchNorm这样的有状态层。

本节提供了如何在 JAX 程序中正确处理状态的一些建议。

一个简单的例子:计数器

让我们首先看一个简单的有状态程序:一个计数器。

import jax
import jax.numpy as jnp

class Counter:
  """A simple counter."""

  def __init__(self):
    self.n = 0

  def count(self) -> int:
  """Increments the counter and returns the new value."""
    self.n += 1
    return self.n

  def reset(self):
  """Resets the counter to zero."""
    self.n = 0

counter = Counter()

for _ in range(3):
  print(counter.count()) 
1
2
3 

计数器的n属性在连续调用count时维护计数器的状态。调用count的副作用是修改它。

假设我们想要快速计数,所以我们即时编译count方法。(在这个例子中,这实际上不会以任何方式加快速度,由于很多原因,但把它看作是模型参数更新的玩具模型,jit()确实产生了巨大的影响)。

counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  print(fast_count()) 
1
1
1 

哦不!我们的计数器不能工作了。这是因为

self.n += 1 

count中涉及副作用:它直接修改了输入的计数器,因此此函数不受jit支持。这样的副作用仅在首次跟踪函数时执行一次,后续调用将不会重复该副作用。那么,我们该如何修复它呢?

解决方案:显式状态

问题的一部分在于我们的计数器返回值不依赖于参数,这意味着编译输出中包含了一个常数。但它不应该是一个常数 - 它应该依赖于状态。那么,为什么我们不将状态作为一个参数呢?

CounterState = int

class CounterV2:

  def count(self, n: CounterState) -> tuple[int, CounterState]:
    # You could just return n+1, but here we separate its role as 
    # the output and as the counter state for didactic purposes.
    return n+1, n+1

  def reset(self) -> CounterState:
    return 0

counter = CounterV2()
state = counter.reset()

for _ in range(3):
  value, state = counter.count(state)
  print(value) 
1
2
3 

在这个Counter的新版本中,我们将n移动到count的参数中,并添加了另一个返回值,表示新的、更新的状态。现在,为了使用这个计数器,我们需要显式地跟踪状态。但作为回报,我们现在可以安全地使用jax.jit这个计数器:

state = counter.reset()
fast_count = jax.jit(counter.count)

for _ in range(3):
  value, state = fast_count(state)
  print(value) 
1
2
3 

一个一般的策略

我们可以将同样的过程应用到任何有状态方法中,将其转换为无状态方法。我们拿一个形式如下的类

class StatefulClass

  state: State

  def stateful_method(*args, **kwargs) -> Output: 

并将其转换为以下形式的类

class StatelessClass

  def stateless_method(state: State, *args, **kwargs) -> (Output, State): 

这是一个常见的函数式编程模式,本质上就是处理所有 JAX 程序中状态的方式。

注意,一旦我们按照这种方式重写它,类的必要性就不那么明显了。我们可以只保留stateless_method,因为类不再执行任何工作。这是因为,像我们刚刚应用的策略一样,面向对象编程(OOP)是帮助程序员理解程序状态的一种方式。

在我们的情况下,CounterV2 类只是一个名称空间,将所有使用 CounterState 的函数集中在一个位置。读者可以思考:将其保留为类是否有意义?

顺便说一句,你已经在 JAX 伪随机性 API 中看到了这种策略的示例,即 jax.random,在 :ref:pseudorandom-numbers 部分展示。与 Numpy 不同,后者使用隐式更新的有状态类管理随机状态,而 JAX 要求程序员直接使用随机生成器状态——PRNG 密钥。

简单的工作示例:线性回归

现在让我们将这种策略应用到一个简单的机器学习模型上:通过梯度下降进行线性回归。

这里,我们只处理一种状态:模型参数。但通常情况下,你会看到许多种状态在 JAX 函数中交替出现,比如优化器状态、批归一化的层统计数据等。

需要仔细查看的函数是 update

from typing import NamedTuple

class Params(NamedTuple):
  weight: jnp.ndarray
  bias: jnp.ndarray

def init(rng) -> Params:
  """Returns the initial model params."""
  weights_key, bias_key = jax.random.split(rng)
  weight = jax.random.normal(weights_key, ())
  bias = jax.random.normal(bias_key, ())
  return Params(weight, bias)

def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
  """Computes the least squares error of the model's predictions on x against y."""
  pred = params.weight * x + params.bias
  return jnp.mean((pred - y) ** 2)

LEARNING_RATE = 0.005

@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
  """Performs one SGD update step on params using the given data."""
  grad = jax.grad(loss)(params, x, y)

  # If we were using Adam or another stateful optimizer,
  # we would also do something like
  #
  #   updates, new_optimizer_state = optimizer(grad, optimizer_state)
  # 
  # and then use `updates` instead of `grad` to actually update the params.
  # (And we'd include `new_optimizer_state` in the output, naturally.)

  new_params = jax.tree_map(
      lambda param, g: param - g * LEARNING_RATE, params, grad)

  return new_params 

注意,我们手动地将参数输入和输出到更新函数中。

import matplotlib.pyplot as plt

rng = jax.random.key(42)

# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise

# Fit regression
params = init(rng)
for _ in range(1000):
  params = update(params, xs, ys)

plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend(); 
/tmp/ipykernel_2992/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
  new_params = jax.tree_map( 

_images/9d9c2471be1e4c9b8597cfff1433de0fe7ad2ef5b99cc6897ee153d7533d6521.png

进一步探讨

上述描述的策略是任何使用 jitvmapgrad 等转换的 JAX 程序必须处理状态的方式。

如果只涉及两个参数,手动处理参数似乎还可以接受,但如果是有数十层的神经网络呢?你可能已经开始担心两件事情:

  1. 我们是否应该手动初始化它们,基本上是在前向传播定义中已经编写过的内容?

  2. 我们是否应该手动处理所有这些事情?

处理这些细节可能有些棘手,但有一些库的示例可以为您解决这些问题。请参阅JAX 神经网络库获取一些示例。

进一步资源

用户指南

原文:jax.readthedocs.io/en/latest/user_guides.html

用户指南是对 JAX 内特定主题的深入探讨,随着您的 JAX 项目发展成为更大或部署代码库,这些主题变得更为相关。

调试和性能

  • 如何在 JAX 中思考

  • 对 JAX 程序进行性能分析

  • 设备内存分析

  • JAX 中的运行时值调试

  • GPU 性能技巧

  • 持久化编译缓存

开发

  • 理解 Jaxprs

  • JAX 中的外部回调

  • 类型提升语义

  • Pytrees

运行时间

  • 提前降低和编译

  • 导出和序列化

  • JAX 错误

  • 转移保护

自定义操作

  • Pallas:一种 JAX 内核语言

如何在 JAX 中思考

原文:jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

在 Colab 中打开 在 Kaggle 中打开

JAX 提供了一个简单而强大的 API 用于编写加速数值代码,但在 JAX 中有效工作有时需要额外考虑。本文档旨在帮助建立对 JAX 如何运行的基础理解,以便您更有效地使用它。

JAX vs. NumPy

关键概念:

  • JAX 提供了一个方便的类似于 NumPy 的接口。

  • 通过鸭子类型,JAX 数组通常可以直接替换 NumPy 数组。

  • 不像 NumPy 数组,JAX 数组总是不可变的。

NumPy 提供了一个众所周知且功能强大的 API 用于处理数值数据。为方便起见,JAX 提供了 jax.numpy,它紧密反映了 NumPy API,并为进入 JAX 提供了便捷的入口。几乎可以用 jax.numpy 完成 numpy 可以完成的任何事情:

import matplotlib.pyplot as plt
import numpy as np

x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np); 

../_images/a5186b574da824b332762dc6cb1c9829034a3a27f1f5dc7ba79658eb4c9e3715.png

import jax.numpy as jnp

x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp); 

../_images/4d0f0e6e14cb62a4e3105aeb1c1f8fe608774b82fd3073f26a9813de35f414eb.png

代码块除了用 jnp 替换 np 外,其余完全相同。正如我们所见,JAX 数组通常可以直接替换 NumPy 数组,用于诸如绘图等任务。

这些数组本身是作为不同的 Python 类型实现的:

type(x_np) 
numpy.ndarray 
type(x_jnp) 
jaxlib.xla_extension.ArrayImpl 

Python 的 鸭子类型 允许在许多地方可互换使用 JAX 数组和 NumPy 数组。

然而,JAX 和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,一旦创建,其内容无法更改。

这里有一个在 NumPy 中突变数组的例子:

# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x) 
[10  1  2  3  4  5  6  7  8  9] 

在 JAX 中,等效操作会导致错误,因为 JAX 数组是不可变的:

%xmode minimal 
Exception reporting mode: Minimal 
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10 
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html 

对于更新单个元素,JAX 提供了一个 索引更新语法,返回一个更新后的副本:

y = x.at[0].set(10)
print(x)
print(y) 
[0 1 2 3 4 5 6 7 8 9]
[10  1  2  3  4  5  6  7  8  9] 

NumPy、lax 和 XLA:JAX API 层次结构

关键概念:

  • jax.numpy 是一个提供熟悉接口的高级包装器。

  • jax.lax 是一个更严格且通常更强大的低级 API。

  • 所有 JAX 操作都是基于 XLA – 加速线性代数编译器中的操作实现的。

如果您查看 jax.numpy 的源代码,您会看到所有操作最终都是以 jax.lax 中定义的函数形式表达的。您可以将 jax.lax 视为更严格但通常更强大的 API,用于处理多维数组。

例如,虽然jax.numpy将隐式促进参数以允许不同数据类型之间的操作,但jax.lax不会:

import jax.numpy as jnp
jnp.add(1, 1.0)  # jax.numpy API implicitly promotes mixed types. 
Array(2., dtype=float32, weak_type=True) 
from jax import lax
lax.add(1, 1.0)  # jax.lax API requires explicit type promotion. 
MLIRError: Verification failed:
error: "jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))): op requires the same element type for all operands and results

The above exception was the direct cause of the following exception:

ValueError: Cannot lower jaxpr with verifier errors:
	op requires the same element type for all operands and results
		at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))
Define JAX_DUMP_IR_TO to dump the module. 

如果直接使用jax.lax,在这种情况下你将需要显式地进行类型提升:

lax.add(jnp.float32(1), 1.0) 
Array(2., dtype=float32) 

除了这种严格性外,jax.lax还提供了一些比 NumPy 支持的更一般操作更高效的 API。

例如,考虑一个 1D 卷积,在 NumPy 中可以这样表达:

x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y) 
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32) 

在幕后,这个 NumPy 操作被转换为由lax.conv_general_dilated实现的更通用的卷积:

from jax import lax
result = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float),  # note: explicit promotion
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)])  # equivalent of padding='full' in NumPy
result[0, 0] 
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32) 

这是一种批处理卷积操作,专为深度神经网络中经常使用的卷积类型设计,需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多细节,请参见Convolutions in JAX)。

从本质上讲,所有jax.lax操作都是 XLA 中操作的 Python 包装器;例如,在这里,卷积实现由XLA:ConvWithGeneralPadding提供。每个 JAX 操作最终都是基于这些基本 XLA 操作表达的,这就是使得即时(JIT)编译成为可能的原因。

要 JIT 或不要 JIT

关键概念:

  • 默认情况下,JAX 按顺序逐个执行操作。

  • 使用即时(JIT)编译装饰器,可以优化操作序列并一次运行:

  • 并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态且已知的。

所有 JAX 操作都是基于 XLA 表达的事实,使得 JAX 能够使用 XLA 编译器非常高效地执行代码块。

例如,考虑此函数,它对二维矩阵的行进行标准化,表达为jax.numpy操作:

import jax.numpy as jnp

def norm(X):
  X = X - X.mean(0)
  return X / X.std(0) 

可以使用jax.jit变换创建函数的即时编译版本:

from jax import jit
norm_compiled = jit(norm) 

此函数返回与原始函数相同的结果,达到标准浮点精度:

np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6) 
True 

但由于编译(其中包括操作的融合、避免分配临时数组以及其他许多技巧),在 JIT 编译的情况下,执行时间可以比非常数级别快得多(请注意使用block_until_ready()以考虑 JAX 的异步调度):

%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready() 
319 μs ± 1.98 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
272 μs ± 849 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each) 

话虽如此,jax.jit确实存在一些限制:特别是,它要求所有数组具有静态形状。这意味着一些 JAX 操作与 JIT 编译不兼容。

例如,此操作可以在逐操作模式下执行:

def get_negatives(x):
  return x[x < 0]

x = jnp.array(np.random.randn(10))
get_negatives(x) 
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32) 

但如果您尝试在 jit 模式下执行它,则会返回错误:

jit(get_negatives)(x) 
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError 

这是因为该函数生成的数组形状在编译时未知:输出的大小取决于输入数组的值,因此与 JIT 不兼容。

JIT 机制:跟踪和静态变量

关键概念:

  • JIT 和其他 JAX 转换通过跟踪函数来确定其对特定形状和类型输入的影响。

  • 不希望被追踪的变量可以标记为静态

要有效使用 jax.jit,理解其工作原理是很有用的。让我们在一个 JIT 编译的函数中放几个 print() 语句,然后调用该函数:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y) 
Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)> 
Array([0.25773212, 5.3623195 , 5.403243  ], dtype=float32) 

注意,打印语句执行,但打印的不是我们传递给函数的数据,而是打印追踪器对象,这些对象代替它们。

这些追踪器对象是 jax.jit 用来提取函数指定的操作序列的基本替代物,编码数组的形状dtype,但对值是不可知的。然后可以有效地将这个记录的计算序列应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。

当我们在匹配的输入上再次调用编译函数时,无需重新编译,也不打印任何内容,因为结果在编译的 XLA 中计算,而不是在 Python 中:

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2) 
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32) 

提取的操作序列编码在 JAX 表达式中,简称为 jaxpr。您可以使用 jax.make_jaxpr 转换查看 jaxpr:

from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y) 
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  in (e,) } 

注意这一后果:因为 JIT 编译是在没有数组内容信息的情况下完成的,所以函数中的控制流语句不能依赖于追踪的值。例如,这将失败:

@jit
def f(x, neg):
  return -x if neg else x

f(1, True) 
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_2814/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError 

如果有不希望被追踪的变量,可以将它们标记为静态以供 JIT 编译使用:

from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
  return -x if neg else x

f(1, True) 
Array(-1, dtype=int32, weak_type=True) 

请注意,使用不同的静态参数调用 JIT 编译函数会导致重新编译,所以函数仍然如预期般工作:

f(1, False) 
Array(1, dtype=int32, weak_type=True) 

理解哪些值和操作将是静态的,哪些将被追踪,是有效使用 jax.jit 的关键部分。

静态与追踪操作

关键概念:

  • 就像值可以是静态的或者被追踪的一样,操作也可以是静态的或者被追踪的。

  • 静态操作在 Python 中在编译时评估;跟踪操作在 XLA 中在运行时编译并评估。

  • 使用 numpy 进行您希望静态的操作;使用 jax.numpy 进行您希望被追踪的操作。

静态和追踪值的区别使得重要的是考虑如何保持静态值的静态。考虑这个函数:

import jax.numpy as jnp
from jax import jit

@jit
def f(x):
  return x.reshape(jnp.array(x.shape).prod())

x = jnp.ones((2, 3))
f(x) 
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2814/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:

  operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
    from line /tmp/ipykernel_2814/1983583872.py:6 (f) 

这会因为找到追踪器而不是整数类型的具体值的 1D 序列而失败。让我们向函数中添加一些打印语句,以了解其原因:

@jit
def f(x):
  print(f"x = {x}")
  print(f"x.shape = {x.shape}")
  print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
  # comment this out to avoid the error:
  # return x.reshape(jnp.array(x.shape).prod())

f(x) 
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)> 

注意尽管x被追踪,x.shape是一个静态值。然而,当我们在这个静态值上使用jnp.arrayjnp.prod时,它变成了一个被追踪的值,在这种情况下,它不能用于像reshape()这样需要静态输入的函数(回想:数组形状必须是静态的)。

一个有用的模式是使用numpy进行应该是静态的操作(即在编译时完成),并使用jax.numpy进行应该被追踪的操作(即在运行时编译和执行)。对于这个函数,可能会像这样:

from jax import jit
import jax.numpy as jnp
import numpy as np

@jit
def f(x):
  return x.reshape((np.prod(x.shape),))

f(x) 
Array([1., 1., 1., 1., 1., 1.], dtype=float32) 

因此,在 JAX 程序中的一个标准约定是import numpy as npimport jax.numpy as jnp,这样两个接口都可以用来更精细地控制操作是以静态方式(使用numpy,一次在编译时)还是以追踪方式(使用jax.numpy,在运行时优化)执行。

对 JAX 程序进行性能分析

原文:jax.readthedocs.io/en/latest/profiling.html

使用 Perfetto 查看程序跟踪

我们可以使用 JAX 分析器生成可以使用Perfetto 可视化工具查看的 JAX 程序的跟踪。目前,此方法会阻塞程序,直到点击链接并加载 Perfetto UI 以打开跟踪为止。如果您希望获取性能分析信息而无需任何交互,请查看下面的 Tensorboard 分析器。

with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
  # Run the operations to be profiled
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready() 

计算完成后,程序会提示您打开链接到ui.perfetto.dev。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。

Perfetto 跟踪查看器

加载链接后,程序执行将继续。链接在打开一次后将不再有效,但将重定向到一个保持有效的新 URL。然后,您可以在 Perfetto UI 中单击“共享”按钮,创建可与他人共享的跟踪的永久链接。

远程分析

在对远程运行的代码进行性能分析(例如在托管的虚拟机上)时,您需要在端口 9001 上建立 SSH 隧道以使链接工作。您可以使用以下命令执行此操作:

$  ssh  -L  9001:127.0.0.1:9001  <user>@<host> 

或者如果您正在使用 Google Cloud:

$  gcloud  compute  ssh  <machine-name>  --  -L  9001:127.0.0.1:9001 

手动捕获

而不是使用jax.profiler.trace以编程方式捕获跟踪,您可以通过在感兴趣的脚本中调用jax.profiler.start_server(<port>)来启动分析服务器。如果您只需在脚本的某部分保持分析服务器活动,则可以通过调用jax.profiler.stop_server()来关闭它。

脚本运行后并且分析服务器已启动后,我们可以通过运行以下命令手动捕获和跟踪:

$  python  -m  jax.collect_profile  <port>  <duration_in_ms> 

默认情况下,生成的跟踪信息会被转储到临时目录中,但可以通过传递--log_dir=<自定义目录>来覆盖此设置。另外,默认情况下,程序将提示您打开链接到ui.perfetto.dev。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。通过传递--no_perfetto_link命令可以禁用此功能。或者,您也可以将 Tensorboard 指向log_dir以分析跟踪(参见下面的“Tensorboard 分析”部分)。

TensorBoard 性能分析

TensorBoard 的分析器可用于分析 JAX 程序。Tensorboard 是获取和可视化程序性能跟踪和分析(包括 GPU 和 TPU 上的活动)的好方法。最终结果看起来类似于这样:

TensorBoard 分析器示例

安装

TensorBoard 分析器仅与捆绑有 TensorFlow 的 TensorBoard 版本一起提供。

pip  install  tensorflow  tensorboard-plugin-profile 

如果您已安装了 TensorFlow,则只需安装tensorboard-plugin-profile pip 包。请注意仅安装一个版本的 TensorFlow 或 TensorBoard,否则可能会遇到下面描述的“重复插件”错误。有关安装 TensorBoard 的更多信息,请参见www.tensorflow.org/guide/profiler

程序化捕获

您可以通过jax.profiler.start_trace()jax.profiler.stop_trace()方法来配置您的代码以捕获性能分析器的追踪。调用start_trace()时需要指定写入追踪文件的目录。这个目录应该与启动 TensorBoard 时使用的--logdir目录相同。然后,您可以使用 TensorBoard 来查看这些追踪信息。

例如,要获取性能分析器的追踪:

import jax

jax.profiler.start_trace("/tmp/tensorboard")

# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()

jax.profiler.stop_trace() 

注意block_until_ready()调用。我们使用这个函数来确保设备上的执行被追踪到。有关为什么需要这样做的详细信息,请参见异步调度部分。

您还可以使用jax.profiler.trace()上下文管理器作为start_tracestop_trace的替代方法:

import jax

with jax.profiler.trace("/tmp/tensorboard"):
  key = jax.random.key(0)
  x = jax.random.normal(key, (5000, 5000))
  y = x @ x
  y.block_until_ready() 

要查看追踪信息,请首先启动 TensorBoard(如果尚未启动):

$  tensorboard  --logdir=/tmp/tensorboard
[...]
Serving  TensorBoard  on  localhost;  to  expose  to  the  network,  use  a  proxy  or  pass  --bind_all
TensorBoard  2.5.0  at  http://localhost:6006/  (Press  CTRL+C  to  quit) 

在这个示例中,您应该能够在localhost:6006/加载 TensorBoard。您可以使用--port标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。

然后,要么在右上角的下拉菜单中选择“Profile”,要么直接访问localhost:6006/#profile。可用的追踪信息会显示在左侧的“Runs”下拉菜单中。选择您感兴趣的运行,并在“Tools”下选择trace_viewer。现在您应该能看到执行时间轴。您可以使用 WASD 键来导航追踪信息,点击或拖动以选择事件并查看底部的更多详细信息。有关使用追踪查看器的更多详细信息,请参阅这些 TensorFlow 文档

您还可以使用memory_viewerop_profilegraph_viewer工具。

通过 TensorBoard 手动捕获

以下是从运行中的程序中手动触发 N 秒追踪的捕获说明。

  1. 启动 TensorBoard 服务器:

    tensorboard  --logdir  /tmp/tensorboard/ 
    

    localhost:6006/处应该能够加载 TensorBoard。您可以使用--port标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。

  2. 在您希望进行分析的 Python 程序或进程中,将以下内容添加到开头的某个位置:

    import jax.profiler
    jax.profiler.start_server(9999) 
    

    这将启动 TensorBoard 连接到的性能分析器服务器。在继续下一步之前,必须先运行性能分析器服务器。完成后,可以调用jax.profiler.stop_server()来关闭它。

    如果你想要分析一个长时间运行的程序片段(例如长时间的训练循环),你可以将此代码放在程序开头并像往常一样启动程序。如果你想要分析一个短程序(例如微基准测试),一种选择是在 IPython shell 中启动分析器服务器,并在下一步开始捕获后用 %run 运行短程序。另一种选择是在程序开头启动分析器服务器,并使用 time.sleep() 给你足够的时间启动捕获。

  3. 打开localhost:6006/#profile,并点击左上角的“CAPTURE PROFILE”按钮。将“localhost:9999”作为分析服务的 URL(这是你在上一步中启动的分析器服务器的地址)。输入你想要进行分析的毫秒数,然后点击“CAPTURE”。

  4. 如果你想要分析的代码尚未运行(例如在 Python shell 中启动了分析器服务器),请在进行捕获时运行它。

  5. 捕获完成后,TensorBoard 应会自动刷新。(并非所有 TensorBoard 分析功能都与 JAX 连接,所以初始时看起来可能没有捕获到任何内容。)在左侧的“工具”下,选择 trace_viewer

    现在你应该可以看到执行的时间轴。你可以使用 WASD 键来导航跟踪,点击或拖动选择事件以在底部查看更多详细信息。参见这些 TensorFlow 文档获取有关使用跟踪查看器的更多详细信息。

    你也可以使用 memory_viewerop_profilegraph_viewer 工具。

添加自定义跟踪事件

默认情况下,跟踪查看器中的事件大多是低级内部 JAX 函数。你可以使用 jax.profiler.TraceAnnotationjax.profiler.annotate_function() 在你的代码中添加自定义事件和函数。

故障排除

GPU 分析

运行在 GPU 上的程序应该在跟踪查看器顶部附近生成 GPU 流的跟踪。如果只看到主机跟踪,请检查程序日志和/或输出,查看以下错误消息。

如果出现类似 Could not load dynamic library 'libcupti.so.10.1' 的错误

完整错误:

W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory
2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found. 

libcupti.so的路径添加到环境变量LD_LIBRARY_PATH中。(尝试使用locate libcupti.so来找到路径。)例如:

export  LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH 

即使在做了以上步骤后仍然收到 Could not load dynamic library 错误消息,请检查 GPU 跟踪是否仍然显示在跟踪查看器中。有时即使一切正常,它也会出现此消息,因为它在多个位置查找 libcupti 库。

如果出现类似 failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES 的错误

完整错误:

E  external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445]  function  cupti_interface_->EnableCallback(  0  ,  subscriber_,  CUPTI_CB_DOMAIN_DRIVER_API,  cbid)failed  with  error  CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
2020-06-12  14:31:54.097791:  E  external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487]  function  cupti_interface_->ActivityDisable(activity)failed  with  error  CUPTI_ERROR_NOT_INITIALIZED 

运行以下命令(注意这将需要重新启动):

echo  'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"'  |  sudo  tee  -a  /etc/modprobe.d/nvidia-kernel-common.conf
sudo  update-initramfs  -u
sudo  reboot  now 

查看更多关于此错误的信息,请参阅NVIDIA 的文档

在远程机器上进行性能分析

如果要分析的 JAX 程序正在远程机器上运行,一种选择是在远程机器上执行上述所有说明(特别是在远程机器上启动 TensorBoard 服务器),然后使用 SSH 本地端口转发从本地访问 TensorBoard Web UI。使用以下 SSH 命令将默认的 TensorBoard 端口 6006 从本地转发到远程机器:

ssh  -L  6006:localhost:6006  <remote  server  address> 

或者如果您正在使用 Google Cloud:

$  gcloud  compute  ssh  <machine-name>  --  -L  6006:localhost:6006 
```  #### 多个 TensorBoard 安装

**如果启动 TensorBoard 失败,并出现类似于`ValueError: Duplicate plugins for name projector`的错误**

这通常是因为安装了两个版本的 TensorBoard 和/或 TensorFlow(例如,`tensorflow`、`tf-nightly`、`tensorboard`和`tb-nightly` pip 包都包含 TensorBoard)。卸载一个 pip 包可能会导致`tensorboard`可执行文件被移除,难以替换,因此可能需要卸载所有内容并重新安装单个版本:

```py
pip  uninstall  tensorflow  tf-nightly  tensorboard  tb-nightly
pip  install  tensorflow 

Nsight

NVIDIA 的Nsight工具可用于跟踪和分析 GPU 上的 JAX 代码。有关详情,请参阅Nsight文档

设备内存分析

原文:jax.readthedocs.io/en/latest/device_memory_profiling.html

注意

2023 年 5 月更新:我们建议使用 Tensorboard 进行设备内存分析。在进行分析后,打开 Tensorboard 分析器的 memory_viewer 标签以获取更详细和易于理解的设备内存使用情况。

JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可用于:

  • 查明在特定时间点哪些数组和可执行文件位于 GPU 内存中,或者

  • 追踪内存泄漏。

安装

JAX 设备内存分析器生成的输出可使用 pprof (google/pprof) 解释。首先按照其 安装说明 安装 pprof。撰写时,安装 pprof 需要先安装版本为 1.16+ 的 GoGraphviz,然后运行

go  install  github.com/google/pprof@latest 

安装 pprof 作为 $GOPATH/bin/pprof,其中 GOPATH 默认为 ~/go

注意

来自 google/pprofpprof 版本与作为 gperftools 软件包一部分分发的同名旧工具不同。gperftools 版本的 pprof 不适用于 JAX。

理解 JAX 程序如何使用 GPU 或 TPU 内存

设备内存分析器的常见用途是找出为何 JAX 程序使用大量 GPU 或 TPU 内存,例如调试内存不足问题。

要将设备内存分析保存到磁盘,使用 jax.profiler.save_device_memory_profile()。例如,考虑以下 Python 程序:

import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof") 

如果我们首先运行上述程序,然后执行

pprof  --web  memory.prof 

pprof 打开一个包含设备内存分析调用图格式的 Web 浏览器:

设备内存分析示例

调用图是在每个活动缓冲区分配的 Python 栈的可视化。例如,在这个特定情况下,可视化显示 func2 及其被调用者负责分配了 76.30MB,其中 38.15MB 是在从 func1func2 的调用中分配的。有关如何解释调用图可视化的更多信息,请参阅 pprof 文档

使用 jax.jit() 编译的函数对设备内存分析器不透明。也就是说,任何在 jit 编译函数内部分配的内存都将归因于整个函数。

在本例中,调用 block_until_ready() 是为了确保在收集设备内存分析之前 func2 完成。有关更多详细信息,请参阅异步调度。

调试内存泄漏

我们还可以使用 JAX 设备内存分析器,通过使用 pprof 来可视化在不同时间点获取的两个设备内存配置文件中的内存使用情况变化,以追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。

import jax
import jax.numpy as jnp
import jax.profiler

def afunction():
  return jax.random.normal(jax.random.key(77), (1000000,))

z = afunction()

def anotherfunc():
  arrays = []
  for i in range(1, 10):
    x = jax.random.normal(jax.random.key(42), (i, 10000))
    arrays.append(x)
    x.block_until_ready()
    jax.profiler.save_device_memory_profile(f"memory{i}.prof")

anotherfunc() 

如果我们仅在执行结束时可视化设备内存配置文件(memory9.prof),则可能不明显,即 anotherfunc 中的每次循环迭代都会累积更多的设备内存分配:

pprof  --web  memory9.prof 

执行结束时的设备内存配置文件

afunction 内部的大型但固定分配主导配置文件,但不会随时间增长。

通过使用 pprof--diff_base 功能 来可视化循环迭代中内存使用情况的变化,我们可以找出程序内存使用量随时间增加的原因:

pprof  --web  --diff_base  memory1.prof  memory9.prof 

执行结束时的设备内存配置文件

可视化显示,内存增长可以归因于 anotherfunc 中对 normal 的调用。

在 JAX 中进行运行时值调试

原文:jax.readthedocs.io/en/latest/debugging/index.html

是否遇到梯度爆炸?NaN 使你牙齿咬紧?只想查看计算中间值?请查看以下 JAX 调试工具!本页提供了 TL;DR 摘要,并且您可以点击底部的“阅读更多”链接了解更多信息。

目录:

  • 使用 jax.debug 进行交互式检查

  • 使用 jax.experimental.checkify 进行功能错误检查

  • 使用 JAX 的调试标志抛出 Python 错误

使用 jax.debug 进行交互式检查

TL;DR 使用 jax.debug.print()jax.jitjax.pmappjit 装饰的函数中将值打印到 stdout,并使用 jax.debug.breakpoint() 暂停执行编译函数以检查调用堆栈中的值:

import jax
import jax.numpy as jnp

@jax.jit
def f(x):
  jax.debug.print("

标签:中文,jnp,JAX,jax,print,文档,debug,checkify
From: https://www.cnblogs.com/apachecn/p/18260395

相关文章

  • JAX-中文文档-六-
    JAX中文文档(六)原文:jax.readthedocs.io/en/latest/高级教程原文:jax.readthedocs.io/en/latest/advanced_guide.html本节包含更高级主题的示例和教程,如多核计算、自定义操作及更深入的应用示例使用tensorflow/datasets进行简单神经网络训练使用PyTorch数据加载......
  • JAX-中文文档-九-
    JAX中文文档(九)原文:jax.readthedocs.io/en/latest/使用jax.checkpoint控制自动微分的保存数值(又名jax.remat)原文:jax.readthedocs.io/en/latest/notebooks/autodiff_remat.htmlimportjaximportjax.numpyasjnp简而言之使用jax.checkpoint装饰器(别名为jax.remat),结合......
  • JAX-中文文档-二-
    JAX中文文档(二)原文:jax.readthedocs.io/en/latest/JAX教程原文:jax.readthedocs.io/en/latest/tutorials.html快速入门关键概念即时编译自动向量化自动微分调试入门伪随机数使用pytrees工作分片计算入门有状态计算关键概念原文:jax.re......
  • JAX-中文文档-八-
    JAX中文文档(八)原文:jax.readthedocs.io/en/latest/自动微分手册原文:jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.htmlalexbw@,mattjj@JAX拥有非常通用的自动微分系统。在这本手册中,我们将介绍许多巧妙的自动微分思想,您可以根据自己的工作进行选择。i......
  • prometheus 中文说明告警指标
    https://blog.51cto.com/qiangsh/1977449主机和硬件监控可用内存指标主机中可用内存容量不足10%-alert:HostOutOfMemoryexpr:node_memory_MemAvailable_bytes/node_memory_MemTotal_bytes*100<10for:5mlabels:severity:warningannotations:......
  • 软件开发项目全套文档资料参考(规格说明书、详细设计、测试计划、验收报告)
     前言:在软件开发过程中,文档资料是非常关键的一部分,它们帮助团队成员理解项目需求、设计、实施、测试、验收等各个环节,确保项目的顺利进行。以下是各个阶段的文档资料概述:软件项目管理部分文档清单: 工作安排任务书,可行性分析报告,立项申请审批表,产品需求规格说明书,需求调研......
  • 中文检测插件
    大家都知道,做出海应用,尤其是在一些对中国不友好的国家做业务。全面去中文化至关重要。对于开发而言,在代码层如果只靠人为控制这个变量,尤其艰难。所以给大家安利一个我们自研的中文检测插件,他能在您开发过程中时刻检测您的输入是否含有中文。大家先看下效果。如果您有需要,烦......
  • 搜索硬编码中文
    老项目中常常有直接在代码里或者xml布局中硬编码中文的,在后期业务扩展做国际化翻译时,这就是一个巨大的坑,因为我们需要知道哪里硬编码了,然后提取到strings.xml中刚好我最近在弄这个,如何找到代码中所有的硬编码就是核心问题,下面记录下我的步骤 1.首先写好正则,直接百度也行^((?!......
  • MestReNova14.0中文版安装教程
    MestReNova14是一款专业级的核磁共振(NMR)与质谱(MS)数据分析软件,专注于化合物结构解析和验证。该软件以卓越的谱图处理能力和智能化算法为核心,提供自定义参数调整、自动峰识别、精准积分、耦合常数计算等功能。支持多种仪器数据格式导入,可高效处理一维至四维NMR谱图以及各类质谱数据......
  • PDF英语文档怎么翻译成中文?
    外语文献是我们学习和工作中经常遇到的难题,其中包含许多重要工作信息,精确地理解和翻译非常重要。但并不是所有格式的文件都能直接编辑和翻译。例如PDF格式的文件就无法直接进行编辑,当我们需要翻译PDF格式的外语文档时,应该使用什么工具呢?本篇文章就为你提供几个快速翻译PDF文件的方......