首页 > 其他分享 >昇思MindSpore进阶教程--故障恢复

昇思MindSpore进阶教程--故障恢复

时间:2024-10-18 10:18:56浏览次数:8  
标签:进阶 -- self epoch CheckPoint ckpt num lenet MindSpore

大家好,我是刘明,明志科技创始人,华为昇思MindSpore布道师。
技术上主攻前端开发、鸿蒙开发和AI算法研究。
努力为大家带来持续的技术分享,如果你也喜欢我的文章,就点个关注吧

概述

模型训练过程中,可能会遇到故障。重新启动训练,各种资源的开销是巨大的。为此MindSpore提供了故障恢复的方案,即周期性保存模型参数,使得模型在故障发生处快速恢复并继续训练。 MindSpore以step或epoch为周期保存模型参数。模型参数保存在CheckPoint(简称ckpt)文件中。模型训练期间,发生故障,载入最新保存的模型参数,恢复在此处的状态,继续训练。

数据和模型准备

为了提供完整的体验,这里使用MNIST数据集和LeNet5网络模拟故障恢复的过程,如已准备好,可直接跳过本章节。

数据准备

下载MNIST数据集,并解压数据集到项目目录。

from download import download

url = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
      "notebook/datasets/MNIST_Data.zip"
path = download(url, "./", kind="zip", replace=True)

模型定义

import os

import mindspore
from mindspore.common.initializer import Normal
from mindspore.dataset import MnistDataset, vision
from mindspore import nn
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, Callback
import mindspore.dataset.transforms as transforms

mindspore.set_context(mode=mindspore.GRAPH_MODE)


# 创建训练数据集
def create_dataset(data_path, batch_size=32):
    train_dataset = MnistDataset(data_path, shuffle=False)
    image_transfroms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Resize(size=(32, 32)),
        vision.HWC2CHW()
    ]
    train_dataset = train_dataset.map(image_transfroms, input_columns='image')
    train_dataset = train_dataset.map(transforms.TypeCast(mindspore.int32), input_columns='label')
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
    return train_dataset


# 加载训练数据集
data_path = "MNIST_Data/train"
train_dataset = create_dataset(data_path)

# 模拟训练过程中发生故障
class myCallback(Callback):
    def __init__(self, break_epoch_num=6):
        super(myCallback, self).__init__()
        self.epoch_num = 0
        self.break_epoch_num = break_epoch_num

    def on_train_epoch_end(self, run_context):
        self.epoch_num += 1
        if self.epoch_num == self.break_epoch_num:
            raise Exception("Some errors happen.")


class LeNet5(nn.Cell):
    def __init__(self, num_class=10, num_channel=1):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode="valid")
        self.conv2 = nn.Conv2d(6, 16, 5, pad_mode="valid")
        self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
        self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
        self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
        self.relu = nn.ReLU()
        self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten()

    def construct(self, x):
        x = self.max_pool2d(self.relu(self.conv1(x)))
        x = self.max_pool2d(self.relu(self.conv2(x)))
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = LeNet5()  # 模型初始化
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")  # 损失函数
optim = nn.Momentum(net.trainable_params(), 0.01, 0.9)  # 优化器
model = Model(net, loss_fn=loss, optimizer=optim)  # Model封装

周期性保存CheckPoint文件

配置CheckpointConfig

mindspore.train.CheckpointConfig 中可根据迭代的次数进行配置,配置迭代策略的参数如下:

  • save_checkpoint_steps :表示每隔多少个step保存一个CheckPoint文件,默认值为1。

  • keep_checkpoint_max :表示最多保存多少个CheckPoint文件,默认值为5。

在迭代策略脚本正常结束的情况下,会默认保存最后一个step的CheckPoint文件。

模型训练过程中,使用 Model.train 里面的 callbacks 参数传入保存模型的对象 ModelCheckpoint (与 mindspore.train.CheckpointConfig 配合使用),可以周期性地保存模型参数,生成CheckPoint文件。

用户自定义保存数据

CheckpointConfig 的参数 append_info 可以在CheckPoint文件中保存用户自定义信息。append_info 支持传入 epoch_num 、 step_num 和字典类型数据。epoch_num 和 step_num 可以在CheckPoint文件中保存训练过程中的epoch数和step数。 字典类型数据的 key 必须是string类型,value 必须是int、float、bool、string、Parameter或Tensor类型。

