首页 > 其他分享 >JAX-中文文档-八-

JAX-中文文档-八-

时间:2024-06-21 14:24:31浏览次数:10  
标签:中文 jnp return JAX jvp 文档 print grad def

JAX 中文文档(八)

原文:jax.readthedocs.io/en/latest/

自动微分手册

原文:jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

在 Colab 中打开 在 Kaggle 中打开

alexbw@, mattjj@

JAX 拥有非常通用的自动微分系统。在这本手册中,我们将介绍许多巧妙的自动微分思想,您可以根据自己的工作进行选择。

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.key(0) 

梯度

grad开始

您可以使用grad对函数进行微分:

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0)) 
0.070650816 

grad接受一个函数并返回一个函数。如果您有一个评估数学函数 ( f ) 的 Python 函数 f,那么 grad(f) 是一个评估数学函数 ( \nabla f ) 的 Python 函数。这意味着 grad(f)(x) 表示值 ( \nabla f(x) )。

由于grad操作函数,您可以将其应用于其自身的输出以多次进行微分:

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0)) 
-0.13621868
0.25265405 

让我们看看如何在线性逻辑回归模型中使用grad计算梯度。首先,设置:

def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ()) 

使用argnums参数的grad函数来相对于位置参数微分函数。

# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad) 
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245 

grad API 直接对应于 Spivak 经典著作Calculus on Manifolds(1965)中的优秀符号,也用于 Sussman 和 Wisdom 的Structure and Interpretation of Classical Mechanics(2015)及其Functional Differential Geometry(2013)。这两本书都是开放获取的。特别是参见Functional Differential Geometry的“序言”部分,以了解此符号的辩护。

当使用argnums参数时,如果f是一个用于计算数学函数 ( f ) 的 Python 函数,则 Python 表达式grad(f, i)用于评估 ( \partial_i f ) 的 Python 函数。

相对于嵌套列表、元组和字典进行微分

使用标准的 Python 容器进行微分是完全有效的,因此可以随意使用元组、列表和字典(以及任意嵌套)。

def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b})) 
{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)} 

您可以注册您自己的容器类型以便不仅与grad一起工作,还可以与所有 JAX 转换(jitvmap等)一起工作。

使用value_and_grad评估函数及其梯度

另一个方便的函数是value_and_grad,可以高效地计算函数值及其梯度值:

from jax import value_and_grad
loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b)) 
loss value 3.0519385
loss value 3.0519385 

与数值差分进行对比

导数的一个很好的特性是它们很容易用有限差分进行检查:

# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec)) 
b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117 

JAX 提供了一个简单的便利函数,本质上执行相同的操作,但可以检查任何您喜欢的微分顺序:

from jax.test_util import check_grads
check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives 

使用 grad-of-grad 进行 Hessian 向量乘积

使用高阶 grad 可以构建一个 Hessian 向量乘积函数。 (稍后我们将编写一个更高效的实现,该实现混合了前向和反向模式,但这个实现将纯粹使用反向模式。)

在最小化平滑凸函数的截断牛顿共轭梯度算法或研究神经网络训练目标的曲率(例如1234)中,Hessian 向量乘积函数非常有用。

对于一个标量值函数 ( f : \mathbb{R}^n \to \mathbb{R} ),具有连续的二阶导数(因此 Hessian 矩阵是对称的),点 ( x \in \mathbb{R}^n ) 处的 Hessian 被写为 (\partial² f(x))。然后,Hessian 向量乘积函数能够评估

(\qquad v \mapsto \partial² f(x) \cdot v)

对于任意 ( v \in \mathbb{R}^n )。

窍门在于不要实例化完整的 Hessian 矩阵:如果 ( n ) 很大,例如在神经网络的背景下可能是百万或十亿级别,那么可能无法存储。

幸运的是,grad 已经为我们提供了一种编写高效的 Hessian 向量乘积函数的方法。我们只需使用下面的身份证

(\qquad \partial² f (x) v = \partial [x \mapsto \partial f(x) \cdot v] = \partial g(x)),

其中 ( g(x) = \partial f(x) \cdot v ) 是一个新的标量值函数,它将 ( f ) 在 ( x ) 处的梯度与向量 ( v ) 点乘。请注意,我们只对向量值参数的标量值函数进行微分,这正是我们知道 grad 高效的地方。

在 JAX 代码中,我们可以直接写成这样:

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) 

这个例子表明,您可以自由使用词汇闭包,而 JAX 绝不会感到不安或困惑。

一旦我们看到如何计算密集的 Hessian 矩阵,我们将在几个单元格下检查此实现。我们还将编写一个更好的版本,该版本同时使用前向模式和反向模式。

使用 jacfwdjacrev 计算 Jacobians 和 Hessians

您可以使用 jacfwdjacrev 函数计算完整的 Jacobian 矩阵:

from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J) 
jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]] 

这两个函数计算相同的值(直到机器数学),但它们在实现上有所不同:jacfwd 使用前向模式自动微分,对于“高”的 Jacobian 矩阵更有效,而 jacrev 使用反向模式,对于“宽”的 Jacobian 矩阵更有效。对于接近正方形的矩阵,jacfwd 可能比 jacrev 有优势。

您还可以在容器类型中使用 jacfwdjacrev

def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v) 
Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771] 

关于前向模式和反向模式的更多细节,以及如何尽可能高效地实现 jacfwdjacrev,请继续阅读!

使用两个这些函数的复合给我们一种计算密集的 Hessian 矩阵的方法:

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H) 
hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]] 

这种形状是合理的:如果我们从一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m) 开始,那么在点 (x \in \mathbb{R}^n) 我们期望得到以下形状

  • (f(x) \in \mathbb{R}^m),在 (x) 处的 (f) 的值,

  • (\partial f(x) \in \mathbb{R}^{m \times n}),在 (x) 处的雅可比矩阵,

  • (\partial² f(x) \in \mathbb{R}^{m \times n \times n}),在 (x) 处的 Hessian 矩阵,

以及其他一些内容。

要实现 hessian,我们可以使用 jacfwd(jacrev(f))jacrev(jacfwd(f)) 或这两者的任何组合。但是前向超过反向通常是最有效的。这是因为在内部雅可比计算中,我们通常是在不同 iating 一个函数宽雅可比(也许像损失函数 (f : \mathbb{R}^n \to \mathbb{R})),而在外部雅可比计算中,我们是在不同 iating 具有方雅可比的函数(因为 (\nabla f : \mathbb{R}^n \to \mathbb{R}^n)),这就是前向模式胜出的地方。

制造过程:两个基础的自动微分函数

雅可比-向量积(JVPs,也称为前向模式自动微分)

JAX 包括前向模式和反向模式自动微分的高效和通用实现。熟悉的 grad 函数建立在反向模式之上,但要解释两种模式的区别,以及每种模式何时有用,我们需要一些数学背景。

数学中的雅可比向量积

在数学上,给定一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m),在输入点 (x \in \mathbb{R}^n) 处评估的雅可比矩阵 (\partial f(x)),通常被视为一个 (\mathbb{R}^m \times \mathbb{R}^n) 中的矩阵:

(\qquad \partial f(x) \in \mathbb{R}^{m \times n}).

但我们也可以将 (\partial f(x)) 看作是一个线性映射,它将 (f) 的定义域在点 (x) 的切空间(即另一个 (\mathbb{R}^n) 的副本)映射到 (f) 的值域在点 (f(x)) 的切空间(一个 (\mathbb{R}^m) 的副本):

(\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m).

此映射称为 (f) 在 (x) 处的推前映射。雅可比矩阵只是标准基中这个线性映射的矩阵。

