需求描述:
当我们训练模型的时候,我们要训练很多训练步数,我们想要保存训练到一定阶段的checkpoint模型参数,并把这些checkpoint模型保存到一个指定的文件夹下。在文件夹下我们最多保存keep_checkpoint_max
个checkpoint模型的文件。保存到output
文件夹下。每save_checkpoint_steps
步去保存一次。
如果保存的checkpoint模型已经达到最大数量,那么就把最早保存的文件删除,然后在保存现在的checkpoint模型的文件。
文件名是后面是保存的第几次。
代码梳理
首先我们定义一个checkpoint模型保存的函数:
def save_checkpoint(step, epoch, model, optimizer, params):
if dist.get_rank() == 0:
state = {
"step": step,
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict()
}
utils.save(state, params.output, params.keep_checkpoint_max)
我们定义了一个保存模型的函数,需要传入的参数为训练步数(step)、数据集训练次数(epoch)、模型(model)、优化器(optimizer)、参数集(params)。
我们定义了需要保存的信息字典:
state = {
"step": step,
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict()
}
字典里保存了训练步数、epoch数、模型参数和优化器参数。
然后传给我们自己的一个工具类的保存函数utils.save()
。
接下来我们看一下工具包保存checkpoint模型的实现。
import os
import glob
import torch
def oldest_checkpoint(path):
names = glob.glob(os.path.join(path, "*.pt"))
if not names:
return None
oldest_counter = 10000000
checkpoint_name = names[0]
for name in names:
counter = name.rstrip(".pt").split("-")[-1]
if not counter.isdigit():
continue
else:
counter = int(counter)
if counter < oldest_counter:
checkpoint_name = name
oldest_counter = counter
return checkpoint_name
def latest_checkpoint(path):
names = glob.glob(os.path.join(path, "*.pt"))
if not names:
return None
latest_counter = 0
checkpoint_name = names[0]
for name in names:
counter = name.rstrip(".pt").split("-")[-1]
if not counter.isdigit():
continue
else:
counter = int(counter)
if counter > latest_counter:
checkpoint_name = name
latest_counter = counter
return checkpoint_name
def save(state, path, max_to_keep=None):
checkpoints = glob.glob(os.path.join(path, "*.pt"))
if max_to_keep and len(checkpoints) >= max_to_keep:
checkpoint = oldest_checkpoint(path)
os.remove(checkpoint)
if not checkpoints:
counter = 1
else:
checkpoint = latest_checkpoint(path)
counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1
checkpoint = os.path.join(path, "model-%d.pt" % counter)
print("Saving checkpoint: %s" % checkpoint)
torch.save(state, checkpoint)
我们首先来看一下save()
函数的实现。
def save(state, path, max_to_keep=None):
checkpoints = glob.glob(os.path.join(path, "*.pt"))
if max_to_keep and len(checkpoints) >= max_to_keep:
checkpoint = oldest_checkpoint(path)
os.remove(checkpoint)
if not checkpoints:
counter = 1
else:
checkpoint = latest_checkpoint(path)
counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1
checkpoint = os.path.join(path, "model-%d.pt" % counter)
print("Saving checkpoint: %s" % checkpoint)
torch.save(state, checkpoint)
刚刚我们从save_checkpoint
函数中传入到save
函数三个参数,我们一个个看一下。
- state:需要保存的信息,类型是字典类型的数据
- path:我们在命令行输入的output路径,用来保存模型的路径
- max_to_keep:keep_checkpoint_max,这个参数的作用就是在文件夹下我们最多保存
keep_checkpoint_max
个checkpoint模型的文件
save
函数的流程:
第一,我们先查看一下这个文件夹下有多少.pt
结尾的文件,以列表的方式保存到checkpoints
变量中。
checkpoints = glob.glob(os.path.join(path, "*.pt"))
第二,如果传入了max_to_keep
参数,并且文件夹中目前的checkpoint模型的文件大于或者等于最大达到保存的文件数时,我们寻找文件夹下最先保存的checkpoint模型的文件,然后删除这个文件。如果没有超过,这段不执行。
if max_to_keep and len(checkpoints) >= max_to_keep:
checkpoint = oldest_checkpoint(path)
os.remove(checkpoint)
第三,如果最开始文件夹下没有任何checkpoint模型的文件,文件计数(counter)加一。如果有文件的话,找到最新保存的checkpoint模型文件的文件名,提取文件名中的数字,然后加一,作为当前保存的文件的数字尾缀。
if not checkpoints:
counter = 1
else:
checkpoint = latest_checkpoint(path)
counter = int(checkpoint.rstrip(".pt").split("-")[-1]) + 1
第四,拼接路径和文件的名字,传给checkpoint
变量。
checkpoint = os.path.join(path, "model-%d.pt" % counter)
print("Saving checkpoint: %s" % checkpoint)
第五,使用pytorch的torch.save()
函数进行模型和训练相关参数的保存。
torch.save(state, checkpoint)
上面save
函数调用了两个函数:
- oldest_checkpoint:返回最早保存的checkpoint模型文件的文件名
- latest_checkpoint:返回最新保存的checkpoint模型文件的文件名
自己仔细看代码实现逻辑十分好懂,自己看一下吧。
到这里我们已经知道模型如何保存的实现了,上面需求描述的也大都实现了,但是缺一个训练多少步进行调用这个函数,在训练的过程中,如下代码所示:
if step % params.save_checkpoint_steps == 0:
save_checkpoint(step, epoch, model, optimizer, params)
代码的意思是,当训练步数对多少步数保存一次的参数(save_checkpoint_steps
)进行取余,如果为零,表示save_checkpoint_steps
步训练到了,需要保存了,然后执行我们实现的save_checkpoint
函数对模型的checkpoint进行保存。
代码中用到的参数来源:
- 一部分是执行命令的时候用户传入的
- 一部分是代码设置的默认参数,这些参数也可以在命令行进行指定
标签:NLP,模型,counter,保存,keep,checkpoint,path,save From: https://www.cnblogs.com/zhangxuegold/p/17536111.html总结:
- 我们这样是为了需要像文中的需求进行具体的代码解决方法,这些代码实现是正确的,只需要用户在自己的项目中把这些代码设计到合适的位置。我仅在文中进行了保存checkpoint文件的思路梳理。