在之前发布文章《一个新 TensorFlow Lite 示例应用:棋盘游戏》中,展示了如何使用 TensorFlow 和 TensorFlow Agents 来训练强化学习 (RL) agent,使其玩一个简单棋盘游戏 “Plane Strike”。我们还将训练后模型转换为 TensorFlow Lite,然后将其部署到功能完备 Android 应用中。本文,我们将演示一种全新路径: 使用 Flax/JAX 训练相同强化学习 agent,然后将其部署到我们之前构建同一款 Android 应用中。我们已经在 tensorflow/examples 代码库中开放了完整源代码以供您参考。
- Flaxhttps://flax.readthedocs.io/
- JAXhttps://jax.readthedocs.io/
- tensorflow/exampleshttps://github.com/tensorflow/examples/blob/master/lite/examples/reinforcement_learning/ml/tf_and_jax/training_jax.py
△ “Plane Strike” 游戏演示 背景: JAX 和 TensorFlow
JAX 是一个与 NumPy 类似内容库,由 Google Research 部门专为实现高性能计算而开发。JAX 使用 XLA 针对 GPU 和 TPU 优化程序进行编译。
- JAXhttps://github.com/google/jax
- XLAhttps://tensorflow.google.cn/xla
- TPUhttps://cloud.google.com/tpu
- Flaxhttps://github.com/google/flax
- PaLMhttps://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html
- Imagenhttps://imagen.research.google/
- JAX 101 教程https://jax.readthedocs.io/en/latest/jax-101/index.html
- Flax 入门示例https://flax.readthedocs.io/en/latest/getting_started.html
- TFXhttps://tensorflow.google.cn/tfx
- TensorBoardhttps://tensorboard.dev/
- TensorFlow Litehttps://tensorflow.google.cn/lite
- TensorFlow.jshttps://tensorflow.google.cn/js
- SavedModelhttps://tensorflow.google.cn/guide/saved_model
- 频 “使用 TensorFlow Serving 为 JAX 模型提供服务”,展示了如何使用 TensorFlow Serving 部署 JAX 模型:
- 文章《借助 TensorFlow.js 在网络上使用 JAX》,对如何将 JAX 模型转换为 TFJS,并在网络应用中运行进行了详细讲解:
https://blog.tensorflow.org/2022/08/jax-on-web-with-tensorflowjs.html
- 本篇文章演示了如何将 Flax/JAX 模型转换为 TFLite,并在原生 Android 应用中运行该模型。
将目光转回到棋盘游戏。为了实现强化学习 agent,我们将会利用与之前相同 OpenAI gym 环境。这次,我们将使用 Flax/JAX 训练相同策略梯度模型。回想一下,在数学层面上策略梯度定义是:
- OpenAI gymhttps://github.com/tensorflow/examples/tree/master/lite/examples/reinforcement_learning/ml/tf_and_jax/gym_planestrike/gym_planestrike/envs
- T: 每段时步数,各段时步数可能有所不同
- st: 时步上状态 t
- at: 时步上所选操作 t 指定状态 s
- πθ: 参数为 θ 策略
- R(*): 在指定策略下,收集到奖励
“””Neural network to predict the next strike position.”””
@nn.compact
def __call__(self, x):
dtype = jnp.float32
x = x.reshape((x.shape[0], -1))
x = nn.Dense(
features=2 * common.BOARD_SIZE**2, name=’hidden1′, dtype=dtype)(
x)
x = nn.relu(x)
x = nn.Dense(features=common.BOARD_SIZE**2, name=’hidden2′, dtype=dtype)(x)
x = nn.relu(x)
x = nn.Dense(features=common.BOARD_SIZE**2, name=’logits’, dtype=dtype)(x)
policy_probabilities = nn.softmax(x)
return policy_probabilities
predict_fn = functools.partial(run_inference, params)
board_log, action_log, result_log = common.play_game(predict_fn)
rewards = common.compute_rewards(result_log)
optimizer, params, opt_state = train_step(optimizer, params, opt_state,
board_log, action_log, rewards)在 train_step() 方法中,我们首先会使用轨迹计算损失,然后使用 jax.grad() 计算梯度,最后,使用 Optax (用于 JAX 梯度处理和优化库) 来更新模型参数。def compute_loss(logits, labels, rewards):
one_hot_labels = jax.nn.one_hot(labels, num_classes=common.BOARD_SIZE**2)
loss = -jnp.mean(
jnp.sum(one_hot_labels * jnp.log(logits), axis=-1) * jnp.asarray(rewards))
return loss
def train_step(model_optimizer, params, opt_state, game_board_log,
predicted_action_log, action_result_log):
“””Run one training step.”””
def loss_fn(model_params):
logits = run_inference(model_params, game_board_log)
loss = compute_loss(logits, predicted_action_log, action_result_log)
return loss
def compute_grads(params):
return jax.grad(loss_fn)(params)
grads = compute_grads(params)
updates, opt_state = model_optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return model_optimizer, params, opt_state
@jax.jit
def run_inference(model_params, board):
logits = PolicyGradient().apply({‘params’: model_params}, board)
return logits
- Optaxhttps://github.com/deepmind/optax
# Convert to tflite model
model = PolicyGradient()
jax_predict_fn = lambda input: model.apply({‘params’: params}, input)
if_predict = tf.function(
jax2tf.convert(jax_predict_fn, enable_xla=False),
input_signature=[
tf.TensorSpec(
shape=[1, common.BOARD_SIZE, common.BOARD_SIZE],
dtype=tf.float32,
name=’input’)
],
autograph=False,
)
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()], tf_predict)
tflite_model = converter.convert()
# Save the model
with open(os.path.join(modeldir, ‘planestrike.tflite’), ‘wb’) as f:
f.write(tflite_model)
- jax2tf
我们可以使用与之前完全一样 Java 代码来调用模型并获取预测结果。
convertBoardStateToByteBuffer(board);
tflite.run(boardData, outputProbArrays);
float[] probArray = outputProbArrays[0];
int agentStrikePosition = -1;
float maxProb = 0;
for (int i = 0; i < probArray.length; i++) {
int x = i / Constants.BOARD_SIZE;
int y = i % Constants.BOARD_SIZE;
if (board[x][y] == BoardCellStatus.UNTRIED && probArray[i] > maxProb) {
agentStrikePosition = i;
maxProb = probArray[i];
}
}