如果我们不确定一个特定的输入点 (x),那么我们可以将函数 (\partial f) 视为首先接受一个输入点并返回该输入点处的雅可比线性映射:

(\qquad \partial f : \mathbb{R}^n \to \mathbb{R}^n \to \mathbb{R}^m)。

特别是,我们可以解开事物,这样给定输入点 (x \in \mathbb{R}^n) 和切向量 (v \in \mathbb{R}^n),我们得到一个输出切向量在 (\mathbb{R}^m) 中。我们称这种映射,从 ((x, v)) 对到输出切向量,为雅可比向量积,并将其写为

(\qquad (x, v) \mapsto \partial f(x) v)

JAX 代码中的雅可比向量积

回到 Python 代码中,JAX 的 jvp 函数模拟了这种转换。给定一个评估 (f) 的 Python 函数,JAX 的 jvp 是获取评估 ((x, v) \mapsto (f(x), \partial f(x) v)) 的 Python 函数的一种方法。

from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,)) 

用类似 Haskell 的类型签名来说,我们可以写成

jvp  ::  (a  ->  b)  ->  a  ->  T  a  ->  (b,  T  b) 

在这里,我们使用 T a 来表示 a 的切线空间的类型。简言之,jvp 接受一个类型为 a -> b 的函数作为参数,一个类型为 a 的值,以及一个类型为 T a 的切线向量值。它返回一个由类型为 b 的值和类型为 T b 的输出切线向量组成的对。

jvp 转换后的函数的评估方式与原函数类似,但与每个类型为 a 的原始值配对时,它会沿着类型为 T a 的切线值进行推进。对于原始函数将应用的每个原始数值操作,jvp 转换后的函数会执行一个“JVP 规则”,该规则同时在这些原始值上评估原始数值,并应用其 JVP。

该评估策略对计算复杂度有一些直接影响:由于我们在进行评估时同时评估 JVP,因此我们不需要为以后存储任何内容,因此内存成本与计算深度无关。此外,jvp 转换后的函数的 FLOP 成本约为评估函数的成本的 3 倍(例如对于评估原始函数的一个单位工作,如 sin(x);一个单位用于线性化,如 cos(x);和一个单位用于将线性化函数应用于向量,如 cos_x * v)。换句话说,对于固定的原始点 (x),我们可以以大致相同的边际成本评估 (v \mapsto \partial f(x) \cdot v),如同评估 (f) 一样。

那么内存复杂度听起来非常有说服力!那为什么我们在机器学习中很少见到正向模式呢?

要回答这个问题,首先考虑如何使用 JVP 构建完整的 Jacobian 矩阵。如果我们将 JVP 应用于一个单位切线向量,它会显示出我们输入的非零条目对应的 Jacobian 矩阵的一列。因此,我们可以逐列地构建完整的 Jacobian 矩阵,获取每列的成本大约与一个函数评估相同。对于具有“高”Jacobian 的函数来说,这将是高效的,但对于“宽”Jacobian 来说则效率低下。

如果你在机器学习中进行基于梯度的优化,你可能想要最小化一个从 (\mathbb{R}^n) 中的参数到 (\mathbb{R}) 中标量损失值的损失函数。这意味着这个函数的雅可比矩阵是一个非常宽的矩阵:(\partial f(x) \in \mathbb{R}^{1 \times n}),我们通常将其视为梯度向量 (\nabla f(x) \in \mathbb{R}^n)。逐列构建这个矩阵,每次调用需要类似数量的浮点运算来评估原始函数,看起来确实效率低下!特别是对于训练神经网络,其中 (f) 是一个训练损失函数,而 (n) 可以是百万或十亿级别,这种方法根本不可扩展。

为了更好地处理这类函数,我们只需要使用反向模式。### 向量-雅可比积(VJPs,又称反向自动微分)

在前向模式中,我们得到了一个用于评估雅可比向量积的函数,然后我们可以使用它逐列构建雅可比矩阵;而反向模式则是一种获取用于评估向量-雅可比积(或等效地雅可比-转置向量积)的函数的方式,我们可以用它逐行构建雅可比矩阵。

数学中的 VJPs

再次考虑一个函数 (f : \mathbb{R}^n \to \mathbb{R}^m)。从我们对 JVP 的表示开始,对于 VJP 的表示非常简单:

(\qquad (x, v) \mapsto v \partial f(x)),

其中 (v) 是在 (x) 处 (f) 的余切空间的元素(同构于另一个 (\mathbb{R}^m) 的副本)。在严格时,我们应该将 (v) 视为一个线性映射 (v : \mathbb{R}^m \to \mathbb{R}),当我们写 (v \partial f(x)) 时,我们意味着函数复合 (v \circ \partial f(x)),其中类型之间的对应关系是因为 (\partial f(x) : \mathbb{R}^n \to \mathbb{R}^m)。但在通常情况下,我们可以将 (v) 与 (\mathbb{R}^m) 中的一个向量等同看待,并几乎可以互换使用,就像有时我们可以在“列向量”和“行向量”之间轻松切换而不加过多评论一样。

有了这个认识,我们可以将 VJP 的线性部分看作是 JVP 线性部分的转置(或共轭伴随):

(\qquad (x, v) \mapsto \partial f(x)^\mathsf{T} v).

对于给定点 (x),我们可以将签名写为

(\qquad \partial f(x)^\mathsf{T} : \mathbb{R}^m \to \mathbb{R}^n).

对应的余切空间映射通常称为[拉回](https://en.wikipedia.org/wiki/Pullback_(differential_geometry)) (f) 在 (x) 处的。对我们而言,关键在于它从类似 (f) 输出的东西到类似 (f) 输入的东西,就像我们从一个转置线性函数所期望的那样。

JAX 代码中的 VJPs

从数学切换回 Python,JAX 函数 vjp 可以接受一个用于评估 (f) 的 Python 函数,并给我们返回一个用于评估 VJP ((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))) 的 Python 函数。

from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u) 

类似 Haskell 类型签名的形式来说,我们可以写成

vjp  ::  (a  ->  b)  ->  a  ->  (b,  CT  b  ->  CT  a) 

在这里,我们使用 CT a 表示 a 的余切空间的类型。换句话说,vjp 接受类型为 a -> b 的函数和类型为 a 的点作为参数,并返回一个由类型为 b 的值和类型为 CT b -> CT a 的线性映射组成的对。

这很棒,因为它让我们一次一行地构建雅可比矩阵,并且评估 ((x, v) \mapsto (f(x), v^\mathsf{T} \partial f(x))) 的 FLOP 成本仅约为评估 (f) 的三倍。特别是,如果我们想要函数 (f : \mathbb{R}^n \to \mathbb{R}) 的梯度,我们可以一次性完成。这就是 grad 对基于梯度的优化非常高效的原因,即使是对于数百万或数十亿个参数的神经网络训练损失函数这样的目标。

这里有一个成本:虽然 FLOP 友好,但内存随计算深度而增加。而且,该实现在传统上比前向模式更为复杂,但 JAX 对此有一些窍门(这是未来笔记本的故事!)。

关于反向模式的工作原理,可以查看2017 年深度学习暑期学校的教程视频

使用 VJPs 的矢量值梯度

如果你对使用矢量值梯度(如 tf.gradients)感兴趣:

from jax import vjp

def vgrad(f, x):
  y, vjp_fn = vjp(f, x)
  return vjp_fn(jnp.ones(y.shape))[0]

print(vgrad(lambda x: 3*x**2, jnp.ones((2, 2)))) 
[[6\. 6.]
 [6\. 6.]] 

使用前向和反向模式的黑塞矢量积

在前面的部分中,我们仅使用反向模式实现了一个黑塞-矢量积函数(假设具有连续二阶导数):

def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x) 

