首页 > 其他分享 >NLP应用 | 保存checkpoint模型

NLP应用 | 保存checkpoint模型

时间:2023-07-07 21:33:30浏览次数:36  
标签:NLP 模型 counter 保存 keep checkpoint path save

需求描述:

当我们训练模型的时候,我们要训练很多训练步数,我们想要保存训练到一定阶段的checkpoint模型参数,并把这些checkpoint模型保存到一个指定的文件夹下。在文件夹下我们最多保存keep_checkpoint_max个checkpoint模型的文件。保存到output文件夹下。每save_checkpoint_steps步去保存一次。

如果保存的checkpoint模型已经达到最大数量,那么就把最早保存的文件删除,然后在保存现在的checkpoint模型的文件。

文件名是后面是保存的第几次。

代码梳理

首先我们定义一个checkpoint模型保存的函数:

def save_checkpoint(step, epoch, model, optimizer, params):
    if dist.get_rank() == 0:
        state = {
            "step": step,
            "epoch": epoch,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict()
        }
        utils.save(state, params.output, params.keep_checkpoint_max)

我们定义了一个保存模型的函数,需要传入的参数为训练步数(step)、数据集训练次数(epoch)、模型(model)、优化器(optimizer)、参数集(params)。

我们定义了需要保存的信息字典:

state = {
    "step": step,
    "epoch": epoch,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict()
}

字典里保存了训练步数、epoch数、模型参数和优化器参数。

然后传给我们自己的一个工具类的保存函数utils.save()

接下来我们看一下工具包保存checkpoint模型的实现。

import os
import glob
import torch


def oldest_checkpoint(path):
    names = glob.glob(os.path.join(path, "*.pt"))

    if not names:
        return None

    oldest_counter = 10000000
    checkpoint_name = names[0]

    for name in names:
        counter = name.rstrip(".pt").split("-")[-1]

        if not counter.isdigit():
            continue
        else:
            counter = int(counter)

        if counter < oldest_counter:
            checkpoint_name = name
            oldest_counter = counter

    return checkpoint_name


def latest_checkpoint(path):
    names = glob.glob(os.path.join(path, "*.pt"))

    if not names:
        return None

    latest_counter = 0
    checkpoint_name = names[0]

    for name in names:
        counter = name.rstrip(".pt").split("-")[-1]

        if not counter.isdigit():
            continue
        else:
            counter = int(counter)

        if counter > latest_counter:
            checkpoint_name = name
            latest_counter = counter

    return checkpoint_name


def save(state, path, max_to_keep=None):
    checkpoints = glob.glob(os.path.join(path, "*.pt"))

    if max_to_keep and len(checkpoints) >= max_to_keep:
        checkpoint = oldest_checkpoint(path)
        os.remove(checkpoint)

    if not checkpoints:
        counter = 1
    else:
        checkpoint = latest_checkpoint(path)
        counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1

    checkpoint = os.path.join(path, "model-%d.pt" % counter)
    print("Saving checkpoint: %s" % checkpoint)
    torch.save(state, checkpoint)

我们首先来看一下save()函数的实现。

def save(state, path, max_to_keep=None):
    checkpoints = glob.glob(os.path.join(path, "*.pt"))

    if max_to_keep and len(checkpoints) >= max_to_keep:
        checkpoint = oldest_checkpoint(path)
        os.remove(checkpoint)

    if not checkpoints:
        counter = 1
    else:
        checkpoint = latest_checkpoint(path)
        counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1

    checkpoint = os.path.join(path, "model-%d.pt" % counter)
    print("Saving checkpoint: %s" % checkpoint)
    torch.save(state, checkpoint)

刚刚我们从save_checkpoint函数中传入到save函数三个参数,我们一个个看一下。

  • state:需要保存的信息,类型是字典类型的数据
  • path:我们在命令行输入的output路径,用来保存模型的路径
  • max_to_keep:keep_checkpoint_max,这个参数的作用就是在文件夹下我们最多保存keep_checkpoint_max个checkpoint模型的文件

save函数的流程:

第一,我们先查看一下这个文件夹下有多少.pt结尾的文件,以列表的方式保存到checkpoints变量中。

checkpoints = glob.glob(os.path.join(path, "*.pt"))

第二,如果传入了max_to_keep参数,并且文件夹中目前的checkpoint模型的文件大于或者等于最大达到保存的文件数时,我们寻找文件夹下最先保存的checkpoint模型的文件,然后删除这个文件。如果没有超过,这段不执行。

if max_to_keep and len(checkpoints) >= max_to_keep:
    checkpoint = oldest_checkpoint(path)
    os.remove(checkpoint)

第三,如果最开始文件夹下没有任何checkpoint模型的文件,文件计数(counter)加一。如果有文件的话,找到最新保存的checkpoint模型文件的文件名,提取文件名中的数字,然后加一,作为当前保存的文件的数字尾缀。

if not checkpoints:
    counter = 1
else:
    checkpoint = latest_checkpoint(path)
    counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1

第四,拼接路径和文件的名字,传给checkpoint变量。

checkpoint = os.path.join(path, "model-%d.pt" % counter)
print("Saving checkpoint: %s" % checkpoint)

第五,使用pytorch的torch.save()函数进行模型和训练相关参数的保存。

