JAX 中文文档(八)
自动微分手册
原文:
jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
alexbw@, mattjj@
JAX 拥有非常通用的自动微分系统。在这本手册中,我们将介绍许多巧妙的自动微分思想,您可以根据自己的工作进行选择。
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.key(0)
梯度
从grad
开始
您可以使用grad
对函数进行微分:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
0.070650816
grad
接受一个函数并返回一个函数。如果您有一个评估数学函数 ( f ) 的 Python 函数 f
,那么 grad(f)
是一个评估数学函数 ( \nabla f ) 的 Python 函数。这意味着 grad(f)(x)
表示值 ( \nabla f(x) )。
由于grad
操作函数,您可以将其应用于其自身的输出以多次进行微分:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
-0.13621868
0.25265405
让我们看看如何在线性逻辑回归模型中使用grad
计算梯度。首先,设置:
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())
使用argnums
参数的grad
函数来相对于位置参数微分函数。
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644 -1.4901346 ]
b_grad -0.29227245
此grad
API 直接对应于 Spivak 经典著作Calculus on Manifolds(1965)中的优秀符号,也用于 Sussman 和 Wisdom 的Structure and Interpretation of Classical Mechanics(2015)及其Functional Differential Geometry(2013)。这两本书都是开放获取的。特别是参见Functional Differential Geometry的“序言”部分,以了解此符号的辩护。
当使用argnums
参数时,如果f
是一个用于计算数学函数 ( f ) 的 Python 函数,则 Python 表达式grad(f, i)
用于评估 ( \partial_i f ) 的 Python 函数。
相对于嵌套列表、元组和字典进行微分
使用标准的 Python 容器进行微分是完全有效的,因此可以随意使用元组、列表和字典(以及任意嵌套)。
def loss2(params_dict):
preds = predict(params_dict['W'], params_dict['b'], inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
print(grad(loss2)({'W': W, 'b': b}))
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}
您可以注册您自己的容器类型以便不仅与grad
一起工作,还可以与所有 JAX 转换(jit
、vmap
等)一起工作。
使用value_and_grad
评估函数及其梯度
另一个方便的函数是value_and_grad
,可以高效地计算函数值及其梯度值:
from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))
loss value 3.0519385
loss value 3.0519385
与数值差分进行对比
导数的一个很好的特性是它们很容易用有限差分进行检查:
# Set a step size for finite differences calculations
eps = 1e-4
# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))
# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117
JAX 提供了一个简单的便利函数,本质上执行相同的操作,但可以检查任何您喜欢的微分顺序:
from jax.test_util import check_grads
check_grads(loss, (W, b), order=2) # check up to 2nd order derivatives
使用 grad
-of-grad
进行 Hessian 向量乘积
使用高阶 grad
可以构建一个 Hessian 向量乘积函数。 (稍后我们将编写一个更高效的实现,该实现混合了前向和反向模式,但这个实现将纯粹使用反向模式。)
在最小化平滑凸函数的截断牛顿共轭梯度算法或研究神经网络训练目标的曲率(例如1,2,3,4)中,Hessian 向量乘积函数非常有用。
对于一个标量值函数 ( f : \mathbb{R}^n \to \mathbb{R} ),具有连续的二阶导数(因此 Hessian 矩阵是对称的),点 ( x \in \mathbb{R}^n ) 处的 Hessian 被写为 (\partial² f(x))。然后,Hessian 向量乘积函数能够评估
(\qquad v \mapsto \partial² f(x) \cdot v)
对于任意 ( v \in \mathbb{R}^n )。
窍门在于不要实例化完整的 Hessian 矩阵:如果 ( n ) 很大,例如在神经网络的背景下可能是百万或十亿级别,那么可能无法存储。
幸运的是,grad
已经为我们提供了一种编写高效的 Hessian 向量乘积函数的方法。我们只需使用下面的身份证
(\qquad \partial² f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)),
其中 ( g(x) = \partial f(x) \cdot v ) 是一个新的标量值函数,它将 ( f ) 在 ( x ) 处的梯度与向量 ( v ) 点乘。请注意,我们只对向量值参数的标量值函数进行微分,这正是我们知道 grad
高效的地方。
在 JAX 代码中,我们可以直接写成这样:
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这个例子表明,您可以自由使用词汇闭包,而 JAX 绝不会感到不安或困惑。
一旦我们看到如何计算密集的 Hessian 矩阵,我们将在几个单元格下检查此实现。我们还将编写一个更好的版本,该版本同时使用前向模式和反向模式。
使用 jacfwd
和 jacrev
计算 Jacobians 和 Hessians
您可以使用 jacfwd
和 jacrev
函数计算完整的 Jacobian 矩阵:
from jax import jacfwd, jacrev
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)
J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)
jacfwd result, with shape (4, 3)
[[ 0.05981758 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188288 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
这两个函数计算相同的值(直到机器数学),但它们在实现上有所不同:jacfwd
使用前向模式自动微分,对于“高”的 Jacobian 矩阵更有效,而 jacrev
使用反向模式,对于“宽”的 Jacobian 矩阵更有效。对于接近正方形的矩阵,jacfwd
可能比 jacrev
有优势。
您还可以在容器类型中使用 jacfwd
和 jacrev
:
def predict_dict(params, inputs):
return predict(params['W'], params['b'], inputs)
J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
print("Jacobian from {} to logits is".format(k))
print(v)
Jacobian from W to logits is
[[ 0.05981757 0.12883787 0.08857603]
[ 0.04015916 -0.04928625 0.00684531]
[ 0.12188289 0.01406341 -0.3047072 ]
[ 0.00140431 -0.00472531 0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]
关于前向模式和反向模式的更多细节,以及如何尽可能高效地实现 jacfwd
和 jacrev
,请继续阅读!
使用两个这些函数的复合给我们一种计算密集的 Hessian 矩阵的方法:
def hessian(f):
return jacfwd(jacrev(f))
H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)
hessian, with shape (4, 3, 3)
[[[ 0.02285465 0.04922541 0.03384247]
[ 0.04922541 0.10602397 0.07289147]
[ 0.03384247 0.07289147 0.05011288]]
[[-0.03195215 0.03921401 -0.00544639]
[ 0.03921401 -0.04812629 0.00668421]
[-0.00544639 0.00668421 -0.00092836]]
[[-0.01583708 -0.00182736 0.03959271]
[-0.00182736 -0.00021085 0.00456839]
[ 0.03959271 0.00456839 -0.09898177]]
[[-0.00103524 0.00348343 -0.00194457]
[ 0.00348343 -0.01172127 0.0065432 ]
[-0.00194457 0.0065432 -0.00365263]]]
这种形状是合理的:如果我们从一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m) 开始,那么在点 (x \in \mathbb{R}^n) 我们期望得到以下形状
-
(f(x) \in \mathbb{R}^m),在 (x) 处的 (f) 的值,
-
(\partial f(x) \in \mathbb{R}^{m \times n}),在 (x) 处的雅可比矩阵,
-
(\partial² f(x) \in \mathbb{R}^{m \times n \times n}),在 (x) 处的 Hessian 矩阵,
以及其他一些内容。
要实现 hessian
,我们可以使用 jacfwd(jacrev(f))
或 jacrev(jacfwd(f))
或这两者的任何组合。但是前向超过反向通常是最有效的。这是因为在内部雅可比计算中,我们通常是在不同 iating 一个函数宽雅可比(也许像损失函数 (f : \mathbb{R}^n \to \mathbb{R})),而在外部雅可比计算中,我们是在不同 iating 具有方雅可比的函数(因为 (\nabla f : \mathbb{R}^n \to \mathbb{R}^n)),这就是前向模式胜出的地方。
制造过程:两个基础的自动微分函数
雅可比-向量积(JVPs,也称为前向模式自动微分)
JAX 包括前向模式和反向模式自动微分的高效和通用实现。熟悉的 grad
函数建立在反向模式之上,但要解释两种模式的区别,以及每种模式何时有用,我们需要一些数学背景。
数学中的雅可比向量积
在数学上,给定一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m),在输入点 (x \in \mathbb{R}^n) 处评估的雅可比矩阵 (\partial f(x)),通常被视为一个 (\mathbb{R}^m \times \mathbb{R}^n) 中的矩阵:
(\qquad \partial f(x) \in \mathbb{R}^{m \times n}).
但我们也可以将 (\partial f(x)) 看作是一个线性映射,它将 (f) 的定义域在点 (x) 的切空间(即另一个 (\mathbb{R}^n) 的副本)映射到 (f) 的值域在点 (f(x)) 的切空间(一个 (\mathbb{R}^m) 的副本):
(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m).
此映射称为 (f) 在 (x) 处的推前映射。雅可比矩阵只是标准基中这个线性映射的矩阵。
如果我们不确定一个特定的输入点 (x),那么我们可以将函数 (\partial f) 视为首先接受一个输入点并返回该输入点处的雅可比线性映射:
(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m)。
特别是,我们可以解开事物,这样给定输入点 (x \in \mathbb{R}^n) 和切向量 (v \in \mathbb{R}^n),我们得到一个输出切向量在 (\mathbb{R}^m) 中。我们称这种映射,从 ((x, v)) 对到输出切向量,为雅可比向量积,并将其写为
(\qquad (x, v) \mapsto \partial f(x) v)
JAX 代码中的雅可比向量积
回到 Python 代码中,JAX 的 jvp
函数模拟了这种转换。给定一个评估 (f) 的 Python 函数,JAX 的 jvp
是获取评估 ((x, v) \mapsto (f(x), \partial f(x) v)) 的 Python 函数的一种方法。
from jax import jvp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))
用类似 Haskell 的类型签名来说,我们可以写成
jvp :: (a -> b) -> a -> T a -> (b, T b)
在这里,我们使用 T a
来表示 a
的切线空间的类型。简言之,jvp
接受一个类型为 a -> b
的函数作为参数,一个类型为 a
的值,以及一个类型为 T a
的切线向量值。它返回一个由类型为 b
的值和类型为 T b
的输出切线向量组成的对。
jvp
转换后的函数的评估方式与原函数类似,但与每个类型为 a
的原始值配对时,它会沿着类型为 T a
的切线值进行推进。对于原始函数将应用的每个原始数值操作,jvp
转换后的函数会执行一个“JVP 规则”,该规则同时在这些原始值上评估原始数值,并应用其 JVP。
该评估策略对计算复杂度有一些直接影响:由于我们在进行评估时同时评估 JVP,因此我们不需要为以后存储任何内容,因此内存成本与计算深度无关。此外,jvp
转换后的函数的 FLOP 成本约为评估函数的成本的 3 倍(例如对于评估原始函数的一个单位工作,如 sin(x)
;一个单位用于线性化,如 cos(x)
;和一个单位用于将线性化函数应用于向量,如 cos_x * v
)。换句话说,对于固定的原始点 (x),我们可以以大致相同的边际成本评估 (v \mapsto \partial f(x) \cdot v),如同评估 (f) 一样。
那么内存复杂度听起来非常有说服力!那为什么我们在机器学习中很少见到正向模式呢?
要回答这个问题,首先考虑如何使用 JVP 构建完整的 Jacobian 矩阵。如果我们将 JVP 应用于一个单位切线向量,它会显示出我们输入的非零条目对应的 Jacobian 矩阵的一列。因此,我们可以逐列地构建完整的 Jacobian 矩阵,获取每列的成本大约与一个函数评估相同。对于具有“高”Jacobian 的函数来说,这将是高效的,但对于“宽”Jacobian 来说则效率低下。
如果你在机器学习中进行基于梯度的优化,你可能想要最小化一个从 (\mathbb{R}^n) 中的参数到 (\mathbb{R}) 中标量损失值的损失函数。这意味着这个函数的雅可比矩阵是一个非常宽的矩阵:(\partial f(x) \in \mathbb{R}^{1 \times n}),我们通常将其视为梯度向量 (\nabla f(x) \in \mathbb{R}^n)。逐列构建这个矩阵,每次调用需要类似数量的浮点运算来评估原始函数,看起来确实效率低下!特别是对于训练神经网络,其中 (f) 是一个训练损失函数,而 (n) 可以是百万或十亿级别,这种方法根本不可扩展。
为了更好地处理这类函数,我们只需要使用反向模式。### 向量-雅可比积(VJPs,又称反向自动微分)
在前向模式中,我们得到了一个用于评估雅可比向量积的函数,然后我们可以使用它逐列构建雅可比矩阵;而反向模式则是一种获取用于评估向量-雅可比积(或等效地雅可比-转置向量积)的函数的方式,我们可以用它逐行构建雅可比矩阵。
数学中的 VJPs
再次考虑一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m)。从我们对 JVP 的表示开始,对于 VJP 的表示非常简单:
(\qquad (x, v) \mapsto v \partial f(x)),
其中 (v) 是在 (x) 处 (f) 的余切空间的元素(同构于另一个 (\mathbb{R}^m) 的副本)。在严格时,我们应该将 (v) 视为一个线性映射 (v : \mathbb{R}^m \to \mathbb{R}),当我们写 (v \partial f(x)) 时,我们意味着函数复合 (v \circ \partial f(x)),其中类型之间的对应关系是因为 (\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m)。但在通常情况下,我们可以将 (v) 与 (\mathbb{R}^m) 中的一个向量等同看待,并几乎可以互换使用,就像有时我们可以在“列向量”和“行向量”之间轻松切换而不加过多评论一样。
有了这个认识,我们可以将 VJP 的线性部分看作是 JVP 线性部分的转置(或共轭伴随):
(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v).
对于给定点 (x),我们可以将签名写为
(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n).
对应的余切空间映射通常称为[拉回](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) (f) 在 (x) 处的。对我们而言,关键在于它从类似 (f) 输出的东西到类似 (f) 输入的东西,就像我们从一个转置线性函数所期望的那样。
JAX 代码中的 VJPs
从数学切换回 Python,JAX 函数 vjp
可以接受一个用于评估 (f) 的 Python 函数,并给我们返回一个用于评估 VJP ((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))) 的 Python 函数。
from jax import vjp
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
y, vjp_fun = vjp(f, W)
key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
就类似 Haskell 类型签名的形式来说,我们可以写成
vjp :: (a -> b) -> a -> (b, CT b -> CT a)
在这里,我们使用 CT a
表示 a
的余切空间的类型。换句话说,vjp
接受类型为 a -> b
的函数和类型为 a
的点作为参数,并返回一个由类型为 b
的值和类型为 CT b -> CT a
的线性映射组成的对。
这很棒,因为它让我们一次一行地构建雅可比矩阵,并且评估 ((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))) 的 FLOP 成本仅约为评估 (f) 的三倍。特别是,如果我们想要函数 (f : \mathbb{R}^n \to \mathbb{R}) 的梯度,我们可以一次性完成。这就是 grad
对基于梯度的优化非常高效的原因,即使是对于数百万或数十亿个参数的神经网络训练损失函数这样的目标。
这里有一个成本:虽然 FLOP 友好,但内存随计算深度而增加。而且,该实现在传统上比前向模式更为复杂,但 JAX 对此有一些窍门(这是未来笔记本的故事!)。
关于反向模式的工作原理,可以查看2017 年深度学习暑期学校的教程视频。
使用 VJPs 的矢量值梯度
如果你对使用矢量值梯度(如 tf.gradients
)感兴趣:
from jax import vjp
def vgrad(f, x):
y, vjp_fn = vjp(f, x)
return vjp_fn(jnp.ones(y.shape))[0]
print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2))))
[[6\. 6.]
[6\. 6.]]
使用前向和反向模式的黑塞矢量积
在前面的部分中,我们仅使用反向模式实现了一个黑塞-矢量积函数(假设具有连续二阶导数):
def hvp(f, x, v):
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
这是高效的,但我们甚至可以更好地节省一些内存,通过使用前向模式和反向模式。
从数学上讲,给定一个要区分的函数 (f : \mathbb{R}^n \to \mathbb{R}),要线性化函数的一个点 (x \in \mathbb{R}^n),以及一个向量 (v \in \mathbb{R}^n),我们想要的黑塞-矢量积函数是
((x, v) \mapsto \partial² f(x) v)
考虑助手函数 (g : \mathbb{R}^n \to \mathbb{R}^n) 定义为 (f) 的导数(或梯度),即 (g(x) = \partial f(x))。我们所需的只是它的 JVP,因为这将给我们
((x, v) \mapsto \partial g(x) v = \partial² f(x) v).
我们几乎可以直接将其转换为代码:
from jax import jvp, grad
# forward-over-reverse
def hvp(f, primals, tangents):
return jvp(grad(f), primals, tangents)[1]
更好的是,由于我们不需要直接调用 jnp.dot
,这个 hvp
函数可以处理任何形状的数组以及任意的容器类型(如嵌套列表/字典/元组中存储的向量),甚至与jax.numpy
没有任何依赖。
这是如何使用它的示例:
def f(X):
return jnp.sum(jnp.tanh(X)**2)
key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))
ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)
print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))
True
另一种你可能考虑写这个的方法是使用反向-前向模式:
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
g = lambda primals: jvp(f, primals, tangents)[1]
return grad(g)(primals)
不过,这不是很好,因为前向模式的开销比反向模式小,由于外部区分算子要区分比内部更大的计算,将前向模式保持在外部是最好的:
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
x, = primals
v, = tangents
return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))
print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)
Forward over reverse
4.74 ms ± 157 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
9.46 ms ± 5.05 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
14.3 ms ± 7.71 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
57.7 ms ± 1.32 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
组成 VJP、JVP 和 vmap
雅可比-矩阵和矩阵-雅可比乘积
现在我们有jvp
和vjp
变换,它们为我们提供了推送或拉回单个向量的函数,我们可以使用 JAX 的vmap
变换一次推送和拉回整个基。特别是,我们可以用它来快速编写矩阵-雅可比和雅可比-矩阵乘积。
# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)
# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
return jnp.vstack([vjp_fun(mi) for mi in M])
# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
y, vjp_fun = vjp(f, x)
outs, = vmap(vjp_fun)(M)
return outs
key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)
loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)
print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'
Non-vmapped Matrix-Jacobian product
168 ms ± 260 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Matrix-Jacobian product
6.39 ms ± 49.3 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
/tmp/ipykernel_1379/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0\. In a future JAX release this will be an error.
return jnp.vstack([vjp_fun(mi) for mi in M])
def loop_jmp(f, W, M):
# jvp immediately returns the primal and tangent values as a tuple,
# so we'll compute and select the tangents in a list comprehension
return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])
def vmap_jmp(f, W, M):
_jvp = lambda s: jvp(f, (W,), (s,))[1]
return vmap(_jvp)(M)
num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)
loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)
assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'
Non-vmapped Jacobian-Matrix product
290 ms ± 437 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Vmapped Jacobian-Matrix product
3.29 ms ± 22.5 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
jacfwd
和jacrev
的实现
现在我们已经看到了快速的雅可比-矩阵和矩阵-雅可比乘积,写出jacfwd
和jacrev
并不难。我们只需使用相同的技术一次推送或拉回整个标准基(等同于单位矩阵)。
from jax import jacrev as builtin_jacrev
def our_jacrev(f):
def jacfun(x):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
# entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'
from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'
有趣的是,Autograd做不到这一点。我们在 Autograd 中反向模式jacobian
的实现必须逐个向量地拉回,使用外层循环map
。逐个向量地通过计算远不及使用vmap
一次将所有内容批处理高效。
另一件 Autograd 做不到的事情是jit
。有趣的是,无论您在要进行微分的函数中使用多少 Python 动态性,我们总是可以在计算的线性部分上使用jit
。例如:
def f(x):
try:
if x < 3:
return 2 * x ** 3
else:
raise ValueError
except ValueError:
return jnp.pi * x
y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))
(Array(3.1415927, dtype=float32, weak_type=True),)
复数和微分
JAX 在复数和微分方面表现出色。为了支持全纯和非全纯微分,理解 JVP 和 VJP 很有帮助。
考虑一个复到复的函数 (f: \mathbb{C} \to \mathbb{C}) 并将其与相应的函数 (g: \mathbb{R}² \to \mathbb{R}²) 对应起来,
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def g(x, y):
return (u(x, y), v(x, y))
也就是说,我们分解了 (f(z) = u(x, y) + v(x, y) i) 其中 (z = x + y i),并将 (\mathbb{C}) 与 (\mathbb{R}²) 对应起来得到了 (g)。
由于 (g) 只涉及实数输入和输出,我们已经知道如何为它编写雅可比-向量积,例如给定切向量 ((c, d) \in \mathbb{R}²),
(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \ d \end{bmatrix}).
要获得应用于切向量 (c + di \in \mathbb{C}) 的原始函数 (f) 的 JVP,我们只需使用相同的定义,并将结果标识为另一个复数,
(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \ d \end{bmatrix}).
这就是我们对复到复函数 (f) 的 JVP 的定义!注意,无论 (f) 是否全纯,JVP 都是明确的。
这里是一个检查:
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# tangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_dot = c + d * 1j
# check jvp
_, ans = jvp(fun, (z,), (z_dot,))
expected = (grad(u, 0)(x, y) * c +
grad(u, 1)(x, y) * d +
grad(v, 0)(x, y) * c * 1j+
grad(v, 1)(x, y) * d * 1j)
print(jnp.allclose(ans, expected))
check(0)
check(1)
check(2)
True
True
True
那么 VJP 呢?我们做了类似的事情:对于余切向量 (c + di \in \mathbb{C}),我们将 (f) 的 VJP 定义为
((c + di)^* ; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \ -i \end{bmatrix}).
为什么要有负号?这些只是为了处理复共轭,以及我们正在处理余切向量的事实。
这里是 VJP 规则的检查:
def check(seed):
key = random.key(seed)
# random coeffs for u and v
key, subkey = random.split(key)
a, b, c, d = random.uniform(subkey, (4,))
def fun(z):
x, y = jnp.real(z), jnp.imag(z)
return u(x, y) + v(x, y) * 1j
def u(x, y):
return a * x + b * y
def v(x, y):
return c * x + d * y
# primal point
key, subkey = random.split(key)
x, y = random.uniform(subkey, (2,))
z = x + y * 1j
# cotangent vector
key, subkey = random.split(key)
c, d = random.uniform(subkey, (2,))
z_bar = jnp.array(c + d * 1j) # for dtype control
# check vjp
_, fun_vjp = vjp(fun, z)
ans, = fun_vjp(z_bar)
expected = (grad(u, 0)(x, y) * c +
grad(v, 0)(x, y) * (-d) +
grad(u, 1)(x, y) * c * (-1j) +
grad(v, 1)(x, y) * (-d) * (-1j))
assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)
check(0)
check(1)
check(2)
方便的包装器如grad
、jacfwd
和jacrev
有什么作用?
对于(\mathbb{R} \to \mathbb{R})函数,回想我们定义grad(f)(x)
为vjp(f, x)1
,这是因为将 VJP 应用于1.0
值会显示梯度(即雅可比矩阵或导数)。对于(\mathbb{C} \to \mathbb{R})函数,我们可以做同样的事情:我们仍然可以使用1.0
作为余切向量,而我们得到的只是总结完整雅可比矩阵的一个复数结果:
def f(z):
x, y = jnp.real(z), jnp.imag(z)
return x**2 + y**2
z = 3. + 4j
grad(f)(z)
Array(6.-8.j, dtype=complex64)
对于一般的(\mathbb{C} \to \mathbb{C})函数,雅可比矩阵有 4 个实值自由度(如上面的 2x2 雅可比矩阵),因此我们不能希望在一个复数中表示所有这些自由度。但对于全纯函数,我们可以!全纯函数恰好是一个(\mathbb{C} \to \mathbb{C})函数,其导数可以表示为一个单一的复数。(柯西-黎曼方程确保上述 2x2 雅可比矩阵在复平面内的作用具有复数乘法下的比例和旋转矩阵的特殊形式。)我们可以使用一个vjp
调用并带有1.0
的余切向量来揭示那一个复数。
因为这仅适用于全纯函数,为了使用这个技巧,我们需要向 JAX 保证我们的函数是全纯的;否则,在复数输出函数上使用grad
时,JAX 会引发错误:
def f(z):
return jnp.sin(z)
z = 3. + 4j
grad(f, holomorphic=True)(z)
Array(-27.034945-3.8511531j, dtype=complex64, weak_type=True)
holomorphic=True
的承诺仅仅是在输出是复数值时禁用错误。当函数不是全纯时,我们仍然可以写holomorphic=True
,但得到的答案将不表示完整的雅可比矩阵。相反,它将是在我们只丢弃输出的虚部的函数的雅可比矩阵。
def f(z):
return jnp.conjugate(z)
z = 3. + 4j
grad(f, holomorphic=True)(z) # f is not actually holomorphic!
Array(1.-0.j, dtype=complex64, weak_type=True)
在这里grad
的工作有一些有用的结论:
-
我们可以在全纯的(\mathbb{C} \to \mathbb{C})函数上使用
grad
。 -
我们可以使用
grad
来优化(\mathbb{C} \to \mathbb{R})函数,例如复参数x
的实值损失函数,通过朝着grad(f)(x)
的共轭方向迈出步伐。 -
如果我们有一个(\mathbb{R} \to \mathbb{R})的函数,它恰好在内部使用一些复数运算(其中一些必须是非全纯的,例如在卷积中使用的 FFT),那么
grad
仍然有效,并且我们得到与仅使用实数值的实现相同的结果。
在任何情况下,JVPs 和 VJPs 都是明确的。如果我们想计算非全纯函数(\mathbb{C} \to \mathbb{C})的完整 Jacobian 矩阵,我们可以用 JVPs 或 VJPs 来做到!
你应该期望复数在 JAX 中的任何地方都能正常工作。这里是通过复杂矩阵的 Cholesky 分解进行微分:
A = jnp.array([[5., 2.+3j, 5j],
[2.-3j, 7., 1.+7j],
[-5j, 1.-7j, 12.]])
def f(X):
L = jnp.linalg.cholesky(X)
return jnp.sum((L - jnp.sin(L))**2)
grad(f, holomorphic=True)(A)
Array([[-0.7534186 +0.j , -3.0509028 -10.940544j ,
5.9896846 +3.5423026j],
[-3.0509028 +10.940544j , -8.904491 +0.j ,
-5.1351523 -6.559373j ],
[ 5.9896846 -3.5423026j, -5.1351523 +6.559373j ,
0.01320427 +0.j ]], dtype=complex64)
更高级的自动微分
在这本笔记本中,我们通过一些简单的,然后逐渐复杂的应用中,使用 JAX 中的自动微分。我们希望现在您感觉在 JAX 中进行导数运算既简单又强大。
还有很多其他自动微分的技巧和功能。我们没有涵盖的主题,但希望在“高级自动微分手册”中进行涵盖:
-
高斯-牛顿向量乘积,一次线性化
-
自定义的 VJPs 和 JVPs
-
在固定点处高效地求导
-
使用随机的 Hessian-vector products 来估计 Hessian 的迹。
-
仅使用反向模式自动微分的前向模式自动微分。
-
对自定义数据类型进行导数计算。
-
检查点(二项式检查点用于高效的反向模式,而不是模型快照)。
-
优化 VJPs 通过 Jacobian 预积累。
JAX 可转换的 Python 函数的自定义导数规则
原文:
jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
mattjj@ Mar 19 2020, last updated Oct 14 2020
JAX 中定义微分规则的两种方式:
-
使用
jax.custom_jvp
和jax.custom_vjp
来为已经可转换为 JAX 的 Python 函数定义自定义微分规则;以及 -
定义新的
core.Primitive
实例及其所有转换规则,例如调用来自其他系统(如求解器、模拟器或一般数值计算系统)的函数。
本笔记本讨论的是 #1. 要了解关于 #2 的信息,请参阅关于添加原语的笔记本。
关于 JAX 自动微分 API 的介绍,请参阅自动微分手册。本笔记本假定读者已对jax.jvp和jax.grad,以及 JVPs 和 VJPs 的数学含义有一定了解。
TL;DR
使用 jax.custom_jvp
进行自定义 JVPs
import jax.numpy as jnp
from jax import custom_jvp
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
return primal_out, tangent_out
from jax import jvp, grad
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
# Equivalent alternative using the defjvps convenience wrapper
@custom_jvp
def f(x, y):
return jnp.sin(x) * y
f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot)
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.))
2.7278922
2.7278922
-1.2484405
-1.2484405
使用 jax.custom_vjp
进行自定义 VJPs
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res # Gets residuals computed in f_fwd
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
示例问题
要了解 jax.custom_jvp
和 jax.custom_vjp
所解决的问题,我们可以看几个例子。有关 jax.custom_jvp
和 jax.custom_vjp
API 的更详细介绍在下一节中。
数值稳定性
jax.custom_jvp
的一个应用是提高微分的数值稳定性。
假设我们想编写一个名为 log1pexp
的函数,用于计算 (x \mapsto \log ( 1 + e^x ))。我们可以使用 jax.numpy
来写:
import jax.numpy as jnp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp(3.)
Array(3.0485873, dtype=float32, weak_type=True)
因为它是用 jax.numpy
编写的,所以它是 JAX 可转换的:
from jax import jit, grad, vmap
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
但这里存在一个数值稳定性问题:
print(grad(log1pexp)(100.))
nan
那似乎不对!毕竟,(x \mapsto \log (1 + e^x)) 的导数是 (x \mapsto \frac{e^x}{1 + e^x}),因此对于大的 (x) 值,我们期望值约为 1。
通过查看梯度计算的 jaxpr,我们可以更深入地了解发生了什么:
from jax import make_jaxpr
make_jaxpr(grad(log1pexp))(100.)
{ lambda ; a:f32[]. let
b:f32[] = exp a
c:f32[] = add 1.0 b
_:f32[] = log c
d:f32[] = div 1.0 c
e:f32[] = mul d b
in (e,) }
通过分析 jaxpr 如何评估,我们可以看到最后一行涉及的值相乘会导致浮点数计算四舍五入为 0 和 (\infty),这从未是一个好主意。也就是说,我们实际上在评估大数值的情况下,计算的是 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x)
,这实际上会变成 0. * jnp.inf
。
而不是生成这样大和小的值,希望浮点数能够提供的取消,我们宁愿将导数函数表达为一个更稳定的数值程序。特别地,我们可以编写一个程序,更接近地评估相等的数学表达式 (1 - \frac{1}{1 + e^x}),看不到取消。
这个问题很有趣,因为即使我们的log1pexp
的定义已经可以进行 JAX 微分(并且可以使用jit
、vmap
等转换),我们对应用标准自动微分规则到组成log1pexp
并组合结果的结果并不满意。相反,我们想要指定整个函数log1pexp
如何作为一个单位进行微分,从而更好地安排这些指数。
这是关于 Python 函数的自定义导数规则的一个应用,这些函数已经可以使用 JAX 进行转换:指定如何对复合函数进行微分,同时仍然使用其原始的 Python 定义进行其他转换(如jit
、vmap
等)。
这里是使用jax.custom_jvp
的解决方案:
from jax import custom_jvp
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = log1pexp(x)
ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
return ans, ans_dot
print(grad(log1pexp)(100.))
1.0
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
这里是一个defjvps
方便包装,来表达同样的事情:
@custom_jvp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t)
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.)))
1.0
3.0485873
0.95257413
[0.5 0.7310586 0.8807971]
强制执行微分约定
一个相关的应用是强制执行微分约定,也许在边界处。
考虑函数 (f : \mathbb{R}+ \to \mathbb{R}+),其中 (f(x) = \frac{x}{1 + \sqrt{x}}),其中我们取 (\mathbb{R}_+ = [0, \infty))。我们可以像这样实现 (f) 的程序:
def f(x):
return x / (1 + jnp.sqrt(x))
作为在(\mathbb{R})上的数学函数(完整的实数线),(f) 在零点是不可微的(因为从左侧定义导数的极限不存在)。相应地,自动微分产生一个nan
值:
print(grad(f)(0.))
nan
但是数学上,如果我们将 (f) 视为 (\mathbb{R}_+) 上的函数,则它在 0 处是可微的 [Rudin 的《数学分析原理》定义 5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和例子 10.1.6]。或者,我们可能会说,作为一个惯例,我们希望考虑从右边的方向导数。因此,对于 Python 函数grad(f)
在0.0
处返回 1.0 是有意义的值。默认情况下,JAX 对微分的机制假设所有函数在(\mathbb{R})上定义,因此这里并不会产生1.0
。
我们可以使用自定义的 JVP 规则!特别地,我们可以定义 JVP 规则,关于导数函数 (x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)²}) 在 (\mathbb{R}_+) 上,
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
return ans, ans_dot
print(grad(f)(0.))
1.0
这里是方便包装版本:
@custom_jvp
def f(x):
return x / (1 + jnp.sqrt(x))
f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t)
print(grad(f)(0.))
1.0
梯度剪裁
虽然在某些情况下,我们想要表达一个数学微分计算,在其他情况下,我们甚至可能想要远离数学,来调整自动微分的计算。一个典型的例子是反向模式梯度剪裁。
对于梯度剪裁,我们可以使用jnp.clip
和一个jax.custom_vjp
仅逆模式规则:
from functools import partial
from jax import custom_vjp
@custom_vjp
def clip_gradient(lo, hi, x):
return x # identity function
def clip_gradient_fwd(lo, hi, x):
return x, (lo, hi) # save bounds as residuals
def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # use None to indicate zero cotangents for lo and hi
clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd)
import matplotlib.pyplot as plt
from jax import vmap
t = jnp.linspace(0, 10, 1000)
plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t))
[<matplotlib.lines.Line2D at 0x7f43dfc210f0>]
def clip_sin(x):
x = clip_gradient(-0.75, 0.75, x)
return jnp.sin(x)
plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t))
[<matplotlib.lines.Line2D at 0x7f43ddb15fc0>]
Python 调试
另一个应用,是受开发工作流程而非数值驱动的动机,是在反向模式自动微分的后向传递中设置pdb
调试器跟踪。
在尝试追踪nan
运行时错误的来源,或者仅仔细检查传播的余切(梯度)值时,可以在反向传递中的特定点插入调试器非常有用。您可以使用jax.custom_vjp
来实现这一点。
我们将在下一节中推迟一个示例。
迭代实现的隐式函数微分
这个例子涉及到了数学中的深层问题!
另一个应用jax.custom_vjp
是对可通过jit
、vmap
等转换为 JAX 但由于某些原因不易 JAX 可区分的函数进行反向模式微分,也许是因为涉及lax.while_loop
。(无法生成 XLA HLO 程序有效计算 XLA HLO While 循环的反向模式导数,因为这将需要具有无界内存使用的程序,这在 XLA HLO 中是不可能表达的,至少不是通过通过 infeed/outfeed 的副作用交互。)
例如,考虑这个fixed_point
例程,它通过在while_loop
中迭代应用函数来计算一个不动点:
from jax.lax import while_loop
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
这是一种通过迭代应用函数(x_{t+1} = f(a, x_t))来数值解方程(x = f(a, x))以计算(x)的迭代过程,直到(x_{t+1})足够接近(x_t)。结果(x^)取决于参数(a),因此我们可以认为存在一个由方程(x = f(a, x))隐式定义的函数(a \mapsto x^(a))。
我们可以使用fixed_point
运行迭代过程以收敛,例如运行牛顿法来计算平方根,只执行加法、乘法和除法:
def newton_sqrt(a):
update = lambda a, x: 0.5 * (x + a / x)
return fixed_point(update, a, a)
print(newton_sqrt(2.))
1.4142135
我们也可以对函数进行vmap
或jit
处理:
print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.])))
[1\. 1.4142135 1.7320509 2\. ]
由于while_loop
,我们无法应用反向模式自动微分,但事实证明我们也不想这样做:我们可以利用数学结构做一些更节省内存(在这种情况下也更节省 FLOP)的事情!我们可以使用隐函数定理[Bertsekas 的《非线性规划,第二版》附录 A.25],它保证(在某些条件下)我们即将使用的数学对象的存在。本质上,我们在线性化解决方案处进行线性化,并迭代解这些线性方程以计算我们想要的导数。
再次考虑方程(x = f(a, x))和函数(x^)。我们想要评估向量-Jacobian 乘积,如(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^(a_0))。
至少在我们想要求微分的点(a_0)周围的开放邻域内,让我们假设方程(x^(a) = f(a, x^(a)))对所有(a)都成立。由于两边作为(a)的函数相等,它们的导数也必须相等,所以让我们分别对两边进行微分:
(\qquad \partial x^(a) = \partial_0 f(a, x^(a)) + \partial_1 f(a, x^(a)) \partial x^(a))。
设置(A = \partial_1 f(a_0, x^(a_0)))和(B = \partial_0 f(a_0, x^(a_0))),我们可以更简单地写出我们想要的数量为
(\qquad \partial x^(a_0) = B + A \partial x^(a_0)),
或者,通过重新排列,
(\qquad \partial x^*(a_0) = (I - A)^{-1} B)。
这意味着我们可以评估向量-Jacobian 乘积,如
(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B),
其中(w^\mathsf{T} = v^\mathsf{T} (I - A){-1}),或者等效地(w\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A),或者等效地(w\mathsf{T})是映射(u\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A)的不动点。最后一个描述使我们可以根据对fixed_point
的调用来编写fixed_point
的 VJP!此外,在展开(A)和(B)之后,我们可以看到我们只需要在((a_0, x^*(a_0)))处评估(f)的 VJP。
这里是要点:
from jax import vjp
@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
def cond_fun(carry):
x_prev, x = carry
return jnp.abs(x_prev - x) > 1e-6
def body_fun(carry):
_, x = carry
return x, f(a, x)
_, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
return x_star
def fixed_point_fwd(f, a, x_init):
x_star = fixed_point(f, a, x_init)
return x_star, (a, x_star)
def fixed_point_rev(f, res, x_star_bar):
a, x_star = res
_, vjp_a = vjp(lambda a: f(a, x_star), a)
a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, jnp.zeros_like(x_star)
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
return x_star_bar + vjp_x(u)[0]
fixed_point.defvjp(fixed_point_fwd, fixed_point_rev)
print(newton_sqrt(2.))
1.4142135
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.))
0.35355338
-0.088388346
我们可以通过对 jnp.sqrt
进行微分来检查我们的答案,它使用了完全不同的实现:
print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.))
0.35355338
-0.08838835
这种方法的一个限制是参数f
不能涉及到任何参与微分的值。也就是说,你可能注意到我们在fixed_point
的参数列表中明确保留了参数a
。对于这种用例,考虑使用低级原语lax.custom_root
,它允许在闭合变量中进行带有自定义根查找函数的导数。
使用 jax.custom_jvp
和 jax.custom_vjp
API 的基本用法
使用 jax.custom_jvp
来定义前向模式(以及间接地,反向模式)规则
这里是使用 jax.custom_jvp
的典型基本示例,其中注释使用Haskell-like type signatures。
from jax import custom_jvp
import jax.numpy as jnp
# f :: a -> b
@custom_jvp
def f(x):
return jnp.sin(x)
# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
f.defjvp(f_jvp)
<function __main__.f_jvp(primals, tangents)>
from jax import jvp
print(f(3.))
y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot)
0.14112
0.14112
-0.9899925
简言之,我们从一个原始函数f
开始,它接受类型为a
的输入并产生类型为b
的输出。我们与之关联一个 JVP 规则函数f_jvp
,它接受一对输入,表示类型为a
的原始输入和类型为T a
的相应切线输入,并产生一对输出,表示类型为b
的原始输出和类型为T b
的切线输出。切线输出应该是切线输入的线性函数。
你还可以使用 f.defjvp
作为装饰器,就像这样
@custom_jvp
def f(x):
...
@f.defjvp
def f_jvp(primals, tangents):
...
尽管我们只定义了一个 JVP 规则而没有 VJP 规则,但我们可以在f
上同时使用正向和反向模式的微分。JAX 会自动将切线值上的线性计算从我们的自定义 JVP 规则转置,高效地计算出 VJP,就好像我们手工编写了规则一样。
from jax import grad
print(grad(f)(3.))
print(grad(grad(f))(3.))
-0.9899925
-0.14112
为了使自动转置工作,JVP 规则的输出切线必须是输入切线的线性函数。否则将引发转置错误。
多个参数的工作方式如下:
@custom_jvp
def f(x, y):
return x ** 2 * y
@f.defjvp
def f_jvp(primals, tangents):
x, y = primals
x_dot, y_dot = tangents
primal_out = f(x, y)
tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
return primal_out, tangent_out
print(grad(f)(2., 3.))
12.0
defjvps
便捷包装器允许我们为每个参数单独定义一个 JVP,并分别计算结果后进行求和:
@custom_jvp
def f(x):
return jnp.sin(x)
f.defjvps(lambda t, ans, x: jnp.cos(x) * t)
print(grad(f)(3.))
-0.9899925
下面是一个带有多个参数的defjvps
示例:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
lambda y_dot, primal_out, x, y: x ** 2 * y_dot)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
4.0
简而言之,使用defjvps
,您可以传递None
值来指示特定参数的 JVP 为零:
@custom_jvp
def f(x, y):
return x ** 2 * y
f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
None)
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.)) # same as above
print(grad(f, 1)(2., 3.))
12.0
12.0
0.0
使用关键字参数调用jax.custom_jvp
函数,或者编写具有默认参数的jax.custom_jvp
函数定义,只要能够根据通过标准库inspect.signature
机制检索到的函数签名映射到位置参数即可。
当您不执行微分时,函数f
的调用方式与未被jax.custom_jvp
修饰时完全一样:
@custom_jvp
def f(x):
print('called f!') # a harmless side-effect
return jnp.sin(x)
@f.defjvp
def f_jvp(primals, tangents):
print('called f_jvp!') # a harmless side-effect
x, = primals
t, = tangents
return f(x), jnp.cos(x) * t
from jax import vmap, jit
print(f(3.))
called f!
0.14112
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.))
called f!
[0\. 0.84147096 0.9092974 ]
called f!
0.14112
自定义的 JVP 规则在微分过程中被调用,无论是正向还是反向:
y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot)
called f_jvp!
called f!
-0.9899925
print(grad(f)(3.))
called f_jvp!
called f!
-0.9899925
注意,f_jvp
调用f
来计算原始输出。在高阶微分的上下文中,每个微分变换的应用将只在规则调用原始f
来计算原始输出时使用自定义的 JVP 规则。(这代表一种基本的权衡,我们不能同时利用f
的评估中间值来制定规则并且使规则在所有高阶微分顺序中应用。)
grad(grad(f))(3.)
called f_jvp!
called f_jvp!
called f!
Array(-0.14112, dtype=float32, weak_type=True)
您可以使用 Python 控制流来使用jax.custom_jvp
:
@custom_jvp
def f(x):
if x > 0:
return jnp.sin(x)
else:
return jnp.cos(x)
@f.defjvp
def f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
ans = f(x)
if x > 0:
return ans, 2 * x_dot
else:
return ans, 3 * x_dot
print(grad(f)(1.))
print(grad(f)(-1.))
2.0
3.0
使用jax.custom_vjp
来定义自定义的仅反向模式规则
虽然jax.custom_jvp
足以控制正向和通过 JAX 的自动转置控制反向模式微分行为,但在某些情况下,我们可能希望直接控制 VJP 规则,例如在上述后两个示例问题中。我们可以通过jax.custom_vjp
来实现这一点。
from jax import custom_vjp
import jax.numpy as jnp
# f :: a -> b
@custom_vjp
def f(x):
return jnp.sin(x)
# f_fwd :: a -> (b, c)
def f_fwd(x):
return f(x), jnp.cos(x)
# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
from jax import grad
print(f(3.))
print(grad(f)(3.))
0.14112
-0.9899925
换句话说,我们再次从接受类型为a
的输入并产生类型为b
的输出的原始函数f
开始。我们将与之关联两个函数f_fwd
和f_bwd
,它们描述了如何执行反向模式自动微分的正向和反向传递。
函数f_fwd
描述了前向传播,不仅包括原始计算,还包括要保存以供后向传播使用的值。其输入签名与原始函数f
完全相同,即它接受类型为a
的原始输入。但作为输出,它产生一对值,其中第一个元素是原始输出b
,第二个元素是类型为c
的任何“残余”数据,用于后向传播时存储。(这第二个输出类似于PyTorch 的 save_for_backward 机制。)
函数f_bwd
描述了反向传播。它接受两个输入,第一个是由f_fwd
生成的类型为c
的残差数据,第二个是对应于原始函数输出的类型为CT b
的输出共切线。它生成一个类型为CT a
的输出,表示原始函数输入对应的共切线。特别地,f_bwd
的输出必须是长度等于原始函数参数个数的序列(例如元组)。
多个参数的工作方式如下:
from jax import custom_vjp
@custom_vjp
def f(x, y):
return jnp.sin(x) * y
def f_fwd(x, y):
return f(x, y), (jnp.cos(x), jnp.sin(x), y)
def f_bwd(res, g):
cos_x, sin_x, y = res
return (cos_x * g * y, sin_x * g)
f.defvjp(f_fwd, f_bwd)
print(grad(f)(2., 3.))
-1.2484405
调用带有关键字参数的jax.custom_vjp
函数,或者编写带有默认参数的jax.custom_vjp
函数定义,只要可以根据标准库inspect.signature
机制清晰地映射到位置参数即可。
与jax.custom_jvp
类似,如果没有应用微分,则不会调用由f_fwd
和f_bwd
组成的自定义 VJP 规则。如果对函数进行评估,或者使用jit
、vmap
或其他非微分变换进行转换,则只调用f
。
@custom_vjp
def f(x):
print("called f!")
return jnp.sin(x)
def f_fwd(x):
print("called f_fwd!")
return f(x), jnp.cos(x)
def f_bwd(cos_x, y_bar):
print("called f_bwd!")
return (cos_x * y_bar,)
f.defvjp(f_fwd, f_bwd)
print(f(3.))
called f!
0.14112
print(grad(f)(3.))
called f_fwd!
called f!
called f_bwd!
-0.9899925
from jax import vjp
y, f_vjp = vjp(f, 3.)
print(y)
called f_fwd!
called f!
0.14112
print(f_vjp(1.))
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),)
无法在 jax.custom_vjp
函数上使用前向模式自动微分,否则会引发错误:
from jax import jvp
try:
jvp(f, (3.,), (1.,))
except TypeError as e:
print('ERROR! {}'.format(e))
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function.
如果希望同时使用前向和反向模式,请使用jax.custom_jvp
。
我们可以使用jax.custom_vjp
与pdb
一起在反向传播中插入调试器跟踪:
import pdb
@custom_vjp
def debug(x):
return x # acts like identity
def debug_fwd(x):
return x, x
def debug_bwd(x, g):
import pdb; pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)
def foo(x):
y = x ** 2
y = debug(y) # insert pdb in corresponding backward pass step
return jnp.sin(y)
jax.grad(foo)(3.)
> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q
更多特性和细节
使用list
/ tuple
/ dict
容器(和其他 pytree)
你应该期望标准的 Python 容器如列表、元组、命名元组和字典可以正常工作,以及这些容器的嵌套版本。总体而言,任何pytrees都是允许的,只要它们的结构符合类型约束。
这里有一个使用jax.custom_jvp
的构造示例:
from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])
@custom_jvp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
@f.defjvp
def f_jvp(primals, tangents):
pt, = primals
pt_dot, = tangents
ans = f(pt)
ans_dot = {'a': 2 * pt.x * pt_dot.x,
'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
return ans, ans_dot
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True))
还有一个类似的使用jax.custom_vjp
的构造示例:
@custom_vjp
def f(pt):
x, y = pt.x, pt.y
return {'a': x ** 2,
'b': (jnp.sin(x), jnp.cos(y))}
def f_fwd(pt):
return f(pt), pt
def f_bwd(pt, g):
a_bar, (b0_bar, b1_bar) = g['a'], g['b']
x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
y_bar = -jnp.sin(pt.y) * b1_bar
return (Point(x_bar, y_bar),)
f.defvjp(f_fwd, f_bwd)
def fun(pt):
dct = f(pt)
return dct['a'] + dct['b'][0]
pt = Point(1., 2.)
print(f(pt))
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))}
print(grad(fun)(pt))
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True))
处理非可微参数
一些用例,如最后的示例问题,需要将非可微参数(如函数值参数)传递给具有自定义微分规则的函数,并且这些参数也需要传递给规则本身。在fixed_point
的情况下,函数参数f
就是这样一个非可微参数。类似的情况在jax.experimental.odeint
中也会出现。
jax.custom_jvp
与nondiff_argnums
使用可选的 nondiff_argnums
参数来指示类似这些的参数给 jax.custom_jvp
。以下是一个带有 jax.custom_jvp
的例子:
from functools import partial
@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
@app.defjvp
def app_jvp(f, primals, tangents):
x, = primals
x_dot, = tangents
return f(x), 2. * x_dot
print(app(lambda x: x ** 3, 3.))
27.0
print(grad(app, 1)(lambda x: x ** 3, 3.))
2.0
注意这里的陷阱:无论这些参数在参数列表的哪个位置出现,它们都放置在相应 JVP 规则签名的起始位置。这里有另一个例子:
@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
return f(g((x)))
@app2.defjvp
def app2_jvp(f, g, primals, tangents):
x, = primals
x_dot, = tangents
return f(g(x)), 3. * x_dot
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y))
3375.0
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y))
3.0
nondiff_argnums
与 jax.custom_vjp
对于 jax.custom_vjp
也有类似的选项,类似地,非可微参数的约定是它们作为 _bwd
规则的第一个参数传递,无论它们出现在原始函数签名的哪个位置。 _fwd
规则的签名保持不变 - 它与原始函数的签名相同。以下是一个例子:
@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
return f(x)
def app_fwd(f, x):
return f(x), x
def app_bwd(f, x, g):
return (5 * g,)
app.defvjp(app_fwd, app_bwd)
print(app(lambda x: x ** 2, 4.))
16.0
print(grad(app, 1)(lambda x: x ** 2, 4.))
5.0
请参见上面的 fixed_point
以获取另一个用法示例。
对于具有整数 dtype 的数组值参数,不需要使用 nondiff_argnums
**。相反,nondiff_argnums
应仅用于不对应 JAX 类型(实质上不对应数组类型)的参数值,如 Python 可调用对象或字符串。如果 JAX 检测到由 nondiff_argnums
指示的参数包含 JAX Tracer,则会引发错误。上面的 clip_gradient
函数是不使用 nondiff_argnums
处理整数 dtype 数组参数的良好示例。