这是高效的,但我们甚至可以更好地节省一些内存,通过使用前向模式和反向模式。

从数学上讲,给定一个要区分的函数 (f : \mathbb{R}^n \to \mathbb{R}),要线性化函数的一个点 (x \in \mathbb{R}^n),以及一个向量 (v \in \mathbb{R}^n),我们想要的黑塞-矢量积函数是

((x, v) \mapsto \partial² f(x) v)

考虑助手函数 (g : \mathbb{R}^n \to \mathbb{R}^n) 定义为 (f) 的导数(或梯度),即 (g(x) = \partial f(x))。我们所需的只是它的 JVP,因为这将给我们

((x, v) \mapsto \partial g(x) v = \partial² f(x) v).

我们几乎可以直接将其转换为代码:

from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1] 

更好的是,由于我们不需要直接调用 jnp.dot,这个 hvp 函数可以处理任何形状的数组以及任意的容器类型(如嵌套列表/字典/元组中存储的向量),甚至与jax.numpy 没有任何依赖。

这是如何使用它的示例:

def f(X):
  return jnp.sum(jnp.tanh(X)**2)

key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X,), (V,))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4)) 
True 

另一种你可能考虑写这个的方法是使用反向-前向模式:

# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
  g = lambda primals: jvp(f, primals, tangents)[1]
  return grad(g)(primals) 

不过,这不是很好,因为前向模式的开销比反向模式小,由于外部区分算子要区分比内部更大的计算,将前向模式保持在外部是最好的:

# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2) 
Forward over reverse
4.74 ms ± 157 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
9.46 ms ± 5.05 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
14.3 ms ± 7.71 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
57.7 ms ± 1.32 ms per loop (mean ± std. dev. of 3 runs, 10 loops each) 

组成 VJP、JVP 和 vmap

雅可比-矩阵和矩阵-雅可比乘积

现在我们有jvpvjp变换,它们为我们提供了推送或拉回单个向量的函数,我们可以使用 JAX 的vmap 变换一次推送和拉回整个基。特别是,我们可以用它来快速编写矩阵-雅可比和雅可比-矩阵乘积。

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.key(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical' 
Non-vmapped Matrix-Jacobian product
168 ms ± 260 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
6.39 ms ± 49.3 μs per loop (mean ± std. dev. of 3 runs, 10 loops each) 
/tmp/ipykernel_1379/3769736790.py:8: DeprecationWarning: vstack requires ndarray or scalar arguments, got <class 'tuple'> at position 0\. In a future JAX release this will be an error.
  return jnp.vstack([vjp_fun(mi) for mi in M]) 
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical' 
Non-vmapped Jacobian-Matrix product
290 ms ± 437 μs per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
3.29 ms ± 22.5 μs per loop (mean ± std. dev. of 3 runs, 10 loops each) 

jacfwdjacrev的实现

现在我们已经看到了快速的雅可比-矩阵和矩阵-雅可比乘积,写出jacfwdjacrev并不难。我们只需使用相同的技术一次推送或拉回整个标准基(等同于单位矩阵)。

from jax import jacrev as builtin_jacrev

def our_jacrev(f):
    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once. 
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J
    return jacfun

assert jnp.allclose(builtin_jacrev(f)(W), our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!' 
from jax import jacfwd as builtin_jacfwd

def our_jacfwd(f):
    def jacfun(x):
        _jvp = lambda s: jvp(f, (x,), (s,))[1]
        Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)
    return jacfun

assert jnp.allclose(builtin_jacfwd(f)(W), our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!' 

有趣的是,Autograd做不到这一点。我们在 Autograd 中反向模式jacobian实现必须逐个向量地拉回,使用外层循环map。逐个向量地通过计算远不及使用vmap一次将所有内容批处理高效。

另一件 Autograd 做不到的事情是jit。有趣的是,无论您在要进行微分的函数中使用多少 Python 动态性,我们总是可以在计算的线性部分上使用jit。例如:

def f(x):
    try:
        if x < 3:
            return 2 * x ** 3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x

y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.)) 
(Array(3.1415927, dtype=float32, weak_type=True),) 

复数和微分

JAX 在复数和微分方面表现出色。为了支持全纯和非全纯微分,理解 JVP 和 VJP 很有帮助。

考虑一个复到复的函数 (f: \mathbb{C} \to \mathbb{C}) 并将其与相应的函数 (g: \mathbb{R}² \to \mathbb{R}²) 对应起来,

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return u(x, y) + v(x, y) * 1j

def g(x, y):
  return (u(x, y), v(x, y)) 

也就是说,我们分解了 (f(z) = u(x, y) + v(x, y) i) 其中 (z = x + y i),并将 (\mathbb{C}) 与 (\mathbb{R}²) 对应起来得到了 (g)。

由于 (g) 只涉及实数输入和输出,我们已经知道如何为它编写雅可比-向量积,例如给定切向量 ((c, d) \in \mathbb{R}²),

(\begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \ d \end{bmatrix}).

要获得应用于切向量 (c + di \in \mathbb{C}) 的原始函数 (f) 的 JVP,我们只需使用相同的定义,并将结果标识为另一个复数,

(\partial f(x + y i)(c + d i) = \begin{matrix} \begin{bmatrix} 1 & i \end{bmatrix} \ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} c \ d \end{bmatrix}).

这就是我们对复到复函数 (f) 的 JVP 的定义!注意,无论 (f) 是否全纯,JVP 都是明确的。

这里是一个检查:

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # tangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_dot = c + d * 1j

  # check jvp
  _, ans = jvp(fun, (z,), (z_dot,))
  expected = (grad(u, 0)(x, y) * c +
              grad(u, 1)(x, y) * d +
              grad(v, 0)(x, y) * c * 1j+
              grad(v, 1)(x, y) * d * 1j)
  print(jnp.allclose(ans, expected)) 
check(0)
check(1)
check(2) 
True
True
True 

那么 VJP 呢?我们做了类似的事情:对于余切向量 (c + di \in \mathbb{C}),我们将 (f) 的 VJP 定义为

((c + di)^* ; \partial f(x + y i) = \begin{matrix} \begin{bmatrix} c & -d \end{bmatrix} \ ~ \end{matrix} \begin{bmatrix} \partial_0 u(x, y) & \partial_1 u(x, y) \ \partial_0 v(x, y) & \partial_1 v(x, y) \end{bmatrix} \begin{bmatrix} 1 \ -i \end{bmatrix}).

为什么要有负号?这些只是为了处理复共轭,以及我们正在处理余切向量的事实。

这里是 VJP 规则的检查:

def check(seed):
  key = random.key(seed)

  # random coeffs for u and v
  key, subkey = random.split(key)
  a, b, c, d = random.uniform(subkey, (4,))

  def fun(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j

  def u(x, y):
    return a * x + b * y

  def v(x, y):
    return c * x + d * y

  # primal point
  key, subkey = random.split(key)
  x, y = random.uniform(subkey, (2,))
  z = x + y * 1j

  # cotangent vector
  key, subkey = random.split(key)
  c, d = random.uniform(subkey, (2,))
  z_bar = jnp.array(c + d * 1j)  # for dtype control

  # check vjp
  _, fun_vjp = vjp(fun, z)
  ans, = fun_vjp(z_bar)
  expected = (grad(u, 0)(x, y) * c +
              grad(v, 0)(x, y) * (-d) +
              grad(u, 1)(x, y) * c * (-1j) +
              grad(v, 1)(x, y) * (-d) * (-1j))
  assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5) 
check(0)
check(1)
check(2) 

方便的包装器如gradjacfwdjacrev有什么作用?

