首页 > 其他分享 >keras:fit_generator的训练过程

keras:fit_generator的训练过程

时间:2022-10-27 13:37:35浏览次数:43  
标签:generator fit keras flow epoch train steps


以​​keras分类猫狗数据(中)使用CNN分类模型​​为例,其中的部分代码如下:

#……
train_pic_gen=ImageDataGenerator(rescale=1./255,rotation_range=20,width_shift_range=0.2,
height_shift_range=0.2,shear_range=0.2,zoom_range=0.5,horizontal_flip=True,
fill_mode='nearest')
#……
train_flow=train_pic_gen.flow_from_directory(train_dir,(128,128),batch_size=32,class_mode='binary')
#……
model.fit_generator(
morph.train_flow,steps_per_epoch=100,epochs=50,verbose=1,validation_data=morph.test_flow,validation_steps=100,
callbacks=[TensorBoard(log_dir='./logs/1')]
)
#……
  • 执行​​fit_generator​​​时,由​​train_flow​​ 数据流返回32(train_flow的batch_size的参数)张经过随机变形的样本,作为一个batch训练模型,
  • 重复这一过程100(fit_generator的steps_per_epoch参数)次,一个epoch结束。一个epoch所用样本batch_size乘以steps_per_epoch。
  • 当epoch=50(fit_generator的epochs参数)时,模型训练结束。

此外,根据官方文档:

  • fit_generator的steps_per_epoch的建议值为​​样本总量除以train_flow的batch_size​​。
  • fit_generator的steps_per_epoch,如果未指定(​​None​​),则fit_generator的steps_per_epoch等于train_flow的batch_size。

源码参考:​​/Lib/site-packages/keras/engine/training.py​​​。
参考上文:
​​​keras:ImageDataGenerator的flow方法​​​
​​​keras:ImageDataGenerator的flow_from_directory方法​


标签:generator,fit,keras,flow,epoch,train,steps
From: https://blog.51cto.com/u_15847885/5800882

相关文章

  • keras离线官方文档
    keras中文文档:​​​https://keras.io/zh/​​​(官方)​​​http://keras-cn.readthedocs.io/en/latest/​​由于官方文档(更新似乎快点儿)经常访问不了,所以下载查看。步骤1......
  • keras中application模型可视化
    问题描述keras提供了模型的可视化,如下:importkerasfromkerasimportmodels,layers,Modelfromkeras.applicationsimportVGG16conv_base=VGG16(weights='imagenet',inc......
  • keras中的keras.utils.to_categorical方法
    ​​to_categorical(y,num_classes=None,dtype='float32')​​将整型标签转为onehot。y为​​int​​数组,num_classes为标签类别总数,大于max(y)(标签从0开始的)。返回:如果nu......
  • keras SegNet实现
    代码位置​​https://github.com/lsh1994/keras-segmentation​​​模型结构我这里用到了vgg16微调作为编码器,读者可以参照着自定义层对称的编解码结构。训练结果......
  • keras SegNet使用池化索引(pooling indices)
    keras中不能直接使用池化索引。最近学习到SegNet(网上许多错的,没有用池化索引),其中下采样上采样用到此部分。此处用到自定义层。完整测试代码如下。"""@author:LiShiHang@so......
  • keras FCN实现(2)
    FCN-8/FCN-16Add了底层特征。FCN-8的实现,承接​​上篇​​。代码位置:​​​https://github.com/lsh1994/keras-segmentation​​结构:训练曲线:可视化结果:......
  • anaconda 下安装tensorflow & keras
    首先,同胞们要记住,你要做什么?该怎么做?你的目标是什么?千万不要因为中间遇到的连带问题,而忘记了你要做什么?一下开始介绍:????下载:官网速度很慢,容易断线:https://www.......
  • Keras搭建CNN进行人脸识别系列(四)--为模型训练准备人脸数据
          机器学习最本质的地方就是基于海量数据统计的学习,说白了,机器学习其实就是在模拟人类儿童的学习行为。举一个简单的例子,成年人并没有主动教孩子学习语言,但随着......
  • Keras搭建CNN进行人脸识别系列(三)--利用haar级联检测器识别出人脸
    人脸识别原理        从实时视频流中识别出人脸区域,从原理上看,其依然属于机器学习的领域之一,本质上与谷歌利用深度学习识别出猫没有什么区别。程序通过大量的......
  • 使用KeyPairGenerator生成公私钥对(oppo平台需要)
    下面是oppo平台的生成公钥的要求公钥需要使用RSA算法生成,1024位,生成后使用Base64进行编码,编码后的长度是216位Base64使用了Apache的commons-codec工具包,这个......