首页 > 其他分享 >prettytensor 卷积神经网络mnist识别例子

prettytensor 卷积神经网络mnist识别例子

时间:2022-10-26 21:35:02浏览次数:54  
标签:卷积 pred tf num test prettytensor true mnist cls


代码

import prettytensor as pt
import time
from datetime import timedelta
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix

from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/MNIST/', one_hot=True)
data.test.cls = np.argmax(data.test.labels, axis=1)

img_size = 28
img_size_flat = img_size * img_size
img_shape = (img_size, img_size)
num_channels = 1
num_classes = 10

x = tf.placeholder(tf.float32, shape=[None, img_size_flat], name='x')
x_image = tf.reshape(x, [-1, img_size, img_size, num_channels])
x_pretty = pt.wrap(x_image)
y_true = tf.placeholder(tf.float32, shape=[None, 10], name='y_true')
y_true_cls = tf.argmax(y_true, dimension=1)

with pt.defaults_scope(activation_fn=tf.nn.relu):
y_pred, loss = x_pretty.conv2d(kernel=5, depth=16, name='layer_conv1') \
.max_pool(kernel=2, stride=2).conv2d(kernel=5, depth=36, name='layer_conv2') \
.max_pool(kernel=2, stride=2).flatten().fully_connected(size=128, name='layer_fc1') \
.softmax_classifier(num_classes=num_classes, labels=y_true)

optimizer = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss)
y_pred_cls = tf.argmax(y_pred, dimension=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

session = tf.Session()
session.run(tf.global_variables_initializer())
train_batch_size = 64

total_iterations = 0
def optimize(num_iterations):
global total_iterations
start_time = time.time()

for i in range(total_iterations,
total_iterations + num_iterations):
x_batch, y_true_batch = data.train.next_batch(train_batch_size)
feed_dict_train = {x: x_batch,
y_true: y_true_batch}
session.run(optimizer, feed_dict=feed_dict_train)
if i % 100 == 0:
acc = session.run(accuracy, feed_dict=feed_dict_train)
msg = "Optimization Iteration: {0:>6}, Training Accuracy: {1:>6.1%}"
print(msg.format(i + 1, acc))
total_iterations += num_iterations
end_time = time.time()
time_dif = end_time - start_time
print("Time usage: " + str(timedelta(seconds=int(round(time_dif)))))

test_batch_size = 256

def print_test_accuracy():
num_test = len(data.test.images)
cls_pred = np.zeros(shape=num_test, dtype=np.int)
i = 0
while i < num_test:
j = min(i + test_batch_size, num_test)
images = data.test.images[i:j, :]
labels = data.test.labels[i:j, :]
feed_dict = {x: images,
y_true: labels}
cls_pred[i:j] = session.run(y_pred_cls, feed_dict=feed_dict)
i = j
cls_true = data.test.cls
correct = (cls_true == cls_pred)
correct_sum = correct.sum()
acc = float(correct_sum) / num_test
msg = "Accuracy on Test-Set: {0:.1%} ({1} / {2})"
print(msg.format(acc, correct_sum, num_test))

plot_confusion_matrix(cls_pred)

plot_example_errors(cls_pred,correct)


def plot_confusion_matrix(cls_pred):

cls_true = data.test.cls
cm = confusion_matrix(y_true=cls_true,y_pred=cls_pred)
print(cm)

plt.matshow(cm)
plt.colorbar()
tick_marks = np.arange(num_classes)
plt.xticks(tick_marks, range(num_classes))
plt.yticks(tick_marks, range(num_classes))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()


def plot_example_errors(cls_pred, correct):

incorrect = (correct == False)
images = data.test.images[incorrect]
cls_pred = cls_pred[incorrect]
cls_true = data.test.cls[incorrect]

fig, axes = plt.subplots(3, 3)
fig.subplots_adjust(hspace=0.3, wspace=0.3)
for i, ax in enumerate(axes.flat):
ax.imshow(images[i].reshape(img_shape), cmap='binary')
xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])
ax.set_xlabel(xlabel)

ax.set_xticks([])
ax.set_yticks([])

plt.show()

optimize(num_iterations=5000)
print_test_accuracy()

效果

prettytensor 卷积神经网络mnist识别例子_2d


prettytensor 卷积神经网络mnist识别例子_2d_02


prettytensor 卷积神经网络mnist识别例子_2d_03


标签:卷积,pred,tf,num,test,prettytensor,true,mnist,cls
From: https://blog.51cto.com/u_15847885/5798579

相关文章