对于(\mathbb{R} \to \mathbb{R})函数,回想我们定义grad(f)(x)vjp(f, x)1,这是因为将 VJP 应用于1.0值会显示梯度(即雅可比矩阵或导数)。对于(\mathbb{C} \to \mathbb{R})函数,我们可以做同样的事情:我们仍然可以使用1.0作为余切向量,而我们得到的只是总结完整雅可比矩阵的一个复数结果:

def f(z):
  x, y = jnp.real(z), jnp.imag(z)
  return x**2 + y**2

z = 3. + 4j
grad(f)(z) 
Array(6.-8.j, dtype=complex64) 

对于一般的(\mathbb{C} \to \mathbb{C})函数,雅可比矩阵有 4 个实值自由度(如上面的 2x2 雅可比矩阵),因此我们不能希望在一个复数中表示所有这些自由度。但对于全纯函数,我们可以!全纯函数恰好是一个(\mathbb{C} \to \mathbb{C})函数,其导数可以表示为一个单一的复数。(柯西-黎曼方程确保上述 2x2 雅可比矩阵在复平面内的作用具有复数乘法下的比例和旋转矩阵的特殊形式。)我们可以使用一个vjp调用并带有1.0的余切向量来揭示那一个复数。

因为这仅适用于全纯函数,为了使用这个技巧,我们需要向 JAX 保证我们的函数是全纯的;否则,在复数输出函数上使用grad时,JAX 会引发错误:

def f(z):
  return jnp.sin(z)

z = 3. + 4j
grad(f, holomorphic=True)(z) 
Array(-27.034945-3.8511531j, dtype=complex64, weak_type=True) 

holomorphic=True的承诺仅仅是在输出是复数值时禁用错误。当函数不是全纯时,我们仍然可以写holomorphic=True,但得到的答案将不表示完整的雅可比矩阵。相反,它将是在我们只丢弃输出的虚部的函数的雅可比矩阵。

def f(z):
  return jnp.conjugate(z)

z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic! 
Array(1.-0.j, dtype=complex64, weak_type=True) 

在这里grad的工作有一些有用的结论:

  1. 我们可以在全纯的(\mathbb{C} \to \mathbb{C})函数上使用grad

  2. 我们可以使用grad来优化(\mathbb{C} \to \mathbb{R})函数,例如复参数x的实值损失函数,通过朝着grad(f)(x)的共轭方向迈出步伐。

  3. 如果我们有一个(\mathbb{R} \to \mathbb{R})的函数,它恰好在内部使用一些复数运算(其中一些必须是非全纯的,例如在卷积中使用的 FFT),那么grad仍然有效,并且我们得到与仅使用实数值的实现相同的结果。

在任何情况下,JVPs 和 VJPs 都是明确的。如果我们想计算非全纯函数(\mathbb{C} \to \mathbb{C})的完整 Jacobian 矩阵,我们可以用 JVPs 或 VJPs 来做到!

你应该期望复数在 JAX 中的任何地方都能正常工作。这里是通过复杂矩阵的 Cholesky 分解进行微分:

A = jnp.array([[5.,    2.+3j,    5j],
              [2.-3j,   7.,  1.+7j],
              [-5j,  1.-7j,    12.]])

def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)

grad(f, holomorphic=True)(A) 
Array([[-0.7534186  +0.j       , -3.0509028 -10.940544j ,
         5.9896846  +3.5423026j],
       [-3.0509028 +10.940544j , -8.904491   +0.j       ,
        -5.1351523  -6.559373j ],
       [ 5.9896846  -3.5423026j, -5.1351523  +6.559373j ,
         0.01320427 +0.j       ]], dtype=complex64) 

更高级的自动微分

在这本笔记本中,我们通过一些简单的,然后逐渐复杂的应用中,使用 JAX 中的自动微分。我们希望现在您感觉在 JAX 中进行导数运算既简单又强大。

还有很多其他自动微分的技巧和功能。我们没有涵盖的主题,但希望在“高级自动微分手册”中进行涵盖:

  • 高斯-牛顿向量乘积,一次线性化

  • 自定义的 VJPs 和 JVPs

  • 在固定点处高效地求导

  • 使用随机的 Hessian-vector products 来估计 Hessian 的迹。

  • 仅使用反向模式自动微分的前向模式自动微分。

  • 对自定义数据类型进行导数计算。

  • 检查点(二项式检查点用于高效的反向模式,而不是模型快照)。

  • 优化 VJPs 通过 Jacobian 预积累。

JAX 可转换的 Python 函数的自定义导数规则

原文:jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

在 Colab 中打开 在 Kaggle 中打开

mattjj@ Mar 19 2020, last updated Oct 14 2020

JAX 中定义微分规则的两种方式:

  1. 使用 jax.custom_jvpjax.custom_vjp 来为已经可转换为 JAX 的 Python 函数定义自定义微分规则;以及

  2. 定义新的 core.Primitive 实例及其所有转换规则,例如调用来自其他系统(如求解器、模拟器或一般数值计算系统)的函数。

本笔记本讨论的是 #1. 要了解关于 #2 的信息,请参阅关于添加原语的笔记本

关于 JAX 自动微分 API 的介绍,请参阅自动微分手册。本笔记本假定读者已对jax.jvpjax.grad,以及 JVPs 和 VJPs 的数学含义有一定了解。

TL;DR

使用 jax.custom_jvp 进行自定义 JVPs

import jax.numpy as jnp
from jax import custom_jvp

@custom_jvp
def f(x, y):
  return jnp.sin(x) * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot
  return primal_out, tangent_out 
from jax import jvp, grad

print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.)) 
2.7278922
2.7278922
-1.2484405
-1.2484405 
# Equivalent alternative using the defjvps convenience wrapper

@custom_jvp
def f(x, y):
  return jnp.sin(x) * y

f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y,
          lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) 
print(f(2., 3.))
y, y_dot = jvp(f, (2., 3.), (1., 0.))
print(y)
print(y_dot)
print(grad(f)(2., 3.)) 
2.7278922
2.7278922
-1.2484405
-1.2484405 

使用 jax.custom_vjp 进行自定义 VJPs

from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
# Returns primal output and residuals to be used in backward pass by f_bwd.
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res # Gets residuals computed in f_fwd
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd) 
print(grad(f)(2., 3.)) 
-1.2484405 

示例问题

要了解 jax.custom_jvpjax.custom_vjp 所解决的问题,我们可以看几个例子。有关 jax.custom_jvpjax.custom_vjp API 的更详细介绍在下一节中。

数值稳定性

jax.custom_jvp 的一个应用是提高微分的数值稳定性。

假设我们想编写一个名为 log1pexp 的函数,用于计算 (x \mapsto \log ( 1 + e^x ))。我们可以使用 jax.numpy 来写:

import jax.numpy as jnp

def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

log1pexp(3.) 
Array(3.0485873, dtype=float32, weak_type=True) 

因为它是用 jax.numpy 编写的,所以它是 JAX 可转换的:

from jax import jit, grad, vmap

print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

但这里存在一个数值稳定性问题:

print(grad(log1pexp)(100.)) 
nan 

那似乎不对!毕竟,(x \mapsto \log (1 + e^x)) 的导数是 (x \mapsto \frac{e^x}{1 + e^x}),因此对于大的 (x) 值,我们期望值约为 1。

通过查看梯度计算的 jaxpr,我们可以更深入地了解发生了什么:

from jax import make_jaxpr

make_jaxpr(grad(log1pexp))(100.) 
{ lambda ; a:f32[]. let
    b:f32[] = exp a
    c:f32[] = add 1.0 b
    _:f32[] = log c
    d:f32[] = div 1.0 c
    e:f32[] = mul d b
  in (e,) } 