# 用户自定义保存的数据
append_info = ["epoch_num", "step_num", {"lr": 0.01, "momentum": 0.9}]
# 数据下沉模式下,默认保存最后一个step的CheckPoint文件
config_ck = CheckpointConfig(append_info=append_info)
# 保存的CheckPoint文件前缀是"lenet",文件保存在"./lenet"路径下
ckpoint_cb = ModelCheckpoint(prefix='lenet', directory='./lenet', config=config_ck)

# 模拟程序故障,默认是在第6个epoch结束后故障
my_callback = myCallback()

# 数据下沉模式下,使用Model.train进行10个epoch的训练
model.train(10, train_dataset, callbacks=[ckpoint_cb, my_callback], dataset_sink_mode=True)

自定义脚本找到最新的CheckPoint文件

程序在第6个epoch结束后发生故障。故障发生后,./lenet 目录下保存了最新生成的5个epoch的CheckPoint文件。

└── lenet
     ├── lenet-graph.meta  # 编译后的计算图
     ├── lenet-2_1875.ckpt  # CheckPoint文件后缀名为'.ckpt'
     ├── lenet-3_1875.ckpt  # 文件的命名方式表示保存参数所在的epoch和step数,这里为第3个epoch的第1875个step的模型参数
     ├── lenet-4_1875.ckpt
     ├── lenet-5_1875.ckpt
     └── lenet-6_1875.ckpt

用户可以使用自定义脚本找到最新保存的CheckPoint文件。

ckpt_path = "./lenet"
filenames = os.listdir(ckpt_path)
# 筛选所有的CheckPoint文件名
ckptnames = [ckpt for ckpt in filenames if ckpt.endswith(".ckpt")]
# 按照创建顺序从旧到新对CheckPoint文件名进行排序
ckptnames.sort(key=lambda ckpt: os.path.getctime(ckpt_path + "/" + ckpt))
# 获取最新的CheckPoint文件路径
ckpt_file = ckpt_path + "/" + ckptnames[-1]

恢复训练

加载CheckPoint文件

使用 load_checkpoint 和 load_param_into_net 方法加载最新保存的CheckPoint文件。

  • load_checkpoint 方法会把CheckPoint文件中的网络参数加载到字典param_dict中。

  • load_param_into_net 方法会把字典param_dict中的参数加载到网络或者优化器中,加载后网络中的参数就是CheckPoint文件中保存的。

# 将模型参数加载到param_dict中,这里加载的是训练过程中保存的模型参数和用户自定义保存的数据
param_dict = mindspore.load_checkpoint(ckpt_file)
net = LeNet5()
# 将参数加载模型中
mindspore.load_param_into_net(net, param_dict)

获取用户自定义数据

用户可以从CheckPoint文件中获取训练时的epoch数和自定义保存的数据。注意,此时获取的数据是Parameter类型。

epoch_num = int(param_dict["epoch_num"].asnumpy())
step_num = int(param_dict["step_num"].asnumpy())
lr = float(param_dict["lr"].asnumpy())
momentum = float(param_dict["momentum"].asnumpy())

设置继续训练的epoch

向 Model.train 的 initial_epoch 参数传入获取的epoch数,网络即可从该epoch继续训练。此时,Model.train 的 epoch 参数表示训练的最后一个epoch数。

model.train(10, train_dataset, callbacks=ckpoint_cb, initial_epoch=epoch_num, dataset_sink_mode=True)

训练结束

训练结束, ./lenet 目录下新生成4个CheckPoint文件。根据CheckPoint文件名可以看出,在故障发生后,模型重新在第7个epoch进行训练,并在第10个epoch结束。故障恢复成功。

└── lenet
     ├── lenet-graph.meta
     ├── lenet-2_1875.ckpt
     ├── lenet-3_1875.ckpt
     ├── lenet-4_1875.ckpt
     ├── lenet-5_1875.ckpt
     ├── lenet-6_1875.ckpt
     ├── lenet-1-7_1875.ckpt
     ├── lenet-1-8_1875.ckpt
     ├── lenet-1-9_1875.ckpt
     ├── lenet-1-10_1875.ckpt
     └── lenet-1-graph.meta

