以keras分类猫狗数据(中)使用CNN分类模型为例,其中的部分代码如下:
#……
train_pic_gen=ImageDataGenerator(rescale=1./255,rotation_range=20,width_shift_range=0.2,
height_shift_range=0.2,shear_range=0.2,zoom_range=0.5,horizontal_flip=True,
fill_mode='nearest')
#……
train_flow=train_pic_gen.flow_from_directory(train_dir,(128,128),batch_size=32,class_mode='binary')
#……
model.fit_generator(
morph.train_flow,steps_per_epoch=100,epochs=50,verbose=1,validation_data=morph.test_flow,validation_steps=100,
callbacks=[TensorBoard(log_dir='./logs/1')]
)
#……
- 执行
fit_generator
时,由train_flow
数据流返回32(train_flow的batch_size的参数)张经过随机变形的样本,作为一个batch训练模型, - 重复这一过程100(fit_generator的steps_per_epoch参数)次,一个epoch结束。一个epoch所用样本batch_size乘以steps_per_epoch。
- 当epoch=50(fit_generator的epochs参数)时,模型训练结束。
此外,根据官方文档:
- fit_generator的steps_per_epoch的建议值为
样本总量除以train_flow的batch_size
。 - fit_generator的steps_per_epoch,如果未指定(
None
),则fit_generator的steps_per_epoch等于train_flow的batch_size。
源码参考:/Lib/site-packages/keras/engine/training.py
。
参考上文:
keras:ImageDataGenerator的flow方法
keras:ImageDataGenerator的flow_from_directory方法