通过分析 jaxpr 如何评估,我们可以看到最后一行涉及的值相乘会导致浮点数计算四舍五入为 0 和 (\infty),这从未是一个好主意。也就是说,我们实际上在评估大数值的情况下,计算的是 lambda x: (1 / (1 + jnp.exp(x))) * jnp.exp(x),这实际上会变成 0. * jnp.inf

而不是生成这样大和小的值,希望浮点数能够提供的取消,我们宁愿将导数函数表达为一个更稳定的数值程序。特别地,我们可以编写一个程序,更接近地评估相等的数学表达式 (1 - \frac{1}{1 + e^x}),看不到取消。

这个问题很有趣,因为即使我们的log1pexp的定义已经可以进行 JAX 微分(并且可以使用jitvmap等转换),我们对应用标准自动微分规则到组成log1pexp并组合结果的结果并不满意。相反,我们想要指定整个函数log1pexp如何作为一个单位进行微分,从而更好地安排这些指数。

这是关于 Python 函数的自定义导数规则的一个应用,这些函数已经可以使用 JAX 进行转换:指定如何对复合函数进行微分,同时仍然使用其原始的 Python 定义进行其他转换(如jitvmap等)。

这里是使用jax.custom_jvp的解决方案:

from jax import custom_jvp

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = log1pexp(x)
  ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
  return ans, ans_dot 
print(grad(log1pexp)(100.)) 
1.0 
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

这里是一个defjvps方便包装,来表达同样的事情:

@custom_jvp
def log1pexp(x):
  return jnp.log(1. + jnp.exp(x))

log1pexp.defjvps(lambda t, ans, x: (1 - 1/(1 + jnp.exp(x))) * t) 
print(grad(log1pexp)(100.))
print(jit(log1pexp)(3.))
print(jit(grad(log1pexp))(3.))
print(vmap(jit(grad(log1pexp)))(jnp.arange(3.))) 
1.0
3.0485873
0.95257413
[0.5       0.7310586 0.8807971] 

强制执行微分约定

一个相关的应用是强制执行微分约定,也许在边界处。

考虑函数 (f : \mathbb{R}+ \to \mathbb{R}+),其中 (f(x) = \frac{x}{1 + \sqrt{x}}),其中我们取 (\mathbb{R}_+ = [0, \infty))。我们可以像这样实现 (f) 的程序:

def f(x):
  return x / (1 + jnp.sqrt(x)) 

作为在(\mathbb{R})上的数学函数(完整的实数线),(f) 在零点是不可微的(因为从左侧定义导数的极限不存在)。相应地,自动微分产生一个nan值:

print(grad(f)(0.)) 
nan 

但是数学上,如果我们将 (f) 视为 (\mathbb{R}_+) 上的函数,则它在 0 处是可微的 [Rudin 的《数学分析原理》定义 5.1,或 Tao 的《分析 I》第 3 版定义 10.1.1 和例子 10.1.6]。或者,我们可能会说,作为一个惯例,我们希望考虑从右边的方向导数。因此,对于 Python 函数grad(f)0.0处返回 1.0 是有意义的值。默认情况下,JAX 对微分的机制假设所有函数在(\mathbb{R})上定义,因此这里并不会产生1.0

我们可以使用自定义的 JVP 规则!特别地,我们可以定义 JVP 规则,关于导数函数 (x \mapsto \frac{\sqrt{x} + 2}{2(\sqrt{x} + 1)²}) 在 (\mathbb{R}_+) 上,

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))

@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  ans_dot = ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * x_dot
  return ans, ans_dot 
print(grad(f)(0.)) 
1.0 

这里是方便包装版本:

@custom_jvp
def f(x):
  return x / (1 + jnp.sqrt(x))

f.defjvps(lambda t, ans, x: ((jnp.sqrt(x) + 2) / (2 * (jnp.sqrt(x) + 1)**2)) * t) 
print(grad(f)(0.)) 
1.0 

梯度剪裁

虽然在某些情况下,我们想要表达一个数学微分计算,在其他情况下,我们甚至可能想要远离数学,来调整自动微分的计算。一个典型的例子是反向模式梯度剪裁。

对于梯度剪裁,我们可以使用jnp.clip和一个jax.custom_vjp仅逆模式规则:

from functools import partial
from jax import custom_vjp

@custom_vjp
def clip_gradient(lo, hi, x):
  return x  # identity function

def clip_gradient_fwd(lo, hi, x):
  return x, (lo, hi)  # save bounds as residuals

def clip_gradient_bwd(res, g):
  lo, hi = res
  return (None, None, jnp.clip(g, lo, hi))  # use None to indicate zero cotangents for lo and hi

clip_gradient.defvjp(clip_gradient_fwd, clip_gradient_bwd) 
import matplotlib.pyplot as plt
from jax import vmap

t = jnp.linspace(0, 10, 1000)

plt.plot(jnp.sin(t))
plt.plot(vmap(grad(jnp.sin))(t)) 
[<matplotlib.lines.Line2D at 0x7f43dfc210f0>] 

../_images/deaae0f99458d9656c1888a740e8fddef86e7a2a68deda903918a80e0b7597be.png

def clip_sin(x):
  x = clip_gradient(-0.75, 0.75, x)
  return jnp.sin(x)

plt.plot(clip_sin(t))
plt.plot(vmap(grad(clip_sin))(t)) 
[<matplotlib.lines.Line2D at 0x7f43ddb15fc0>] 

../_images/8b564451d0b054ab4486979b76183ae8af108a4d106652226651c16843285dde.png

Python 调试

另一个应用,是受开发工作流程而非数值驱动的动机,是在反向模式自动微分的后向传递中设置pdb调试器跟踪。

在尝试追踪nan运行时错误的来源,或者仅仔细检查传播的余切(梯度)值时,可以在反向传递中的特定点插入调试器非常有用。您可以使用jax.custom_vjp来实现这一点。

我们将在下一节中推迟一个示例。

迭代实现的隐式函数微分

这个例子涉及到了数学中的深层问题!

另一个应用jax.custom_vjp是对可通过jitvmap等转换为 JAX 但由于某些原因不易 JAX 可区分的函数进行反向模式微分,也许是因为涉及lax.while_loop。(无法生成 XLA HLO 程序有效计算 XLA HLO While 循环的反向模式导数,因为这将需要具有无界内存使用的程序,这在 XLA HLO 中是不可能表达的,至少不是通过通过 infeed/outfeed 的副作用交互。)

例如,考虑这个fixed_point例程,它通过在while_loop中迭代应用函数来计算一个不动点:

from jax.lax import while_loop

def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6

  def body_fun(carry):
    _, x = carry
    return x, f(a, x)

  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star 

这是一种通过迭代应用函数(x_{t+1} = f(a, x_t))来数值解方程(x = f(a, x))以计算(x)的迭代过程,直到(x_{t+1})足够接近(x_t)。结果(x^)取决于参数(a),因此我们可以认为存在一个由方程(x = f(a, x))隐式定义的函数(a \mapsto x^(a))。

我们可以使用fixed_point运行迭代过程以收敛,例如运行牛顿法来计算平方根,只执行加法、乘法和除法:

def newton_sqrt(a):
  update = lambda a, x: 0.5 * (x + a / x)
  return fixed_point(update, a, a) 
print(newton_sqrt(2.)) 
1.4142135 

我们也可以对函数进行vmapjit处理:

print(jit(vmap(newton_sqrt))(jnp.array([1., 2., 3., 4.]))) 
[1\.        1.4142135 1.7320509 2\.       ] 