标签:进阶,--,self,epoch,CheckPoint,ckpt,num,lenet,MindSpore
From: https://blog.csdn.net/weixin_42553583/article/details/143034883

相关文章

  • π TIKI派::TikTok公会邀约系统:你的主播管理利器!
    嘿,大家好,今天我要跟你们分享一个超级实用的工具——πTIKI派TikTok公会邀约系统!这个系统不仅可以让老板们一键分派主播,还能让员工随时随地通过手机轻松管理并认领主播,极大提高了工作效率,真是太方便了!......
  • 软考中级(软件设计师)必备知识解读——第二章:​程序设计语言
    第二章程序设计语言程序设计语言的基本概念解释器:翻译源程序时不生产独立的目标程序。解释程序和源程序要参与到程序的运行过程中。编译器:翻译时将源程序翻译成独立保存的目标程序。机器上运行的是与源程序等价的目标程序,源程序和编译程序都不再参与目标程序的运行......
  • 软考中级(软件设计师)必备知识解读——第五章:软件工程
    第五章软件工程软件过程1.能力成熟度模型(CMM)CMM将软件过程改进分为以下5个成熟度级别:1)初始级(最低成熟度)软件过程的特点是杂乱无章,有时甚至很混乱,几乎没有明确定义的步骤,项目的成功完全依赖个人的努力和英雄式核心人物的作用。2)可重复级建立了基本的项目管理过程和......
  • docker-certbot-dnspod 使用 Docker 申请、续期免费证书
    项目地址https://github.com/chenlongqiang/docker-certbot-dnspod背景近期免费证书有效期从1年缩短到3个月,避免经常要上云平台手动申请,所以想找个工具可以简单的申请、续期证书。通过了解,发现Certbot工具,但官方没提供Dnspod插件,于是找了Python3的封装并打包成......
  • Linux中文件的读写过程
    文件的读取过程在Linux系统中,读取文件的过程主要由操作系统内核通过文件系统与存储设备的交互来完成。以下是文件读取过程的详细步骤:1.系统调用阶段当用户程序(如cat、less)请求读取文件时,会调用系统调用(如open()或read())来请求访问文件。这些调用会传递文件路径等参数给内......
  • GEE 教程:Landsat TOA数据计算地表温度(LST)
    目录简介函数expression(expression, map)Arguments:Returns: ImagereduceRegion(reducer, geometry, scale, crs, crsTransform, bestEffort, maxPixels, tileScale)Arguments:Returns: Dictionary代码结果简介地表温度(LandSurfaceTemperature,LST)指......
  • 类与对象基础练习_学生登记
    要求定义一个Student类,包含以下要求:实例字段:name(名字),age(年龄),id(学生编号,自动生成,从1000开始)静态字段:studentCount(静态变量,用于统计总学生数),nextId(下一个学生编号)构造器:接受名字和年龄作为参数,在构造器中应进行总学生数的统计,并为学生分配唯一的id实例方法:introduce(),当调用该......
  • 算法与数据结构——桶排序
    桶排序前面的快速排序、归并排序、堆排序等都是属于“基于比较的排序算法”,它们通过比较元素间的大小来实现排序。此类排序算法的时间复杂度无法超越O(nlogn)。下面介绍几种“非比较排序算法”,它们的时间复杂度可以达到线性阶。桶排序(bucketsort)是分治策略的一个典型应用。它通......
  • 智能高效,智慧监管:EasyCVR视频汇聚平台助力煤矿构建一体化视频监控系统
    随着物联网、大数据、云计算等技术的快速发展,智慧化转型已成为煤矿行业提升生产效率、保障安全的重要途径。煤矿生产环境复杂多变,存在高温、低氧、多尘、黑暗等不利因素,给传统的人工巡检和管理方式带来了极大的挑战。EasyCVR视频汇聚平台作为智慧煤矿建设的重要组成部分,凭借其强大......
  • vue,xlsx,xlsx-style,file-saver,生成Excel并导出,cptable报错,合并单元格 样式缺失
    一,安装依赖 二,导入依赖import*asXLSXfrom'xlsx';import*asXLSX_STYLEfrom'xlsx-style'import{saveAs}from'file-saver';三,解决引入xlsx-style./cptable模块找不到问题Thisrelativemodulewasnotfound:*./cptablein./node_modules......