JAX 中文文档(六)
高级教程
本节包含更高级主题的示例和教程,如多核计算、自定义操作及更深入的应用
示例
-
使用 tensorflow/datasets 进行简单神经网络训练
-
使用 PyTorch 数据加载进行简单神经网络训练
-
贝叶斯推断的自动批处理
并行计算
-
在多主机和多进程环境中使用 JAX
-
分布式数组和自动并行化
-
带有
shard_map
的 SPMD 多设备并行性 -
API 规范
-
集合教程
-
玩具示例
-
多主机/多进程环境中的分布式数据加载
-
带有
xmap
的命名轴和易于修改的并行性
自动微分
-
自动微分食谱
-
为可转换为 JAX 的 Python 函数编写自定义导数规则
-
使用
jax.checkpoint
(又名jax.remat
)控制自动微分的保存值
JAX 内部机制
-
JAX 原语的工作原理
-
在 JAX 中编写自定义 Jaxpr 解释器
-
使用 C++ 和 CUDA 为 GPU 的自定义操作
-
检查正确性
深入探讨
- JAX 中的广义卷积
训练一个简单的神经网络,使用 tensorflow/datasets
进行数据加载
原文:
jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html
从 neural_network_and_data_loading.ipynb
衍生
让我们结合我们在快速入门中展示的所有内容来训练一个简单的神经网络。我们将首先使用 JAX 在 MNIST 上指定和训练一个简单的 MLP 进行计算。我们将使用 tensorflow/datasets
数据加载 API 来加载图像和标签(因为它非常出色,世界上不需要再另外一种数据加载库