大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧
概述
模型训练过程中,可能会遇到故障。重新启动训练,各种资源的开销是巨大的。为此MindSpore提供了故障恢复的方案,即周期性保存模型参数,使得模型在故障发生处快速恢复并继续训练。 MindSpore以step或epoch为周期保存模型参数。模型参数保存在CheckPoint(简称ckpt)文件中。模型训练期间,发生故障,载入最新保存的模型参数,恢复在此处的状态,继续训练。
数据和模型准备
为了提供完整的体验,这里使用MNIST数据集和LeNet5网络模拟故障恢复的过程,如已准备好,可直接跳过本章节。
数据准备
下载MNIST数据集,并解压数据集到项目目录。
from download import download
url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
"notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)
模型定义
import os
import mindspore
from mindspore.common.initializer import Normal
from mindspore.dataset import MnistDataset, vision
from mindspore import nn
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, Callback
import mindspore.dataset.transforms as transforms
mindspore.set_context(mode=mindspore.GRAPH_MODE)
# 创建训练数据集
def create_dataset(data_path, batch_size=32):
train_dataset = MnistDataset(data_path, shuffle=False)
image_transfroms = [
vision.Rescale(1.0 / 255.0, 0),
vision.Resize(size=(32, 32)),
vision.HWC2CHW()
]
train_dataset = train_dataset.map(image_transfroms, input_columns='image')
train_dataset = train_dataset.map(transforms.TypeCast(mindspore.int32), input_columns='label')
train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
return train_dataset
# 加载训练数据集
data_path = "MNIST_Data/train"
train_dataset = create_dataset(data_path)
# 模拟训练过程中发生故障
class myCallback(Callback):
def __init__(self, break_epoch_num=6):
super(myCallback, self).__init__()
self.epoch_num = 0
self.break_epoch_num = break_epoch_num
def on_train_epoch_end(self, run_context):
self.epoch_num += 1
if self.epoch_num == self.break_epoch_num:
raise Exception("Some errors happen.")
class LeNet5(nn.Cell):
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode="valid")
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode="valid")
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
net = LeNet5() # 模型初始化
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") # 损失函数
optim = nn.Momentum(net.trainable_params(), 0.01, 0.9) # 优化器
model = Model(net, loss_fn=loss, optimizer=optim) # Model封装
周期性保存CheckPoint文件
配置CheckpointConfig
mindspore.train.CheckpointConfig 中可根据迭代的次数进行配置,配置迭代策略的参数如下:
-
save_checkpoint_steps :表示每隔多少个step保存一个CheckPoint文件,默认值为1。
-
keep_checkpoint_max :表示最多保存多少个CheckPoint文件,默认值为5。
在迭代策略脚本正常结束的情况下,会默认保存最后一个step的CheckPoint文件。
模型训练过程中,使用 Model.train 里面的 callbacks 参数传入保存模型的对象 ModelCheckpoint (与 mindspore.train.CheckpointConfig 配合使用),可以周期性地保存模型参数,生成CheckPoint文件。
用户自定义保存数据
CheckpointConfig 的参数 append_info 可以在CheckPoint文件中保存用户自定义信息。append_info 支持传入 epoch_num 、 step_num 和字典类型数据。epoch_num 和 step_num 可以在CheckPoint文件中保存训练过程中的epoch数和step数。 字典类型数据的 key 必须是string类型,value 必须是int、float、bool、string、Parameter或Tensor类型。
# 用户自定义保存的数据
append_info = ["epoch_num", "step_num", {"lr": 0.01, "momentum": 0.9}]
# 数据下沉模式下,默认保存最后一个step的CheckPoint文件
config_ck = CheckpointConfig(append_info=append_info)
# 保存的CheckPoint文件前缀是"lenet",文件保存在"./lenet"路径下
ckpoint_cb = ModelCheckpoint(prefix='lenet', directory='./lenet', config=config_ck)
# 模拟程序故障,默认是在第6个epoch结束后故障
my_callback = myCallback()
# 数据下沉模式下,使用Model.train进行10个epoch的训练
model.train(10, train_dataset, callbacks=[ckpoint_cb, my_callback], dataset_sink_mode=True)
自定义脚本找到最新的CheckPoint文件
程序在第6个epoch结束后发生故障。故障发生后,./lenet 目录下保存了最新生成的5个epoch的CheckPoint文件。
└── lenet
├── lenet-graph.meta # 编译后的计算图
├── lenet-2_1875.ckpt # CheckPoint文件后缀名为'.ckpt'
├── lenet-3_1875.ckpt # 文件的命名方式表示保存参数所在的epoch和step数,这里为第3个epoch的第1875个step的模型参数
├── lenet-4_1875.ckpt
├── lenet-5_1875.ckpt
└── lenet-6_1875.ckpt
用户可以使用自定义脚本找到最新保存的CheckPoint文件。
ckpt_path = "./lenet"
filenames = os.listdir(ckpt_path)
# 筛选所有的CheckPoint文件名
ckptnames = [ckpt for ckpt in filenames if ckpt.endswith(".ckpt")]
# 按照创建顺序从旧到新对CheckPoint文件名进行排序
ckptnames.sort(key=lambda ckpt: os.path.getctime(ckpt_path + "/" + ckpt))
# 获取最新的CheckPoint文件路径
ckpt_file = ckpt_path + "/" + ckptnames[-1]
恢复训练
加载CheckPoint文件
使用 load_checkpoint 和 load_param_into_net 方法加载最新保存的CheckPoint文件。
-
load_checkpoint 方法会把CheckPoint文件中的网络参数加载到字典param_dict中。
-
load_param_into_net 方法会把字典param_dict中的参数加载到网络或者优化器中,加载后网络中的参数就是CheckPoint文件中保存的。
# 将模型参数加载到param_dict中,这里加载的是训练过程中保存的模型参数和用户自定义保存的数据
param_dict = mindspore.load_checkpoint(ckpt_file)
net = LeNet5()
# 将参数加载模型中
mindspore.load_param_into_net(net, param_dict)
获取用户自定义数据
用户可以从CheckPoint文件中获取训练时的epoch数和自定义保存的数据。注意,此时获取的数据是Parameter类型。
epoch_num = int(param_dict["epoch_num"].asnumpy())
step_num = int(param_dict["step_num"].asnumpy())
lr = float(param_dict["lr"].asnumpy())
momentum = float(param_dict["momentum"].asnumpy())
设置继续训练的epoch
向 Model.train 的 initial_epoch 参数传入获取的epoch数,网络即可从该epoch继续训练。此时,Model.train 的 epoch 参数表示训练的最后一个epoch数。
model.train(10, train_dataset, callbacks=ckpoint_cb, initial_epoch=epoch_num, dataset_sink_mode=True)
训练结束
训练结束, ./lenet 目录下新生成4个CheckPoint文件。根据CheckPoint文件名可以看出,在故障发生后,模型重新在第7个epoch进行训练,并在第10个epoch结束。故障恢复成功。
└── lenet
├── lenet-graph.meta
├── lenet-2_1875.ckpt
├── lenet-3_1875.ckpt
├── lenet-4_1875.ckpt
├── lenet-5_1875.ckpt
├── lenet-6_1875.ckpt
├── lenet-1-7_1875.ckpt
├── lenet-1-8_1875.ckpt
├── lenet-1-9_1875.ckpt
├── lenet-1-10_1875.ckpt
└── lenet-1-graph.meta
标签:进阶,--,self,epoch,CheckPoint,ckpt,num,lenet,MindSpore
From: https://blog.csdn.net/weixin_42553583/article/details/143034883