torch.save(state, checkpoint)

上面save函数调用了两个函数:

  • oldest_checkpoint:返回最早保存的checkpoint模型文件的文件名
  • latest_checkpoint:返回最新保存的checkpoint模型文件的文件名

自己仔细看代码实现逻辑十分好懂,自己看一下吧。

到这里我们已经知道模型如何保存的实现了,上面需求描述的也大都实现了,但是缺一个训练多少步进行调用这个函数,在训练的过程中,如下代码所示:

if step % params.save_checkpoint_steps == 0:
    save_checkpoint(step, epoch, model, optimizer, params)

代码的意思是,当训练步数对多少步数保存一次的参数(save_checkpoint_steps)进行取余,如果为零,表示save_checkpoint_steps步训练到了,需要保存了,然后执行我们实现的save_checkpoint函数对模型的checkpoint进行保存。

代码中用到的参数来源:

  • 一部分是执行命令的时候用户传入的
  • 一部分是代码设置的默认参数,这些参数也可以在命令行进行指定

总结:

  • 我们这样是为了需要像文中的需求进行具体的代码解决方法,这些代码实现是正确的,只需要用户在自己的项目中把这些代码设计到合适的位置。我仅在文中进行了保存checkpoint文件的思路梳理。

标签:NLP,模型,counter,保存,keep,checkpoint,path,save
From: https://www.cnblogs.com/zhangxuegold/p/17536111.html

相关文章

  • 可视化模型地址
    https://github.com/zhangti0708/bigdata-exampleshttps://github.com/iGaoWei/BigDataView......
  • AI重塑千行百业,华为云发布盘古大模型3.0和昇腾AI云服务
    【中国,东莞,2023年7月7日】华为开发者大会2023(Cloud)7月7日在中国东莞正式揭开帷幕,并同时在全球10余个国家、中国30多个城市设有分会场,邀请全球开发者共聚一堂,就AI浪潮之下的产业新机会和技术新实践开展交流分享。在7日下午举行的大会主题演讲中,华为常务董事、华为云CEO张平安重磅......
  • 我用numpy实现了GPT-2,GPT-2源码,GPT-2模型加速推理,并且可以在树莓派上运行,读了不少hung
     之前分别用numpy实现了mlp,cnn,lstm和bert模型,这周顺带搞一下GPT-2,纯numpy实现,最重要的是可在树莓派上或其他不能安装pytorch的板子上运行,生成数据gpt-2的mask-multi-headed-self-attention我现在才彻底的明白它是真的牛逼,比bert的multi-headed-self-attention牛的不是一点半点,......
  • 道德与社会问题简报 #4: 文生图模型中的偏见
    简而言之:我们需要更好的方法来评估文生图模型中的偏见介绍文本到图像(TTI)生成现在非常流行,成千上万的TTI模型被上传到HuggingFaceHub。每种模态都可能受到不同来源的偏见影响,这就引出了一个问题:我们如何发现这些模型中的偏见?在当前的博客文章中,我们分享了我们对TT......
  • 大模型复现实践记录-在linux环境4090GPU(24G)
    chatglm-6btiger-7b......
  • 【视频】决策树模型原理和R语言预测心脏病实例
    全文链接:https://tecdat.cn/?p=33128原文出处:拓端数据部落公众号分析师:YudongWan决策树模型简介决策树模型是一种非参数的有监督学习方法,它能够从一系列有特征和标签的数据中总结出决策规则,并用树状图的结构来呈现这些规则,以解决分类和回归问题。与传统的线性回归模型不同,决......
  • ARIMA模型,ARIMAX模型预测冰淇淋消费时间序列数据|附代码数据
    全文下载链接:http://tecdat.cn/?p=22511最近我们被客户要求撰写关于ARIMAX的研究报告,包括一些图形和统计输出。标准的ARIMA(移动平均自回归模型)模型允许只根据预测变量的过去值进行预测。该模型假定一个变量的未来的值线性地取决于其过去的值,以及过去(随机)影响的值。ARIMAX模型......
  • 基础大模型能像人类一样标注数据吗?
    自从ChatGPT出现以来,我们见证了大语言模型(LLM)领域前所未有的发展,尤其是对话类模型,经过微调以后可以根据给出的提示语(prompt)来完成相关要求和命令。然而,直到如今我们也无法对比这些大模型的性能,因为缺乏一个统一的基准,难以严谨地去测试它们各自的性能。评测我们发给它们......
  • Qt+opencv dnn模块调用tensorflow模型
    参考网址(1条消息)Qt+opencvdnn模块调用tensorflow模型_vsqt调用tensorflow_街道口扛把子的博客-CSDN博客代码地址:GitHub-Whu-wxy/Simple_Qt_opencv_dnn:UsingdeeplearningmodelwithopencvinQt修改运行后的代码如下:#include<QCoreApplication>#include<opencv2\o......
  • python基础day39 生产者消费者模型和线程相关
    如何查看进程的id号进程都有几个属性:进程名、进程id号(pid--->processid)每个进程都有一个唯一的id号,通过这个id号就能找到这个进程importosimporttimedeftask():print("task中的子进程号:",os.getpid())print("主进程中的进程号:",os.getppid())#parent......