由于while_loop,我们无法应用反向模式自动微分,但事实证明我们也不想这样做:我们可以利用数学结构做一些更节省内存(在这种情况下也更节省 FLOP)的事情!我们可以使用隐函数定理[Bertsekas 的《非线性规划,第二版》附录 A.25],它保证(在某些条件下)我们即将使用的数学对象的存在。本质上,我们在线性化解决方案处进行线性化,并迭代解这些线性方程以计算我们想要的导数。

再次考虑方程(x = f(a, x))和函数(x^)。我们想要评估向量-Jacobian 乘积,如(v^\mathsf{T} \mapsto v^\mathsf{T} \partial x^(a_0))。

至少在我们想要求微分的点(a_0)周围的开放邻域内,让我们假设方程(x^(a) = f(a, x^(a)))对所有(a)都成立。由于两边作为(a)的函数相等,它们的导数也必须相等,所以让我们分别对两边进行微分:

(\qquad \partial x^(a) = \partial_0 f(a, x^(a)) + \partial_1 f(a, x^(a)) \partial x^(a))。

设置(A = \partial_1 f(a_0, x^(a_0)))和(B = \partial_0 f(a_0, x^(a_0))),我们可以更简单地写出我们想要的数量为

(\qquad \partial x^(a_0) = B + A \partial x^(a_0)),

或者,通过重新排列,

(\qquad \partial x^*(a_0) = (I - A)^{-1} B)。

这意味着我们可以评估向量-Jacobian 乘积,如

(\qquad v^\mathsf{T} \partial x^*(a_0) = v^\mathsf{T} (I - A)^{-1} B = w^\mathsf{T} B),

其中(w^\mathsf{T} = v^\mathsf{T} (I - A){-1}),或者等效地(w\mathsf{T} = v^\mathsf{T} + w^\mathsf{T} A),或者等效地(w\mathsf{T})是映射(u\mathsf{T} \mapsto v^\mathsf{T} + u^\mathsf{T} A)的不动点。最后一个描述使我们可以根据对fixed_point的调用来编写fixed_point的 VJP!此外,在展开(A)和(B)之后,我们可以看到我们只需要在((a_0, x^*(a_0)))处评估(f)的 VJP。

这里是要点:

from jax import vjp

@partial(custom_vjp, nondiff_argnums=(0,))
def fixed_point(f, a, x_guess):
  def cond_fun(carry):
    x_prev, x = carry
    return jnp.abs(x_prev - x) > 1e-6

  def body_fun(carry):
    _, x = carry
    return x, f(a, x)

  _, x_star = while_loop(cond_fun, body_fun, (x_guess, f(a, x_guess)))
  return x_star

def fixed_point_fwd(f, a, x_init):
  x_star = fixed_point(f, a, x_init)
  return x_star, (a, x_star)

def fixed_point_rev(f, res, x_star_bar):
  a, x_star = res
  _, vjp_a = vjp(lambda a: f(a, x_star), a)
  a_bar, = vjp_a(fixed_point(partial(rev_iter, f),
                             (a, x_star, x_star_bar),
                             x_star_bar))
  return a_bar, jnp.zeros_like(x_star)

def rev_iter(f, packed, u):
  a, x_star, x_star_bar = packed
  _, vjp_x = vjp(lambda x: f(a, x), x_star)
  return x_star_bar + vjp_x(u)[0]

fixed_point.defvjp(fixed_point_fwd, fixed_point_rev) 
print(newton_sqrt(2.)) 
1.4142135 
print(grad(newton_sqrt)(2.))
print(grad(grad(newton_sqrt))(2.)) 
0.35355338
-0.088388346 

我们可以通过对 jnp.sqrt 进行微分来检查我们的答案,它使用了完全不同的实现:

print(grad(jnp.sqrt)(2.))
print(grad(grad(jnp.sqrt))(2.)) 
0.35355338
-0.08838835 

这种方法的一个限制是参数f不能涉及到任何参与微分的值。也就是说,你可能注意到我们在fixed_point的参数列表中明确保留了参数a。对于这种用例,考虑使用低级原语lax.custom_root,它允许在闭合变量中进行带有自定义根查找函数的导数。

使用 jax.custom_jvpjax.custom_vjp API 的基本用法

使用 jax.custom_jvp 来定义前向模式(以及间接地,反向模式)规则

这里是使用 jax.custom_jvp 的典型基本示例,其中注释使用Haskell-like type signatures

from jax import custom_jvp
import jax.numpy as jnp

# f :: a -> b
@custom_jvp
def f(x):
  return jnp.sin(x)

# f_jvp :: (a, T a) -> (b, T b)
def f_jvp(primals, tangents):
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t

f.defjvp(f_jvp) 
<function __main__.f_jvp(primals, tangents)> 
from jax import jvp

print(f(3.))

y, y_dot = jvp(f, (3.,), (1.,))
print(y)
print(y_dot) 
0.14112
0.14112
-0.9899925 

简言之,我们从一个原始函数f开始,它接受类型为a的输入并产生类型为b的输出。我们与之关联一个 JVP 规则函数f_jvp,它接受一对输入,表示类型为a的原始输入和类型为T a的相应切线输入,并产生一对输出,表示类型为b的原始输出和类型为T b的切线输出。切线输出应该是切线输入的线性函数。

你还可以使用 f.defjvp 作为装饰器,就像这样

@custom_jvp
def f(x):
  ...

@f.defjvp
def f_jvp(primals, tangents):
  ... 

尽管我们只定义了一个 JVP 规则而没有 VJP 规则,但我们可以在f上同时使用正向和反向模式的微分。JAX 会自动将切线值上的线性计算从我们的自定义 JVP 规则转置,高效地计算出 VJP,就好像我们手工编写了规则一样。

from jax import grad

print(grad(f)(3.))
print(grad(grad(f))(3.)) 
-0.9899925
-0.14112 

为了使自动转置工作,JVP 规则的输出切线必须是输入切线的线性函数。否则将引发转置错误。

多个参数的工作方式如下:

@custom_jvp
def f(x, y):
  return x ** 2 * y

@f.defjvp
def f_jvp(primals, tangents):
  x, y = primals
  x_dot, y_dot = tangents
  primal_out = f(x, y)
  tangent_out = 2 * x * y * x_dot + x ** 2 * y_dot
  return primal_out, tangent_out 
print(grad(f)(2., 3.)) 
12.0 

defjvps便捷包装器允许我们为每个参数单独定义一个 JVP,并分别计算结果后进行求和:

@custom_jvp
def f(x):
  return jnp.sin(x)

f.defjvps(lambda t, ans, x: jnp.cos(x) * t) 
print(grad(f)(3.)) 
-0.9899925 

下面是一个带有多个参数的defjvps示例:

@custom_jvp
def f(x, y):
  return x ** 2 * y

f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          lambda y_dot, primal_out, x, y: x ** 2 * y_dot) 
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.)) 
12.0
12.0
4.0 

简而言之,使用defjvps,您可以传递None值来指示特定参数的 JVP 为零:

@custom_jvp
def f(x, y):
  return x ** 2 * y

f.defjvps(lambda x_dot, primal_out, x, y: 2 * x * y * x_dot,
          None) 
print(grad(f)(2., 3.))
print(grad(f, 0)(2., 3.))  # same as above
print(grad(f, 1)(2., 3.)) 
12.0
12.0
0.0 

使用关键字参数调用jax.custom_jvp函数,或者编写具有默认参数的jax.custom_jvp函数定义,只要能够根据通过标准库inspect.signature机制检索到的函数签名映射到位置参数即可。

当您不执行微分时,函数f的调用方式与未被jax.custom_jvp修饰时完全一样:

@custom_jvp
def f(x):
  print('called f!')  # a harmless side-effect
  return jnp.sin(x)

@f.defjvp
def f_jvp(primals, tangents):
  print('called f_jvp!')  # a harmless side-effect
  x, = primals
  t, = tangents
  return f(x), jnp.cos(x) * t 
