JAX 中文文档(三)
有状态计算
JAX 的转换(如jit()
、vmap()
、grad()
)要求它们包装的函数是纯粹的:即,函数的输出仅依赖于输入,并且没有副作用,比如更新全局状态。您可以在JAX sharp bits: Pure functions中找到关于这一点的讨论。
在机器学习的背景下,这种约束可能会带来一些挑战,因为状态可以以多种形式存在。例如:
-
模型参数,
-
优化器状态,以及
-
像BatchNorm这样的有状态层。
本节提供了如何在 JAX 程序中正确处理状态的一些建议。
一个简单的例子:计数器
让我们首先看一个简单的有状态程序:一个计数器。
import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self):
self.n = 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
counter = Counter()
for _ in range(3):
print(counter.count())
1
2
3
计数器的n
属性在连续调用count
时维护计数器的状态。调用count
的副作用是修改它。
假设我们想要快速计数,所以我们即时编译count
方法。(在这个例子中,这实际上不会以任何方式加快速度,由于很多原因,但把它看作是模型参数更新的玩具模型,jit()
确实产生了巨大的影响)。
counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
print(fast_count())
1
1
1
哦不!我们的计数器不能工作了。这是因为
self.n += 1
在count
中涉及副作用:它直接修改了输入的计数器,因此此函数不受jit
支持。这样的副作用仅在首次跟踪函数时执行一次,后续调用将不会重复该副作用。那么,我们该如何修复它呢?
解决方案:显式状态
问题的一部分在于我们的计数器返回值不依赖于参数,这意味着编译输出中包含了一个常数。但它不应该是一个常数 - 它应该依赖于状态。那么,为什么我们不将状态作为一个参数呢?
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
def reset(self) -> CounterState:
return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)
1
2
3
在这个Counter
的新版本中,我们将n
移动到count
的参数中,并添加了另一个返回值,表示新的、更新的状态。现在,为了使用这个计数器,我们需要显式地跟踪状态。但作为回报,我们现在可以安全地使用jax.jit
这个计数器:
state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
1
2
3
一个一般的策略
我们可以将同样的过程应用到任何有状态方法中,将其转换为无状态方法。我们拿一个形式如下的类
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:
并将其转换为以下形式的类
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
这是一个常见的函数式编程模式,本质上就是处理所有 JAX 程序中状态的方式。
注意,一旦我们按照这种方式重写它,类的必要性就不那么明显了。我们可以只保留stateless_method
,因为类不再执行任何工作。这是因为,像我们刚刚应用的策略一样,面向对象编程(OOP)是帮助程序员理解程序状态的一种方式。
在我们的情况下,CounterV2
类只是一个名称空间,将所有使用 CounterState
的函数集中在一个位置。读者可以思考:将其保留为类是否有意义?
顺便说一句,你已经在 JAX 伪随机性 API 中看到了这种策略的示例,即 jax.random
,在 :ref:pseudorandom-numbers
部分展示。与 Numpy 不同,后者使用隐式更新的有状态类管理随机状态,而 JAX 要求程序员直接使用随机生成器状态——PRNG 密钥。
简单的工作示例:线性回归
现在让我们将这种策略应用到一个简单的机器学习模型上:通过梯度下降进行线性回归。
这里,我们只处理一种状态:模型参数。但通常情况下,你会看到许多种状态在 JAX 函数中交替出现,比如优化器状态、批归一化的层统计数据等。
需要仔细查看的函数是 update
。
from typing import NamedTuple
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * x + params.bias
return jnp.mean((pred - y) ** 2)
LEARNING_RATE = 0.005
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
"""Performs one SGD update step on params using the given data."""
grad = jax.grad(loss)(params, x, y)
# If we were using Adam or another stateful optimizer,
# we would also do something like
#
# updates, new_optimizer_state = optimizer(grad, optimizer_state)
#
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params
注意,我们手动地将参数输入和输出到更新函数中。
import matplotlib.pyplot as plt
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise
# Fit regression
params = init(rng)
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
/tmp/ipykernel_2992/721844192.py:37: DeprecationWarning: jax.tree_map is deprecated: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).
new_params = jax.tree_map(
进一步探讨
上述描述的策略是任何使用 jit
、vmap
、grad
等转换的 JAX 程序必须处理状态的方式。
如果只涉及两个参数,手动处理参数似乎还可以接受,但如果是有数十层的神经网络呢?你可能已经开始担心两件事情:
-
我们是否应该手动初始化它们,基本上是在前向传播定义中已经编写过的内容?
-
我们是否应该手动处理所有这些事情?
处理这些细节可能有些棘手,但有一些库的示例可以为您解决这些问题。请参阅JAX 神经网络库获取一些示例。
进一步资源
用户指南
用户指南是对 JAX 内特定主题的深入探讨,随着您的 JAX 项目发展成为更大或部署代码库,这些主题变得更为相关。
调试和性能
-
如何在 JAX 中思考
-
对 JAX 程序进行性能分析
-
设备内存分析
-
JAX 中的运行时值调试
-
GPU 性能技巧
-
持久化编译缓存
开发
-
理解 Jaxprs
-
JAX 中的外部回调
-
类型提升语义
-
Pytrees
运行时间
-
提前降低和编译
-
导出和序列化
-
JAX 错误
-
转移保护
自定义操作
- Pallas:一种 JAX 内核语言
如何在 JAX 中思考
原文:
jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html
JAX 提供了一个简单而强大的 API 用于编写加速数值代码,但在 JAX 中有效工作有时需要额外考虑。本文档旨在帮助建立对 JAX 如何运行的基础理解,以便您更有效地使用它。
JAX vs. NumPy
关键概念:
-
JAX 提供了一个方便的类似于 NumPy 的接口。
-
通过鸭子类型,JAX 数组通常可以直接替换 NumPy 数组。
-
不像 NumPy 数组,JAX 数组总是不可变的。
NumPy 提供了一个众所周知且功能强大的 API 用于处理数值数据。为方便起见,JAX 提供了 jax.numpy
,它紧密反映了 NumPy API,并为进入 JAX 提供了便捷的入口。几乎可以用 jax.numpy
完成 numpy
可以完成的任何事情:
import matplotlib.pyplot as plt
import numpy as np
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np);
import jax.numpy as jnp
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
代码块除了用 jnp
替换 np
外,其余完全相同。正如我们所见,JAX 数组通常可以直接替换 NumPy 数组,用于诸如绘图等任务。
这些数组本身是作为不同的 Python 类型实现的:
type(x_np)
numpy.ndarray
type(x_jnp)
jaxlib.xla_extension.ArrayImpl
Python 的 鸭子类型 允许在许多地方可互换使用 JAX 数组和 NumPy 数组。
然而,JAX 和 NumPy 数组之间有一个重要的区别:JAX 数组是不可变的,一旦创建,其内容无法更改。
这里有一个在 NumPy 中突变数组的例子:
# NumPy: mutable arrays
x = np.arange(10)
x[0] = 10
print(x)
[10 1 2 3 4 5 6 7 8 9]
在 JAX 中,等效操作会导致错误,因为 JAX 数组是不可变的:
%xmode minimal
Exception reporting mode: Minimal
# JAX: immutable arrays
x = jnp.arange(10)
x[0] = 10
TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html
对于更新单个元素,JAX 提供了一个 索引更新语法,返回一个更新后的副本:
y = x.at[0].set(10)
print(x)
print(y)
[0 1 2 3 4 5 6 7 8 9]
[10 1 2 3 4 5 6 7 8 9]
NumPy、lax 和 XLA:JAX API 层次结构
关键概念:
-
jax.numpy
是一个提供熟悉接口的高级包装器。 -
jax.lax
是一个更严格且通常更强大的低级 API。 -
所有 JAX 操作都是基于 XLA – 加速线性代数编译器中的操作实现的。
如果您查看 jax.numpy
的源代码,您会看到所有操作最终都是以 jax.lax
中定义的函数形式表达的。您可以将 jax.lax
视为更严格但通常更强大的 API,用于处理多维数组。
例如,虽然jax.numpy
将隐式促进参数以允许不同数据类型之间的操作,但jax.lax
不会:
import jax.numpy as jnp
jnp.add(1, 1.0) # jax.numpy API implicitly promotes mixed types.
Array(2., dtype=float32, weak_type=True)
from jax import lax
lax.add(1, 1.0) # jax.lax API requires explicit type promotion.
MLIRError: Verification failed:
error: "jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))): op requires the same element type for all operands and results
The above exception was the direct cause of the following exception:
ValueError: Cannot lower jaxpr with verifier errors:
op requires the same element type for all operands and results
at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_2814/3435837498.py":2:0) at callsite("run_code"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3577:0) at callsite("run_ast_nodes"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3517:0) at callsite("run_cell_async"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3334:0) at callsite("_pseudo_sync_runner"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/async_helpers.py":129:0) at callsite("_run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3130:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/IPython/core/interactiveshell.py":3075:0) at callsite("run_cell"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/zmqshell.py":549:0) at callsite("do_execute"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/ipkernel.py":449:0) at "execute_request"("/home/docs/checkouts/readthedocs.org/user_builds/jax/envs/latest/lib/python3.10/site-packages/ipykernel/kernelbase.py":778:0))))))))))))
Define JAX_DUMP_IR_TO to dump the module.
如果直接使用jax.lax
,在这种情况下你将需要显式地进行类型提升:
lax.add(jnp.float32(1), 1.0)
Array(2., dtype=float32)
除了这种严格性外,jax.lax
还提供了一些比 NumPy 支持的更一般操作更高效的 API。
例如,考虑一个 1D 卷积,在 NumPy 中可以这样表达:
x = jnp.array([1, 2, 1])
y = jnp.ones(10)
jnp.convolve(x, y)
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
在幕后,这个 NumPy 操作被转换为由lax.conv_general_dilated
实现的更通用的卷积:
from jax import lax
result = lax.conv_general_dilated(
x.reshape(1, 1, 3).astype(float), # note: explicit promotion
y.reshape(1, 1, 10),
window_strides=(1,),
padding=[(len(y) - 1, len(y) - 1)]) # equivalent of padding='full' in NumPy
result[0, 0]
Array([1., 3., 4., 4., 4., 4., 4., 4., 4., 4., 3., 1.], dtype=float32)
这是一种批处理卷积操作,专为深度神经网络中经常使用的卷积类型设计,需要更多的样板代码,但比 NumPy 提供的卷积更灵活和可扩展(有关 JAX 卷积的更多细节,请参见Convolutions in JAX)。
从本质上讲,所有jax.lax
操作都是 XLA 中操作的 Python 包装器;例如,在这里,卷积实现由XLA:ConvWithGeneralPadding提供。每个 JAX 操作最终都是基于这些基本 XLA 操作表达的,这就是使得即时(JIT)编译成为可能的原因。
要 JIT 或不要 JIT
关键概念:
-
默认情况下,JAX 按顺序逐个执行操作。
-
使用即时(JIT)编译装饰器,可以优化操作序列并一次运行:
-
并非所有 JAX 代码都可以进行 JIT 编译,因为它要求数组形状在编译时是静态且已知的。
所有 JAX 操作都是基于 XLA 表达的事实,使得 JAX 能够使用 XLA 编译器非常高效地执行代码块。
例如,考虑此函数,它对二维矩阵的行进行标准化,表达为jax.numpy
操作:
import jax.numpy as jnp
def norm(X):
X = X - X.mean(0)
return X / X.std(0)
可以使用jax.jit
变换创建函数的即时编译版本:
from jax import jit
norm_compiled = jit(norm)
此函数返回与原始函数相同的结果,达到标准浮点精度:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-6)
True
但由于编译(其中包括操作的融合、避免分配临时数组以及其他许多技巧),在 JIT 编译的情况下,执行时间可以比非常数级别快得多(请注意使用block_until_ready()
以考虑 JAX 的异步调度):
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()
319 μs ± 1.98 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
272 μs ± 849 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
话虽如此,jax.jit
确实存在一些限制:特别是,它要求所有数组具有静态形状。这意味着一些 JAX 操作与 JIT 编译不兼容。
例如,此操作可以在逐操作模式下执行:
def get_negatives(x):
return x[x < 0]
x = jnp.array(np.random.randn(10))
get_negatives(x)
Array([-0.10570311, -0.59403396, -0.8680282 , -0.23489487], dtype=float32)
但如果您尝试在 jit 模式下执行它,则会返回错误:
jit(get_negatives)(x)
NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[10])
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError
这是因为该函数生成的数组形状在编译时未知:输出的大小取决于输入数组的值,因此与 JIT 不兼容。
JIT 机制:跟踪和静态变量
关键概念:
-
JIT 和其他 JAX 转换通过跟踪函数来确定其对特定形状和类型输入的影响。
-
不希望被追踪的变量可以标记为静态
要有效使用 jax.jit
,理解其工作原理是很有用的。让我们在一个 JIT 编译的函数中放几个 print()
语句,然后调用该函数:
@jit
def f(x, y):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
result = jnp.dot(x + 1, y + 1)
print(f" result = {result}")
return result
x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
Running f():
x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Array([0.25773212, 5.3623195 , 5.403243 ], dtype=float32)
注意,打印语句执行,但打印的不是我们传递给函数的数据,而是打印追踪器对象,这些对象代替它们。
这些追踪器对象是 jax.jit
用来提取函数指定的操作序列的基本替代物,编码数组的形状和dtype,但对值是不可知的。然后可以有效地将这个记录的计算序列应用于具有相同形状和 dtype 的新输入,而无需重新执行 Python 代码。
当我们在匹配的输入上再次调用编译函数时,无需重新编译,也不打印任何内容,因为结果在编译的 XLA 中计算,而不是在 Python 中:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)
Array([1.4344584, 4.3004413, 7.9897013], dtype=float32)
提取的操作序列编码在 JAX 表达式中,简称为 jaxpr。您可以使用 jax.make_jaxpr
转换查看 jaxpr:
from jax import make_jaxpr
def f(x, y):
return jnp.dot(x + 1, y + 1)
make_jaxpr(f)(x, y)
{ lambda ; a:f32[3,4] b:f32[4]. let
c:f32[3,4] = add a 1.0
d:f32[4] = add b 1.0
e:f32[3] = dot_general[
dimension_numbers=(([1], [0]), ([], []))
preferred_element_type=float32
] c d
in (e,) }
注意这一后果:因为 JIT 编译是在没有数组内容信息的情况下完成的,所以函数中的控制流语句不能依赖于追踪的值。例如,这将失败:
@jit
def f(x, neg):
return -x if neg else x
f(1, True)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_2814/2422663986.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument neg.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError
如果有不希望被追踪的变量,可以将它们标记为静态以供 JIT 编译使用:
from functools import partial
@partial(jit, static_argnums=(1,))
def f(x, neg):
return -x if neg else x
f(1, True)
Array(-1, dtype=int32, weak_type=True)
请注意,使用不同的静态参数调用 JIT 编译函数会导致重新编译,所以函数仍然如预期般工作:
f(1, False)
Array(1, dtype=int32, weak_type=True)
理解哪些值和操作将是静态的,哪些将被追踪,是有效使用 jax.jit
的关键部分。
静态与追踪操作
关键概念:
-
就像值可以是静态的或者被追踪的一样,操作也可以是静态的或者被追踪的。
-
静态操作在 Python 中在编译时评估;跟踪操作在 XLA 中在运行时编译并评估。
-
使用
numpy
进行您希望静态的操作;使用jax.numpy
进行您希望被追踪的操作。
静态和追踪值的区别使得重要的是考虑如何保持静态值的静态。考虑这个函数:
import jax.numpy as jnp
from jax import jit
@jit
def f(x):
return x.reshape(jnp.array(x.shape).prod())
x = jnp.ones((2, 3))
f(x)
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
The error occurred while tracing the function f at /tmp/ipykernel_2814/1983583872.py:4 for jit. This value became a tracer due to JAX operations on these lines:
operation a:i32[2] = convert_element_type[new_dtype=int32 weak_type=False] b
from line /tmp/ipykernel_2814/1983583872.py:6 (f)
这会因为找到追踪器而不是整数类型的具体值的 1D 序列而失败。让我们向函数中添加一些打印语句,以了解其原因:
@jit
def f(x):
print(f"x = {x}")
print(f"x.shape = {x.shape}")
print(f"jnp.array(x.shape).prod() = {jnp.array(x.shape).prod()}")
# comment this out to avoid the error:
# return x.reshape(jnp.array(x.shape).prod())
f(x)
x = Traced<ShapedArray(float32[2,3])>with<DynamicJaxprTrace(level=1/0)>
x.shape = (2, 3)
jnp.array(x.shape).prod() = Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
注意尽管x
被追踪,x.shape
是一个静态值。然而,当我们在这个静态值上使用jnp.array
和jnp.prod
时,它变成了一个被追踪的值,在这种情况下,它不能用于像reshape()
这样需要静态输入的函数(回想:数组形状必须是静态的)。
一个有用的模式是使用numpy
进行应该是静态的操作(即在编译时完成),并使用jax.numpy
进行应该被追踪的操作(即在运行时编译和执行)。对于这个函数,可能会像这样:
from jax import jit
import jax.numpy as jnp
import numpy as np
@jit
def f(x):
return x.reshape((np.prod(x.shape),))
f(x)
Array([1., 1., 1., 1., 1., 1.], dtype=float32)
因此,在 JAX 程序中的一个标准约定是import numpy as np
和import jax.numpy as jnp
,这样两个接口都可以用来更精细地控制操作是以静态方式(使用numpy
,一次在编译时)还是以追踪方式(使用jax.numpy
,在运行时优化)执行。
对 JAX 程序进行性能分析
使用 Perfetto 查看程序跟踪
我们可以使用 JAX 分析器生成可以使用Perfetto 可视化工具查看的 JAX 程序的跟踪。目前,此方法会阻塞程序,直到点击链接并加载 Perfetto UI 以打开跟踪为止。如果您希望获取性能分析信息而无需任何交互,请查看下面的 Tensorboard 分析器。
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
计算完成后,程序会提示您打开链接到ui.perfetto.dev
。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。
加载链接后,程序执行将继续。链接在打开一次后将不再有效,但将重定向到一个保持有效的新 URL。然后,您可以在 Perfetto UI 中单击“共享”按钮,创建可与他人共享的跟踪的永久链接。
远程分析
在对远程运行的代码进行性能分析(例如在托管的虚拟机上)时,您需要在端口 9001 上建立 SSH 隧道以使链接工作。您可以使用以下命令执行此操作:
$ ssh -L 9001:127.0.0.1:9001 <user>@<host>
或者如果您正在使用 Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 9001:127.0.0.1:9001
手动捕获
而不是使用jax.profiler.trace
以编程方式捕获跟踪,您可以通过在感兴趣的脚本中调用jax.profiler.start_server(<port>)
来启动分析服务器。如果您只需在脚本的某部分保持分析服务器活动,则可以通过调用jax.profiler.stop_server()
来关闭它。
脚本运行后并且分析服务器已启动后,我们可以通过运行以下命令手动捕获和跟踪:
$ python -m jax.collect_profile <port> <duration_in_ms>
默认情况下,生成的跟踪信息会被转储到临时目录中,但可以通过传递--log_dir=<自定义目录>
来覆盖此设置。另外,默认情况下,程序将提示您打开链接到ui.perfetto.dev
。打开链接后,Perfetto UI 将加载跟踪文件并打开可视化工具。通过传递--no_perfetto_link
命令可以禁用此功能。或者,您也可以将 Tensorboard 指向log_dir
以分析跟踪(参见下面的“Tensorboard 分析”部分)。
TensorBoard 性能分析
TensorBoard 的分析器可用于分析 JAX 程序。Tensorboard 是获取和可视化程序性能跟踪和分析(包括 GPU 和 TPU 上的活动)的好方法。最终结果看起来类似于这样:
安装
TensorBoard 分析器仅与捆绑有 TensorFlow 的 TensorBoard 版本一起提供。
pip install tensorflow tensorboard-plugin-profile
如果您已安装了 TensorFlow,则只需安装tensorboard-plugin-profile
pip 包。请注意仅安装一个版本的 TensorFlow 或 TensorBoard,否则可能会遇到下面描述的“重复插件”错误。有关安装 TensorBoard 的更多信息,请参见www.tensorflow.org/guide/profiler
。
程序化捕获
您可以通过jax.profiler.start_trace()
和jax.profiler.stop_trace()
方法来配置您的代码以捕获性能分析器的追踪。调用start_trace()
时需要指定写入追踪文件的目录。这个目录应该与启动 TensorBoard 时使用的--logdir
目录相同。然后,您可以使用 TensorBoard 来查看这些追踪信息。
例如,要获取性能分析器的追踪:
import jax
jax.profiler.start_trace("/tmp/tensorboard")
# Run the operations to be profiled
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
jax.profiler.stop_trace()
注意block_until_ready()
调用。我们使用这个函数来确保设备上的执行被追踪到。有关为什么需要这样做的详细信息,请参见异步调度部分。
您还可以使用jax.profiler.trace()
上下文管理器作为start_trace
和stop_trace
的替代方法:
import jax
with jax.profiler.trace("/tmp/tensorboard"):
key = jax.random.key(0)
x = jax.random.normal(key, (5000, 5000))
y = x @ x
y.block_until_ready()
要查看追踪信息,请首先启动 TensorBoard(如果尚未启动):
$ tensorboard --logdir=/tmp/tensorboard
[...]
Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.5.0 at http://localhost:6006/ (Press CTRL+C to quit)
在这个示例中,您应该能够在localhost:6006/
加载 TensorBoard。您可以使用--port
标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。
然后,要么在右上角的下拉菜单中选择“Profile”,要么直接访问localhost:6006/#profile
。可用的追踪信息会显示在左侧的“Runs”下拉菜单中。选择您感兴趣的运行,并在“Tools”下选择trace_viewer
。现在您应该能看到执行时间轴。您可以使用 WASD 键来导航追踪信息,点击或拖动以选择事件并查看底部的更多详细信息。有关使用追踪查看器的更多详细信息,请参阅这些 TensorFlow 文档。
您还可以使用memory_viewer
、op_profile
和graph_viewer
工具。
通过 TensorBoard 手动捕获
以下是从运行中的程序中手动触发 N 秒追踪的捕获说明。
-
启动 TensorBoard 服务器:
tensorboard --logdir /tmp/tensorboard/
在
localhost:6006/
处应该能够加载 TensorBoard。您可以使用--port
标志指定不同的端口。如果在远程服务器上运行 JAX,请参见下面的远程机器上的分析。 -
在您希望进行分析的 Python 程序或进程中,将以下内容添加到开头的某个位置:
import jax.profiler jax.profiler.start_server(9999)
这将启动 TensorBoard 连接到的性能分析器服务器。在继续下一步之前,必须先运行性能分析器服务器。完成后,可以调用
jax.profiler.stop_server()
来关闭它。如果你想要分析一个长时间运行的程序片段(例如长时间的训练循环),你可以将此代码放在程序开头并像往常一样启动程序。如果你想要分析一个短程序(例如微基准测试),一种选择是在 IPython shell 中启动分析器服务器,并在下一步开始捕获后用
%run
运行短程序。另一种选择是在程序开头启动分析器服务器,并使用time.sleep()
给你足够的时间启动捕获。 -
打开
localhost:6006/#profile
,并点击左上角的“CAPTURE PROFILE”按钮。将“localhost:9999”作为分析服务的 URL(这是你在上一步中启动的分析器服务器的地址)。输入你想要进行分析的毫秒数,然后点击“CAPTURE”。 -
如果你想要分析的代码尚未运行(例如在 Python shell 中启动了分析器服务器),请在进行捕获时运行它。
-
捕获完成后,TensorBoard 应会自动刷新。(并非所有 TensorBoard 分析功能都与 JAX 连接,所以初始时看起来可能没有捕获到任何内容。)在左侧的“工具”下,选择
trace_viewer
。现在你应该可以看到执行的时间轴。你可以使用 WASD 键来导航跟踪,点击或拖动选择事件以在底部查看更多详细信息。参见这些 TensorFlow 文档获取有关使用跟踪查看器的更多详细信息。
你也可以使用
memory_viewer
、op_profile
和graph_viewer
工具。
添加自定义跟踪事件
默认情况下,跟踪查看器中的事件大多是低级内部 JAX 函数。你可以使用 jax.profiler.TraceAnnotation
和 jax.profiler.annotate_function()
在你的代码中添加自定义事件和函数。
故障排除
GPU 分析
运行在 GPU 上的程序应该在跟踪查看器顶部附近生成 GPU 流的跟踪。如果只看到主机跟踪,请检查程序日志和/或输出,查看以下错误消息。
如果出现类似 Could not load dynamic library 'libcupti.so.10.1'
的错误
完整错误:
W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcupti.so.10.1'; dlerror: libcupti.so.10.1: cannot open shared object file: No such file or directory
2020-06-12 13:19:59.822799: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1422] function cupti_interface_->Subscribe( &subscriber_, (CUpti_CallbackFunc)ApiCallback, this)failed with error CUPTI could not be loaded or symbol could not be found.
将libcupti.so
的路径添加到环境变量LD_LIBRARY_PATH
中。(尝试使用locate libcupti.so
来找到路径。)例如:
export LD_LIBRARY_PATH=/usr/local/cuda-10.1/extras/CUPTI/lib64/:$LD_LIBRARY_PATH
即使在做了以上步骤后仍然收到 Could not load dynamic library
错误消息,请检查 GPU 跟踪是否仍然显示在跟踪查看器中。有时即使一切正常,它也会出现此消息,因为它在多个位置查找 libcupti
库。
如果出现类似 failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
的错误
完整错误:
E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1445] function cupti_interface_->EnableCallback( 0 , subscriber_, CUPTI_CB_DOMAIN_DRIVER_API, cbid)failed with error CUPTI_ERROR_INSUFFICIENT_PRIVILEGES
2020-06-12 14:31:54.097791: E external/org_tensorflow/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc:1487] function cupti_interface_->ActivityDisable(activity)failed with error CUPTI_ERROR_NOT_INITIALIZED
运行以下命令(注意这将需要重新启动):
echo 'options nvidia "NVreg_RestrictProfilingToAdminUsers=0"' | sudo tee -a /etc/modprobe.d/nvidia-kernel-common.conf
sudo update-initramfs -u
sudo reboot now
查看更多关于此错误的信息,请参阅NVIDIA 的文档。
在远程机器上进行性能分析
如果要分析的 JAX 程序正在远程机器上运行,一种选择是在远程机器上执行上述所有说明(特别是在远程机器上启动 TensorBoard 服务器),然后使用 SSH 本地端口转发从本地访问 TensorBoard Web UI。使用以下 SSH 命令将默认的 TensorBoard 端口 6006 从本地转发到远程机器:
ssh -L 6006:localhost:6006 <remote server address>
或者如果您正在使用 Google Cloud:
$ gcloud compute ssh <machine-name> -- -L 6006:localhost:6006
``` #### 多个 TensorBoard 安装
**如果启动 TensorBoard 失败,并出现类似于`ValueError: Duplicate plugins for name projector`的错误**
这通常是因为安装了两个版本的 TensorBoard 和/或 TensorFlow(例如,`tensorflow`、`tf-nightly`、`tensorboard`和`tb-nightly` pip 包都包含 TensorBoard)。卸载一个 pip 包可能会导致`tensorboard`可执行文件被移除,难以替换,因此可能需要卸载所有内容并重新安装单个版本:
```py
pip uninstall tensorflow tf-nightly tensorboard tb-nightly
pip install tensorflow
Nsight
NVIDIA 的Nsight
工具可用于跟踪和分析 GPU 上的 JAX 代码。有关详情,请参阅Nsight
文档。
设备内存分析
原文:
jax.readthedocs.io/en/latest/device_memory_profiling.html
注意
2023 年 5 月更新:我们建议使用 Tensorboard 进行设备内存分析。在进行分析后,打开 Tensorboard 分析器的 memory_viewer
标签以获取更详细和易于理解的设备内存使用情况。
JAX 设备内存分析器允许我们探索 JAX 程序如何以及为何使用 GPU 或 TPU 内存。例如,它可用于:
-
查明在特定时间点哪些数组和可执行文件位于 GPU 内存中,或者
-
追踪内存泄漏。
安装
JAX 设备内存分析器生成的输出可使用 pprof (google/pprof) 解释。首先按照其 安装说明 安装 pprof
。撰写时,安装 pprof
需要先安装版本为 1.16+ 的 Go,Graphviz,然后运行
go install github.com/google/pprof@latest
安装 pprof
作为 $GOPATH/bin/pprof
,其中 GOPATH
默认为 ~/go
。
注意
来自 google/pprof 的 pprof
版本与作为 gperftools
软件包一部分分发的同名旧工具不同。gperftools
版本的 pprof
不适用于 JAX。
理解 JAX 程序如何使用 GPU 或 TPU 内存
设备内存分析器的常见用途是找出为何 JAX 程序使用大量 GPU 或 TPU 内存,例如调试内存不足问题。
要将设备内存分析保存到磁盘,使用 jax.profiler.save_device_memory_profile()
。例如,考虑以下 Python 程序:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.key(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
如果我们首先运行上述程序,然后执行
pprof --web memory.prof
pprof
打开一个包含设备内存分析调用图格式的 Web 浏览器:
调用图是在每个活动缓冲区分配的 Python 栈的可视化。例如,在这个特定情况下,可视化显示 func2
及其被调用者负责分配了 76.30MB,其中 38.15MB 是在从 func1
到 func2
的调用中分配的。有关如何解释调用图可视化的更多信息,请参阅 pprof 文档。
使用 jax.jit()
编译的函数对设备内存分析器不透明。也就是说,任何在 jit
编译函数内部分配的内存都将归因于整个函数。
在本例中,调用 block_until_ready()
是为了确保在收集设备内存分析之前 func2
完成。有关更多详细信息,请参阅异步调度。
调试内存泄漏
我们还可以使用 JAX 设备内存分析器,通过使用 pprof
来可视化在不同时间点获取的两个设备内存配置文件中的内存使用情况变化,以追踪内存泄漏。例如,考虑以下程序,该程序将 JAX 数组累积到一个不断增长的 Python 列表中。
import jax
import jax.numpy as jnp
import jax.profiler
def afunction():
return jax.random.normal(jax.random.key(77), (1000000,))
z = afunction()
def anotherfunc():
arrays = []
for i in range(1, 10):
x = jax.random.normal(jax.random.key(42), (i, 10000))
arrays.append(x)
x.block_until_ready()
jax.profiler.save_device_memory_profile(f"memory{i}.prof")
anotherfunc()
如果我们仅在执行结束时可视化设备内存配置文件(memory9.prof
),则可能不明显,即 anotherfunc
中的每次循环迭代都会累积更多的设备内存分配:
pprof --web memory9.prof
在 afunction
内部的大型但固定分配主导配置文件,但不会随时间增长。
通过使用 pprof
的 --diff_base
功能 来可视化循环迭代中内存使用情况的变化,我们可以找出程序内存使用量随时间增加的原因:
pprof --web --diff_base memory1.prof memory9.prof
可视化显示,内存增长可以归因于 anotherfunc
中对 normal
的调用。
在 JAX 中进行运行时值调试
是否遇到梯度爆炸?NaN 使你牙齿咬紧?只想查看计算中间值?请查看以下 JAX 调试工具!本页提供了 TL;DR 摘要,并且您可以点击底部的“阅读更多”链接了解更多信息。
目录:
-
使用
jax.debug
进行交互式检查 -
使用 jax.experimental.checkify 进行功能错误检查
-
使用 JAX 的调试标志抛出 Python 错误
使用 jax.debug
进行交互式检查
TL;DR 使用 jax.debug.print()
在 jax.jit
、jax.pmap
和 pjit
装饰的函数中将值打印到 stdout,并使用 jax.debug.breakpoint()
暂停执行编译函数以检查调用堆栈中的值:
import jax
import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("
标签:中文,jnp,JAX,jax,print,文档,debug,checkify
From: https://www.cnblogs.com/apachecn/p/18260395