首页 > 其他分享 >3、ModelCheckPoint

3、ModelCheckPoint

时间:2022-10-24 21:46:43浏览次数:53  
标签:loss ModelCheckPoint 缓存 val 模型 保存 save

1、导包

1 from tensorflow.keras.callbacks import ModelCheckpoint

2、介绍

  在训练机器学习模型时,经常需要缓存模型。

  ModelCheckpoint是Pytorch Lightning中的一个Callback,它就是用于模型缓存的。

  它会监视某个指标,每次指标达到最好的时候,它就缓存当前模型。

  在每个epoch结束作为回调函数,保存模型。

3、参数介绍

3.1、monitor='val_loss', 我们想要监视的指标 ,val_acc或val_loss。

3.2、dirpath='my/path/', 模型缓存目录

3.3、verbose: 详细信息模式,0 或者1。 0为不打印输出信息,1为打印

3.4、save_best_only: True,将只保存在验证集上性能最好的模型mode: {auto, min, max} 的其中之一。是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 

对于val_acc,模式就会是max;而对于val_loss,模式就需要是min。在auto模式中,方式会自动从被监测的数据的名字中判断出来。

3.5、save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。

3.6、period: 每个检查点之间的间隔(训练轮数)。

标签:loss,ModelCheckPoint,缓存,val,模型,保存,save
From: https://www.cnblogs.com/xiaoliang-333/p/16823083.html

相关文章