前言
本文适用对象:任何接触过 TensorFlow, Pytorch, Keras 并且已经开始了解或尝鲜 Jax 的人群。如果是没有接触过任何深度学习框架的人群,这篇文章可能不适合你。在开始学习之前,你应该对 PyTorch 或 TensorFlow 有一定的了解。Jax 可能是一个比较难学的库,但值得一学。为什么使用 Jax 的理由这里就不多赘述了。就我个人而言,tf 或者 torch 自定义损失函数训练的速度实在是不太满意,即使这过程中换用了 numba 仍然差强人意,加之我是 tf 的旧党,所以不难预料的投入了 Jax 的怀抱。
正文开始
如果是经常用 Tensorflow Keras 或者 Pytorch Lightning 的炼丹师,一定会喜欢 fit 这个方法。所以本文以实现一个简单且常用的 fit 方法来快速上手 Jax,而且实现的这个 fit 方法基本上可以复用在很多项目中。另外再次强调,这篇文章可能不适合入门,但是很适合快速上工(从删库到跑路)。
本文实现的 fit 方法需要安装如下依赖,如果你已经使用过 Jax ,基本以下依赖库想必都已经了解了。
- jax (jax, jaxlib)
- flax 定义你的模型
- optax 优化器,学习率,损失函数
- orbax 用于保存 checkpoints
- tqdm 显示进度条,用过多解释了
- tensorboardX 一个三方的 tensorboard 的 python 库,用于输出一些训练过程日志
本文也是用的这个经典组合:Jax + Flax + Optax + Orbax,硬件加速 + 网络结构 + 损失函数 + 保存储存点
快速上手
先看看一个训练模型的模板,但只需要修改脚本中的三个关键代码部分。
import jax, flax, optax, orbax
from fit import lr_schedule, TrainState
# 准备你自己的数据集
train_ds, test_ds = your_dataset()
# 学习率
lr_fn = lr_schedule(
base_lr=1e-3,
steps_per_epoch=len(train_ds),
epochs=100,
warmup_epochs=5,
)
# key 1: 你的模型
model = YourModel()
# 初始化 key 和你的模型
key = jax.random.PRNGKey(0)
x = jnp.ones((1, 28, 28, 1)) # MNIST 示例输入大小
# 注意这里 train=True, 区别模型的训练和评价模式
var = model.init(key, x, train=True)
# 固定模板,直接复制就能用
state = TrainState.create(
apply_fn=model.apply,
params=var['params'],
batch_stats=var['batch_stats'],
tx=optax.inject_hyperparams(optax.adam)(lr_fn),
)
# 你的训练函数,详情参考下个章节
@jax.jit
def train_step():
# key 2: 你的损失函数
def loss_fn():
...
return state, loss_dict, opt_state
# 你的评价函数
@jax.jit
def eval_step():
# key 3: 你的评价函数
...
return acc
# 一些必要的参数,epoches 之类
fit(state, train_ds, test_ds,
train_step=train_step,
eval_step=eval_step,
eval_freq=1,
num_epochs=10,
log_name='mnist',
)
使用方法
让我们从一个简单的例子开始,在 MNIST 数据集上训练一个模型。首先,在训练脚本中导入 fit 模块。
from fit import *
在训练之前,你需要定义模型、损失函数和评估函数。让我们从模型开始。
模型
下面是一个非常简单的模型示例。setup 函数用来定义模型结构,__call__ 函数定义模型的前向传播。
class Model(nn.Module):
def setup(self):
self.conv1 = nn.Conv(features=16, kernel_size=(3, 3))
self.dense1 = nn.Dense(features=10)
# train=False 用于评价模式
# 如果你使用了 dropout 或者 batch normalization 层
# 我打赌你会用到它
@nn.compact
def __call__(self, x, train=False):
# 简单的 conv + bn + relu + 全连接层
x = self.conv1(x)
x = nn.BatchNorm(use_running_average=not train)(x)
x = nn.relu(x)
# dropout 层
x = nn.Dropout(rate=0.5)(x, deterministic=not train)
# 展平
x = x.reshape((x.shape[0], -1))
x = self.dense1(x)
return x
接下来,你只需要考虑两件事:损失函数和评估函数。下面的 train_step 函数是训练模型的一个通用模板。state 对象是一个基于 TrainState 的对象的改进,其中不仅包含了模型参数、 Batch 状态和其他必要信息。batch 对应的是输入数据,opt_state 是优化器状态。
不要担心面对这个复杂的 train_step 函数,它只是一个模板。你可以复制并粘贴到你的脚本中,只需修改 loss_fn 函数即可。
@jax.jit
def train_step(state: TrainState, batch, opt_state):
x, y = batch
def loss_fn(params):
logits, updates = state.apply_fn({
'params': params,
'batch_stats': state.batch_stats
}, x, train=True, mutable=['batch_stats'], rngs={'dropout': key})
loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 10)).mean()
loss_dict = {'loss': loss}
return loss, (loss_dict, updates)
# gradient and update
(_, (loss_dict, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
state = state.apply_gradients(grads=grads, batch_stats=updates['batch_stats'])
# update optimizer state
_, opt_state = state.tx.update(grads, opt_state)
return state, loss_dict, opt_state
损失函数 is all you need
让我们把重点放在 loss_fn 函数上。让我们从伪 pytorch 风格的代码开始,这对理解 Jax 中 train_step 方法里的 loss_fn 函数很有帮助。
def loss_fn():
pred_y = model(x, train=False)
loss = criterion(pred_y, true_y)
return loss
很简单对吧?让我们继续。
def loss_fn(params):
pred_y, updates = state.apply_fn({'params': params}, x, train=True)
loss = criterion(pred_y, y_true)
# 方便将一些你需要的日志记录展示在 tensorboard 里
loss_dict = {'loss': loss}
return loss, (loss_dict, updates)
接下来,让我们为 loss_fn 函数添加更多细节,例如 batch state 和 dropout key。这是 train_step 函数中 loss_fn 函数的完整版本。
def loss_fn(params):
logits, updates = state.apply_fn({
'params': params,
'batch_stats': state.batch_stats
}, x, train=True, mutable=['batch_stats'], rngs={'dropout': key})
loss = optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 10)).mean()
loss_dict = {'loss': loss}
return loss, (loss_dict, updates)
既然这篇文章标题叫上工指南,mnist 这个例子显然就不太合适了,毕竟谁家上工不是得整一堆损失函数辅助配合了用的,所以需要在 jax.value_and_grad 函数中开启 has_aux=True 然后写成以下这个样子,并且为了将 log 输出到 tensorboard 里,所以这里用一个字典返回。
@jax.jit
def train_step(state: TrainState, batch, opt_state):
x, y = batch
def loss_fn(params):
logits, updates = state.apply_fn({
'params': params,
'batch_stats': state.batch_stats
}, x, train=True, mutable=['batch_stats'], rngs={'dropout': key})
loss_1 = optax.softmax_cross_entropy(logits, jax.nn.one_hot(y, 10)).mean()
loss_2 = jnp.mean(jnp.square(logits - y))
loss = loss_1 + loss_2
loss_dict = {'loss': loss, 'loss_1': loss_1, 'loss_2': loss_2}
return loss, (loss_dict, updates)
# gradient and update
(_, (loss_dict, updates)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
state = state.apply_gradients(grads=grads, batch_stats=updates['batch_stats'])
# update optimizer state
_, opt_state = state.tx.update(grads, opt_state)
return state, loss_dict, opt_state
需要注意的是,损失函数应返回总损失值和要记录到 tensorboard 的字典。
评价函数
现在,让我们继续使用伪 pytorch 风格代码来看看评估函数。
def eval_step():
true_x, true_y = data
model.eval()
pred_y = model(true_x)
# 你的计算准确度的函数
acc = metric(pred_y, true_y)
return acc
在 pytorch 中,可以使用 model.eval() 函数将模型切换到评估模式。因为在训练和评估模式中,BN 层和 Dropout 层的行为不同。在 Jax 中,你需要在 apply_fn 函数中设置 train=False 参数。需要注意的是,如果使用 BN 层和Dropout 层,模型结构在训练和评估模式下应该是不同的,请参阅模型部分的 __call__ 函数。
与 train_step 函数类似,只需要传入 state 对象和 batch 对象。
@jax.jit
def eval_step(state: TrainState, batch):
x, y = batch
logits = state.apply_fn({
'params': state.params,
'batch_stats': state.batch_stats,
}, x, train=False)
acc = jnp.equal(jnp.argmax(logits, -1), y).mean()
return acc
数据准备
TensorFlow Datasets
ds = tfds.load("mnist", split="train", as_supervised=True)
train_ds = ds.take(50000).map(lambda x, y: (x / 255, y))
Torchvision Datasets
ds = torchvision.datasets.MNIST(
root="data", train=True, download=True,
transform=torchvision.transforms.ToTensor()
)
train_ds = torch.utils.data.DataLoader(ds, batch_size=32, shuffle=True)
学习率
顺便说一下,lr_schedule 用于创建学习率函数,这是 TrainState 对象所必需的。当然,你也可以配置你偏好的 lr_schedule 或者直接用默认的 lr_schedule 。
lr_fn = lr_schedule(base_lr=1e-3,
steps_per_epoch=len(train_ds),
epochs=100,
warmup_epochs=5,
)
此外,你还可以定义自己的链式更新,详情请查看 optax 库。
state = TrainState.create(
apply_fn=model.apply,
params=var['params'],
batch_stats=var['batch_stats'],
# 链式组合
tx=optax.chain(
optax.clip_by_global_norm(1.0),
optax.adam(lr_fn)),
)
最后调用fit函数开始训练。
fit(state, train_ds, test_ds,
train_step=train_step,
eval_step=eval_step,
# evaluate the model every N epochs (default 1)
eval_freq=1,
num_epochs=10,
# log name for tensorboard
log_name='mnist',
)
可视化训练过程
可以打开 Tensorboard 查看训练过程或检查任何损失和准确度指标。
Q&A
什么是 @jax.jit 装饰器?
这是一个将函数编译为单个静态函数的装饰器,可以在 GPU 或 TPU 上执行,如果你想加快训练过程,尤其是你自己的损失函数和评估函数,可以添加 @jax.jit 装饰器。
什么是 batch state 和 dropout key?
Batch State 用于存储批处理归一化统计数据,而 Dropout Key 用于生成 Dropout 层的随机掩码。
完整的代码在我的 github 上⬇️,欢迎 fork 和 star。
标签:loss,Jax,fit,batch,state,train,上工,fn,函数 From: https://blog.51cto.com/u_16989134/11883344