from jax import vmap, jit

print(f(3.)) 
called f!
0.14112 
print(vmap(f)(jnp.arange(3.)))
print(jit(f)(3.)) 
called f!
[0\.         0.84147096 0.9092974 ]
called f!
0.14112 

自定义的 JVP 规则在微分过程中被调用,无论是正向还是反向:

y, y_dot = jvp(f, (3.,), (1.,))
print(y_dot) 
called f_jvp!
called f!
-0.9899925 
print(grad(f)(3.)) 
called f_jvp!
called f!
-0.9899925 

注意,f_jvp调用f来计算原始输出。在高阶微分的上下文中,每个微分变换的应用将只在规则调用原始f来计算原始输出时使用自定义的 JVP 规则。(这代表一种基本的权衡,我们不能同时利用f的评估中间值来制定规则并且使规则在所有高阶微分顺序中应用。)

grad(grad(f))(3.) 
called f_jvp!
called f_jvp!
called f! 
Array(-0.14112, dtype=float32, weak_type=True) 

您可以使用 Python 控制流来使用jax.custom_jvp

@custom_jvp
def f(x):
  if x > 0:
    return jnp.sin(x)
  else:
    return jnp.cos(x)

@f.defjvp
def f_jvp(primals, tangents):
  x, = primals
  x_dot, = tangents
  ans = f(x)
  if x > 0:
    return ans, 2 * x_dot
  else:
    return ans, 3 * x_dot 
print(grad(f)(1.))
print(grad(f)(-1.)) 
2.0
3.0 

使用jax.custom_vjp来定义自定义的仅反向模式规则

虽然jax.custom_jvp足以控制正向和通过 JAX 的自动转置控制反向模式微分行为,但在某些情况下,我们可能希望直接控制 VJP 规则,例如在上述后两个示例问题中。我们可以通过jax.custom_vjp来实现这一点。

from jax import custom_vjp
import jax.numpy as jnp

# f :: a -> b
@custom_vjp
def f(x):
  return jnp.sin(x)

# f_fwd :: a -> (b, c)
def f_fwd(x):
  return f(x), jnp.cos(x)

# f_bwd :: (c, CT b) -> CT a
def f_bwd(cos_x, y_bar):
  return (cos_x * y_bar,)

f.defvjp(f_fwd, f_bwd) 
from jax import grad

print(f(3.))
print(grad(f)(3.)) 
0.14112
-0.9899925 

换句话说,我们再次从接受类型为a的输入并产生类型为b的输出的原始函数f开始。我们将与之关联两个函数f_fwdf_bwd,它们描述了如何执行反向模式自动微分的正向和反向传递。

函数f_fwd描述了前向传播,不仅包括原始计算,还包括要保存以供后向传播使用的值。其输入签名与原始函数f完全相同,即它接受类型为a的原始输入。但作为输出,它产生一对值,其中第一个元素是原始输出b,第二个元素是类型为c的任何“残余”数据,用于后向传播时存储。(这第二个输出类似于PyTorch 的 save_for_backward 机制。)

函数f_bwd描述了反向传播。它接受两个输入,第一个是由f_fwd生成的类型为c的残差数据,第二个是对应于原始函数输出的类型为CT b的输出共切线。它生成一个类型为CT a的输出,表示原始函数输入对应的共切线。特别地,f_bwd的输出必须是长度等于原始函数参数个数的序列(例如元组)。

多个参数的工作方式如下:

from jax import custom_vjp

@custom_vjp
def f(x, y):
  return jnp.sin(x) * y

def f_fwd(x, y):
  return f(x, y), (jnp.cos(x), jnp.sin(x), y)

def f_bwd(res, g):
  cos_x, sin_x, y = res
  return (cos_x * g * y, sin_x * g)

f.defvjp(f_fwd, f_bwd) 
print(grad(f)(2., 3.)) 
-1.2484405 

调用带有关键字参数的jax.custom_vjp函数,或者编写带有默认参数的jax.custom_vjp函数定义,只要可以根据标准库inspect.signature机制清晰地映射到位置参数即可。

jax.custom_jvp类似,如果没有应用微分,则不会调用由f_fwdf_bwd组成的自定义 VJP 规则。如果对函数进行评估,或者使用jitvmap或其他非微分变换进行转换,则只调用f

@custom_vjp
def f(x):
  print("called f!")
  return jnp.sin(x)

def f_fwd(x):
  print("called f_fwd!")
  return f(x), jnp.cos(x)

def f_bwd(cos_x, y_bar):
  print("called f_bwd!")
  return (cos_x * y_bar,)

f.defvjp(f_fwd, f_bwd) 
print(f(3.)) 
called f!
0.14112 
print(grad(f)(3.)) 
called f_fwd!
called f!
called f_bwd!
-0.9899925 
from jax import vjp

y, f_vjp = vjp(f, 3.)
print(y) 
called f_fwd!
called f!
0.14112 
print(f_vjp(1.)) 
called f_bwd!
(Array(-0.9899925, dtype=float32, weak_type=True),) 

无法在 jax.custom_vjp 函数上使用前向模式自动微分,否则会引发错误:

from jax import jvp

try:
  jvp(f, (3.,), (1.,))
except TypeError as e:
  print('ERROR! {}'.format(e)) 
called f_fwd!
called f!
ERROR! can't apply forward-mode autodiff (jvp) to a custom_vjp function. 

如果希望同时使用前向和反向模式,请使用jax.custom_jvp

我们可以使用jax.custom_vjppdb一起在反向传播中插入调试器跟踪:

import pdb

@custom_vjp
def debug(x):
  return x  # acts like identity

def debug_fwd(x):
  return x, x

def debug_bwd(x, g):
  import pdb; pdb.set_trace()
  return g

debug.defvjp(debug_fwd, debug_bwd) 
def foo(x):
  y = x ** 2
  y = debug(y)  # insert pdb in corresponding backward pass step
  return jnp.sin(y) 
jax.grad(foo)(3.)

> <ipython-input-113-b19a2dc1abf7>(12)debug_bwd()
-> return g
(Pdb) p x
Array(9., dtype=float32)
(Pdb) p g
Array(-0.91113025, dtype=float32)
(Pdb) q 

更多特性和细节

使用list / tuple / dict容器(和其他 pytree)

你应该期望标准的 Python 容器如列表、元组、命名元组和字典可以正常工作,以及这些容器的嵌套版本。总体而言,任何pytrees都是允许的,只要它们的结构符合类型约束。

这里有一个使用jax.custom_jvp的构造示例:

from collections import namedtuple
Point = namedtuple("Point", ["x", "y"])

@custom_jvp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}

@f.defjvp
def f_jvp(primals, tangents):
  pt, = primals
  pt_dot, =  tangents
  ans = f(pt)
  ans_dot = {'a': 2 * pt.x * pt_dot.x,
             'b': (jnp.cos(pt.x) * pt_dot.x, -jnp.sin(pt.y) * pt_dot.y)}
  return ans, ans_dot

def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0] 
pt = Point(1., 2.)

print(f(pt)) 
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))} 
print(grad(fun)(pt)) 
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(0., dtype=float32, weak_type=True)) 

还有一个类似的使用jax.custom_vjp的构造示例:

@custom_vjp
def f(pt):
  x, y = pt.x, pt.y
  return {'a': x ** 2,
          'b': (jnp.sin(x), jnp.cos(y))}

def f_fwd(pt):
  return f(pt), pt

def f_bwd(pt, g):
  a_bar, (b0_bar, b1_bar) = g['a'], g['b']
  x_bar = 2 * pt.x * a_bar + jnp.cos(pt.x) * b0_bar
  y_bar = -jnp.sin(pt.y) * b1_bar
  return (Point(x_bar, y_bar),)

