首页 > 其他分享 >180205 Keras回调函数Callback举例

180205 Keras回调函数Callback举例

时间:2023-03-03 10:32:35浏览次数:40  
标签:loss plt Keras 180205 Callback train test model self


  • 调用LambdaCallback
  • 180205 Keras回调函数Callback举例_自定义

  • 调用History
  • 180205 Keras回调函数Callback举例_H2_02

  • 自定义Callback类+调用tensorboard的程序结果
runfile('F:/180204/NoisyLabelCode/noisy_labels27Code/mnist-mlp.py', wdir='F:/180204/NoisyLabelCode/noisy_labels27Code')
60000 train samples
10000 test samples
On_train_begin
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
H1 (Dense) (None, 500) 392500
_________________________________________________________________
DO1 (Dropout) (None, 500) 0
_________________________________________________________________
H2 (Dense) (None, 300) 150300
_________________________________________________________________
DO2 (Dropout) (None, 300) 0
_________________________________________________________________
output (Dense) (None, 10) 3010
=================================================================
Total params: 545,810
Trainable params: 545,810
Non-trainable params: 0
_________________________________________________________________
None
0 0.44240659469 <class 'int'> <class 'numpy.float64'>
E:\Anaconda\envs\py3\envs\tf-gpu\lib\site-packages\matplotlib\axes\_axes.py:545: UserWarning: No labelled objects found. Use label='...' kwarg on individual plots.
warnings.warn("No labelled objects found. "
1 0.186225538524 <class 'int'> <class 'numpy.float64'>
2 0.142999434288 <class 'int'> <class 'numpy.float64'>
3 0.119528411309 <class 'int'> <class 'numpy.float64'>
4 0.102810428786 <class 'int'> <class 'numpy.float64'>
5 0.0908560588837 <class 'int'> <class 'numpy.float64'>
6 0.0802998322467 <class 'int'> <class 'numpy.float64'>
7 0.0760480070591 <class 'int'> <class 'numpy.float64'>
8 0.0702064224124 <class 'int'> <class 'numpy.float64'>
9 0.0658400574287 <class 'int'> <class 'numpy.float64'>
10 0.0599815090001 <class 'int'> <class 'numpy.float64'>
11 0.0561602519502 <class 'int'> <class 'numpy.float64'>
12 0.0545255301515 <class 'int'> <class 'numpy.float64'>
13 0.0524513412038 <class 'int'> <class 'numpy.float64'>
14 0.0493695429226 <class 'int'> <class 'numpy.float64'>
15 0.0493934159478 <class 'int'> <class 'numpy.float64'>
16 0.0447554209352 <class 'int'> <class 'numpy.float64'>
17 0.042964329419 <class 'int'> <class 'numpy.float64'>
18 0.0409662023197 <class 'int'> <class 'numpy.float64'>
19 0.0423117034843 <class 'int'> <class 'numpy.float64'>
20 0.0399761411309 <class 'int'> <class 'numpy.float64'>
21 0.0395201882392 <class 'int'> <class 'numpy.float64'>
Test loss: 0.0590692019651
Test accuracy: 0.9852
  • 源代码
# -*- coding: utf-8 -*-
from __future__ import print_function
"""
Created on Mon Feb 5 15:31:25 2018

@author: brucelau
"""

'''Trains a simple deep NN on the MNIST dataset.'''

import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.callbacks import EarlyStopping
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.callbacks import LambdaCallback
import numpy as np

batch_size = 256
num_classes = 10
epochs = 50
DROPOUT = 0.5
opt='adam'
patience = 4

# the data, shuffled and split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.reshape(60000, 784)
x_test = x_test.reshape(10000, 784)
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
#with tf.name_scope('input'):

with tf.name_scope('layer-1'):
model.add(Dense(500, activation='relu', input_shape=(784,),name='H1'))
model.add(Dropout(DROPOUT,name='DO1'))
with tf.name_scope('layer-2'):
model.add(Dense(300, activation='relu',name='H2'))
model.add(Dropout(DROPOUT,name='DO2'))
with tf.name_scope('output'):
model.add(Dense(num_classes, activation='softmax',name='output'))

model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])

#%%
# callback on_train_run as a Class
class Mylogger(keras.callbacks.Callback):
def on_train_begin(self,logs=None):
print('On_train_begin')
# model.summary()
print(keras.utils.layer_utils.print_summary(self.model))

# callback loss-show
show_loss_callback = LambdaCallback(on_epoch_end = lambda epoch,logs:
print(epoch,logs['loss'],type(epoch),type(logs['loss'])))

# callback loss-plot
def vis(e,l):
plt.figure(1)
plt.scatter(e,l)
plt.xlabel('epochs')
plt.ylabel('train-accuracy')
plt.legend()
plt.title('The training process')

plot_loss_callback = LambdaCallback(on_epoch_end = lambda epoch,logs:
vis(epoch,logs['loss']))

# recording loss history
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = []
self.val_losses = []

def on_epoch_end(self, epoch, logs={}):
self.losses.append(logs['loss'])
self.val_losses.append(logs['val_loss'])
def vis_losss(self):
plt.figure(2)
plt.plot(np.arange(len(self.losses)),self.losses,label='losses')
plt.plot(np.arange(len(self.val_losses)),self.val_losses,label='val_losses')
plt.xlabel('epochs')
plt.ylabel('train-accuracy')
plt.legend()
plt.title('The training process')
#%%
history = LossHistory()

# callback tensorboard
tbCallBack = keras.callbacks.TensorBoard(log_dir='./Graph',
histogram_freq=0,
write_graph=True,
write_images=True)

#%%

model_history = model.fit(x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
verbose=0,
validation_data=(x_test, y_test),
callbacks = [Mylogger(),
tbCallBack,
EarlyStopping(patience=patience,mode='min',verbose=0),
show_loss_callback,
plot_loss_callback,
history])

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
history.vis_losss()


标签:loss,plt,Keras,180205,Callback,train,test,model,self
From: https://blog.51cto.com/guokliu/6098075

相关文章