# %%
import jax
import jax.numpy as jnp
import numpy as np
def loss(params, r):
lambda_a, lambda_s = params
return jnp.maximum(r - lambda_a + lambda_s, 0).max()
loss_grad = jax.grad(loss)
grad_a, grad_s = loss_grad(params, r)
print(grad_a, grad_s)
##%%
def jac(params, r):
lambda_a, lambda_s = params
v = r - lambda_a + lambda_s
g_b = np.logical_and(v>0, v==v.max()).astype(float)
g_b /= g_b.sum()
return -g_b, g_b
#%% 测试
print("data1")
params = (np.array([-.50, .51, .51]), np.random.randn(3) )
r = 1
grad_a, grad_s = loss_grad(params, r)
print(params)
print(grad_a, grad_s)
ga, gs = jac(params, r)
print(grad_a==ga, grad_s==gs)
标签:loss,jax,梯度,params,计算,print,grad,lambda
From: https://www.cnblogs.com/bregman/p/17397225.html