首页 > 其他分享 >CNN实现手写数字识别

CNN实现手写数字识别

时间:2022-11-16 19:36:58浏览次数:52  
标签:keras 28 tf add CNN 手写 识别 model accuracy


手写数字识别一致是一个机器学习里面常见的案例,今天通过CNN来实现一个手写数字识别来介绍一个机器学习的流程。

数据预处理

from keras import datasets
(x_train, y_train), (x_test, y_test) = datasets.mnist.load_data()
x_train = x_train.reshape((60000, 28, 28, 1))
# 归一化,0-255不太方便神经网络进行计算,因此将范围缩小到0—1
x_train = x_train.astype('float32') / 255
x_test = x_test.reshape((10000, 28, 28, 1))
x_test = x_test.astype('float32') / 255

构建模型

from keras import models,layers
from keras import backend as K
K.clear_session()
#初始化模型,可以通过add往里面加层
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), input_shape=(28, 28, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3) ))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(128, (3, 3)))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(10, activation='softmax'))
#查看模型结构
model.summary()

CNN实现手写数字识别_keras

模型训练

model.compile() 作用:
设置优化器、损失函数和准确率评测标准。

optimizer:
1.“sgd” 或者 tf.optimizers.SGD(lr = 学习率, decay = 学习率衰减率,momentum = 动量参数)

2.“adagrad” 或者 tf.keras.optimizers.Adagrad(lr = 学习率, decay = 学习率衰减率)

3.“adadelta” 或者 tf.keras.optimizers.Adadelta(lr = 学习率,decay = 学习率衰减率)

4.“adam” 或者 tf.keras.optimizers.Adam(lr = 学习率, decay = 学习率衰减率)
loss:
1.“mse” 或者 “mean squared error” 或 tf.keras.losses.MeanSquaredError()
2.“sparse_categorical_crossentropy” 或 tf.keras.losses.SparseCatagoricalCrossentropy(from_logits = False)
Metrics:
1.“accuracy” :
2.“sparse_accuracy":
3.“sparse_categorical_accuracy” :

CNN实现手写数字识别_深度学习_02

model.compile(optimizer='rmsprop',
loss='sparse_categorical_crossentropy', # 注意此处loss形式针对未作Onehot的分类标签
metrics=['accuracy'])
history = model.fit(x_train, y_train, epochs=2,
batch_size=64,validation_data =(x_test,y_test))
import pandas as pd
import matplotlib.pyplot as plt
dfhistory = pd.DataFrame(history.history)
dfhistory.index = range(1,len(dfhistory) + 1)
dfhistory.index.name = 'epoch'
dfhistory.to_csv('hitory_metrics',sep = '\t')
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc, 'b', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

CNN实现手写数字识别_keras_03

model.save('minst_model.h5')

我们在网上随便找一张图片实验

CNN实现手写数字识别_深度学习_04

from PIL import Image
import numpy as np
def produceImage(file_in, width, height, file_out):
image = Image.open(file_in)
resized_image = image.resize((width, height), Image.ANTIALIAS)
resized_image.save(file_out)
if __name__ == '__main__':
file_in = r'image\2.png'
width = 28
height = 28
file_out = r'image\2_1.png'
produceImage(file_in, width, height, file_out)
# 把图像转化为黑白的
im = Image.open(r'image\2_1.png')
L = im.convert("L")
L.save(r'image\2_1.png')

通过以上代码进行裁切得到

CNN实现手写数字识别_2d_05


我们通过刚刚保存的模型去实验

import tensorflow as tf
from PIL import Image
import numpy as np
im_4 = Image.open(r'image\2_1.png')
im_4 = np.reshape(im_4, [1,28,28,1])
#调用模型
new_model =tf.keras.models.load_model('minst_model.h5')
#进行预测
pe_4 = new_model.predict(im_4)
#把最大的坐标找到,因为new_model.predict返回的是[0,0,1,0,0,0,0,0,0,0]这种格式,
#所以需要转换为我们熟悉的格式
pe_4 = tf.argmax(pe_4 ,1)

with tf.Session() as sess:
print(sess.run(pe_4))

输出结果:
2


标签:keras,28,tf,add,CNN,手写,识别,model,accuracy
From: https://blog.51cto.com/u_15876949/5857156

相关文章

  • 基于 CNN-GRU 的菇房多点温湿度预测方法研究 学习记录
    本篇文章主要为学习其模型思想。引言卷积神经网络(CNN)作为在图像处理、计算机视觉等领域被广泛应用的模型,其特殊的网络结构通过共享权重的特性可以很好地处理高维稀疏特......
  • 2018年辽宁省电子设计大赛D题手势识别装置
    一转眼,两年过去了。距离这个比赛已经好久。我此时(2020年5月28日)已然大四,马上要念研究生了。现在回头看这篇我刚开始接触写的文章,还没有学会markdown,而且认识也比较粗浅。大......
  • 教你手写webpack常用loader
    前言webpack作为目前主流的前端构建工具,我们几乎每天都需要与它打交道。个人认为一个好的开源产品壮大的原因应该包括核心开发者的稳定输出以及对应生态的繁荣。对于生态......
  • 【tensorflow2.6】图片数据建模流程:猫狗分类,83.6%识别率
    目标:识别猫和狗一、猫狗数据集数据集下载:公众号,回复:猫狗数据集训练数据集(每一张图片都有dog和cat标签):测试集(图片没有标签):二、训练环境kaggletenslrflow2.6三、数据处理impo......
  • 经典CNN设计演变的关键总结:从VGGNet到EfficientNet
    卷积神经网络设计史上的主要里程碑:模块化、多路径、因式分解、压缩、可扩展一般来说,分类问题是计算机视觉模型的基础,它可以延申解决更复杂的视觉问题,例如:目标检测的任务包......
  • 实验3 手写字体识别【机器学习】
    推荐​​python实现手写数字识别(小白入门)​​原文​​MNISTHandwrittenDigitRecognitioninPyTorch​​​翻译用PyTorch实现MNIST手写数字识别(非常详细)mnist.gz/mnis......
  • Qt 在Mac上无法识别编译器
    由于Mac系统更新,导致我之前的Xcode不能用了,然后我就把Xcode卸载了,结果悲剧了,Qt无法使用了,提示无法识别Apple的clang编译器。使用Qt前,必须先安装Xcode!!!!关闭QtCreator,在终......
  • 深度学习基础课:用全连接层识别手写数字(上)
    大家好~我开设了“深度学习基础班”的线上课程,带领同学从0开始学习全连接和卷积神经网络,进行数学推导,并且实现可以运行的Demo程序线上课程资料:本节课录像回放加QQ群,获得......
  • JAVA 调佣百度ai识别身份证和车牌号
    识别身份证和车牌号的方法:packagefunction;importcom.baidu.aip.ocr.AipOcr;importorg.json.JSONObject;importjava.util.HashMap;/***图像识别sdk*/p......
  • JAVA 调佣百度ai识别动植物
    项目结构:    调用sdk分别实现动物识别和植物识别类:packagefounction;importutil.AuthService;importutil.Base64Util;importutil.FileUtil;importut......