f.defvjp(f_fwd, f_bwd)

def fun(pt):
  dct = f(pt)
  return dct['a'] + dct['b'][0] 
pt = Point(1., 2.)

print(f(pt)) 
{'a': 1.0, 'b': (Array(0.84147096, dtype=float32, weak_type=True), Array(-0.41614684, dtype=float32, weak_type=True))} 
print(grad(fun)(pt)) 
Point(x=Array(2.5403023, dtype=float32, weak_type=True), y=Array(-0., dtype=float32, weak_type=True)) 

处理非可微参数

一些用例,如最后的示例问题,需要将非可微参数(如函数值参数)传递给具有自定义微分规则的函数,并且这些参数也需要传递给规则本身。在fixed_point的情况下,函数参数f就是这样一个非可微参数。类似的情况在jax.experimental.odeint中也会出现。

jax.custom_jvpnondiff_argnums

使用可选的 nondiff_argnums 参数来指示类似这些的参数给 jax.custom_jvp。以下是一个带有 jax.custom_jvp 的例子:

from functools import partial

@partial(custom_jvp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)

@app.defjvp
def app_jvp(f, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(x), 2. * x_dot 
print(app(lambda x: x ** 3, 3.)) 
27.0 
print(grad(app, 1)(lambda x: x ** 3, 3.)) 
2.0 

注意这里的陷阱:无论这些参数在参数列表的哪个位置出现,它们都放置在相应 JVP 规则签名的起始位置。这里有另一个例子:

@partial(custom_jvp, nondiff_argnums=(0, 2))
def app2(f, x, g):
  return f(g((x)))

@app2.defjvp
def app2_jvp(f, g, primals, tangents):
  x, = primals
  x_dot, = tangents
  return f(g(x)), 3. * x_dot 
print(app2(lambda x: x ** 3, 3., lambda y: 5 * y)) 
3375.0 
print(grad(app2, 1)(lambda x: x ** 3, 3., lambda y: 5 * y)) 
3.0 

nondiff_argnumsjax.custom_vjp

对于 jax.custom_vjp 也有类似的选项,类似地,非可微参数的约定是它们作为 _bwd 规则的第一个参数传递,无论它们出现在原始函数签名的哪个位置。 _fwd 规则的签名保持不变 - 它与原始函数的签名相同。以下是一个例子:

@partial(custom_vjp, nondiff_argnums=(0,))
def app(f, x):
  return f(x)

def app_fwd(f, x):
  return f(x), x

def app_bwd(f, x, g):
  return (5 * g,)

app.defvjp(app_fwd, app_bwd) 
print(app(lambda x: x ** 2, 4.)) 
16.0 
print(grad(app, 1)(lambda x: x ** 2, 4.)) 
5.0 

请参见上面的 fixed_point 以获取另一个用法示例。

对于具有整数 dtype 的数组值参数,不需要使用 nondiff_argnums **。相反,nondiff_argnums 应仅用于不对应 JAX 类型(实质上不对应数组类型)的参数值,如 Python 可调用对象或字符串。如果 JAX 检测到由 nondiff_argnums 指示的参数包含 JAX Tracer,则会引发错误。上面的 clip_gradient 函数是不使用 nondiff_argnums 处理整数 dtype 数组参数的良好示例。

标签:中文,jnp,return,JAX,jvp,文档,print,grad,def
From: https://www.cnblogs.com/apachecn/p/18260390

相关文章

  • prometheus 中文说明告警指标
    https://blog.51cto.com/qiangsh/1977449主机和硬件监控可用内存指标主机中可用内存容量不足10%-alert:HostOutOfMemoryexpr:node_memory_MemAvailable_bytes/node_memory_MemTotal_bytes*100<10for:5mlabels:severity:warningannotations:......
  • 软件开发项目全套文档资料参考(规格说明书、详细设计、测试计划、验收报告)
     前言:在软件开发过程中,文档资料是非常关键的一部分,它们帮助团队成员理解项目需求、设计、实施、测试、验收等各个环节,确保项目的顺利进行。以下是各个阶段的文档资料概述:软件项目管理部分文档清单: 工作安排任务书,可行性分析报告,立项申请审批表,产品需求规格说明书,需求调研......
  • 中文检测插件
    大家都知道,做出海应用,尤其是在一些对中国不友好的国家做业务。全面去中文化至关重要。对于开发而言,在代码层如果只靠人为控制这个变量,尤其艰难。所以给大家安利一个我们自研的中文检测插件,他能在您开发过程中时刻检测您的输入是否含有中文。大家先看下效果。如果您有需要,烦......
  • 搜索硬编码中文
    老项目中常常有直接在代码里或者xml布局中硬编码中文的,在后期业务扩展做国际化翻译时,这就是一个巨大的坑,因为我们需要知道哪里硬编码了,然后提取到strings.xml中刚好我最近在弄这个,如何找到代码中所有的硬编码就是核心问题,下面记录下我的步骤 1.首先写好正则,直接百度也行^((?!......
  • MestReNova14.0中文版安装教程
    MestReNova14是一款专业级的核磁共振(NMR)与质谱(MS)数据分析软件,专注于化合物结构解析和验证。该软件以卓越的谱图处理能力和智能化算法为核心,提供自定义参数调整、自动峰识别、精准积分、耦合常数计算等功能。支持多种仪器数据格式导入,可高效处理一维至四维NMR谱图以及各类质谱数据......
  • PDF英语文档怎么翻译成中文?
    外语文献是我们学习和工作中经常遇到的难题,其中包含许多重要工作信息,精确地理解和翻译非常重要。但并不是所有格式的文件都能直接编辑和翻译。例如PDF格式的文件就无法直接进行编辑,当我们需要翻译PDF格式的外语文档时,应该使用什么工具呢?本篇文章就为你提供几个快速翻译PDF文件的方......
  • mac苹果窗口辅助工具:Magnet for mac 2.14.0中文免激活版
    Magnet是一款针对MacOS系统的窗口管理工具软件。它能够帮助用户更加高效地管理和组织桌面上的窗口,通过简单的快捷键操作,可以将窗口自动调整到指定的位置和大小,实现多窗口快速布局。Magnet还支持多显示器环境下的窗口管理,可以让用户更加轻松地在多屏幕之间切换和布局窗口。......
  • 032java jsp ssm大学生第二课堂成绩单系统学生思想道德技术修养文体活动管理(源码+数据
     项目技术:SSM+Maven等等组成,B/S模式+Maven管理等等。环境需要1.运行环境:最好是javajdk1.8,我们在这个平台上运行的。其他版本理论上也可以。2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA;3.tomcat环境:Tomcat7.x,8.x,9.x版本均可4.硬件环境:windows7/8/101G......
  • 026java jsp ssm网络硬盘系统网站系统(源码+数据库+文档)
     项目技术:SSM+Maven等等组成,B/S模式+Maven管理等等。环境需要1.运行环境:最好是javajdk1.8,我们在这个平台上运行的。其他版本理论上也可以。2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA;3.tomcat环境:Tomcat7.x,8.x,9.x版本均可4.硬件环境:windows7/8/101G......
  • 027java jsp ssm洗衣店管理系统(源码+数据库+文档)
     项目技术:SSM+Maven等等组成,B/S模式+Maven管理等等。环境需要1.运行环境:最好是javajdk1.8,我们在这个平台上运行的。其他版本理论上也可以。2.IDE环境:IDEA,Eclipse,Myeclipse都可以。推荐IDEA;3.tomcat环境:Tomcat7.x,8.x,9.x版本均可4.硬件环境:windows7/8/101G......