首页 > 其他分享 >手写数字图片识别——DL 入门案例

手写数字图片识别——DL 入门案例

时间:2024-04-03 09:22:52浏览次数:318  
标签:DL 入门 img keras predict 模型 digit np 手写

Deep Learning Demo of Primary

下面介绍一个入门案例,如何使用TensorFlow和Keras构建一个CNN模型进行手写数字识别,以及如何使用该模型对自己的图像进行预测。尽管这是一个相对简单的任务,但它涵盖了深度学习基本流程,包括:

  • 数据准备
  • 模型构建
  • 模型训练
  • 模型预测

输入:

import tensorflow as tf
from tensorflow import keras
import numpy as np
from PIL import Image

# 加载MNIST数据集(用于训练模型)
# 这部分代码加载了MNIST数据集,这是一个广泛使用的手写数字图像数据集,包含60,000个训练样本和10,000个测试样本。
# 我们将像素值除以255.0,将它们归一化到0-1的范围内,这是神经网络输入的标准做法。
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 规范化像素值
train_images, test_images = train_images / 255.0, test_images / 255.0


# 构建CNN模型
# 这部分代码构建了一个卷积神经网络(CNN)模型。我们使用Keras的Sequential API,它允许我们按顺序堆叠不同的层。
# 我们添加了两个卷积层和两个最大池化层,用于从图像中提取特征。
# 然后,我们添加了一个展平层,将特征映射到一个一维向量。
# 最后,我们添加了两个全连接层,第一个具有128个神经元,第二个具有10个神经元,用于对手写数字进行分类。
# 最后一层使用softmax激活函数输出每个数字的概率。
model = keras.Sequential([
    keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    keras.layers.MaxPooling2D((2, 2)),
    keras.layers.Conv2D(64, (3, 3), activation='relu'),
    keras.layers.MaxPooling2D((2, 2)),
    keras.layers.Flatten(),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

# 编译模型
# 这部分代码构建了一个卷积神经网络(CNN)模型。我们使用Keras的Sequential API,它允许我们按顺序堆叠不同的层。
# 我们添加了两个卷积层和两个最大池化层,用于从图像中提取特征。然后,我们添加了一个展平层,将特征映射到一个一维向量。
# 最后,我们添加了两个全连接层,第一个具有128个神经元,第二个具有10个神经元,用于对手写数字进行分类。
# 最后一层使用softmax激活函数输出每个数字的概率。
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# 训练模型
model.fit(train_images, train_labels, epochs=5)

# 函数:预测手写数字图像
# 在构建模型之后,我们需要编译它。
# 我们指定了使用Adam优化器,稀疏分类交台熵损失函数(适用于整数标签),并监控准确率指标。
# 然后,我们使用model.fit函数在训练数据上训练模型,迭代5个epoch。
def predict_digit(img_path):
    # 加载图像
    img = Image.open(img_path).convert('L')
    img = img.resize((28, 28))
    img_array = np.array(img) / 255.0
    img_array = np.expand_dims(img_array, axis=-1)
    img_array = np.expand_dims(img_array, axis=0)

    # 进行预测
    predictions = model.predict(img_array)
    predicted_digit = np.argmax(predictions)

    return predicted_digit

# 测试
# 这个 predict_digit 函数用于预测手写数字图像。它接受一个图像文件路径作为输入。
# 首先,它使用PIL库加载图像,将其转换为灰度模式,并调整大小为28x28像素。
# 然后,它将图像转换为NumPy数组,并进行与训练数据相同的归一化处理。
# 由于CNN模型需要一个4D张量作为输入(batch_size, height, width, channels),
# 我们需要使用 np.expand_dims 在最后两个维度上扩展数组形状。
#
# 接下来,我们使用训练好的模型的 predict 方法对预处理后的图像数据进行预测,得到一个包含10个概率值的列表,每个值对应一个数字(0-9)的概率。
# 我们使用 np.argmax 找到概率值最大的索引,即模型预测的数字。
# 最后,函数返回预测的数字。
digit = predict_digit('image-8.png')
print(f'预测的数字是: {digit}')

输出:
预测的数字是: 8

但是完全不知道程序都做了什么...,那就学习它的流程吧。

Process:

  1. 首先,我们加载内置的MNIST数据集,并将像素值归一化到0-1之间。
  2. 然后,我们使用Keras的Sequential API构建一个CNN模型。该模型包含两个卷积层、两个最大池化层、一个展平层和两个全连接层。
    最后一层使用softmax激活函数输出10个数字的概率。
  3. 我们使用稀疏分类交叉熵损失函数和Adam优化器编译模型。
  4. 接下来,我们使用训练数据train_images和train_labels训练模型5个epoch。
  5. 我们定义了一个predict_digit函数,用于预测手写数字图像。这个函数接受一个图像文件路径作为输入。
  6. 在predict_digit函数中,我们首先使用Pillow库加载图像,并将其转换为灰度模式和28x28大小。
    然后,我们将图像数据转换为Numpy数组,并进行相同的归一化处理。
    由于模型的输入维度为(批次大小, 高度, 宽度, 通道数),我们需要使用np.expand_dims在最后两个维度上扩展数组形状。
  7. 接下来,我们使用训练好的模型的predict方法对预处理后的图像数据进行预测,得到一个包含10个概率值的列表,每个值对应一个数字(0-9)的概率。
    我们使用np.argmax找到概率值最大的索引,即模型预测的数字。
  8. 最后,我们调用predict_digit函数,传入你自己的图像文件路径,并打印预测结果。

标签:DL,入门,img,keras,predict,模型,digit,np,手写
From: https://www.cnblogs.com/mysticbinary/p/18110725

相关文章

  • Java登陆第三十六天——VUE3响应式入门、setup语法糖
    当浏览器接收到服务端返回的页面后,浏览器会把页面解析成DOM树,DOM树中各个元素会相应的显示在浏览器上。VUE提供的响应式数据可以在页面不刷新的情况下更新数据。响应式数据App.vue<script>//等价于setup语法糖。固定的写法,不会改。exportdefault{setup(){letsum......
  • 深入理解ThreadLocal原理
    目录1-什么是ThreadLocal?2-ThreadLocal的作用?ThreadLocal实现线程间资源隔离ThreadLocal实现线程内资源共享3-ThreadLocal原理3-1ThreadLocalMap3-2ThreadLocalMap的扩容......
  • py基础入门(一篇足够)
    python笔记来自b站中,孙兴华老师的课程笔记!目录看起来多,只是为了让有基础的兄弟选择查看,内容其实一点都不多!可以翻着看一下,有基础的感觉只用看目录就可以重温python基础,不用浪费太多时间了。文章目录python笔记安装python的安装pytharm简介什么是python?基本语法......
  • ThreadLocal源码解析
    方法三个主要方法:getsetremove讲三个方法前,现需要知道Thread,ThreadLocal,ThreadLocalMap三个之间的关系,首先ThreadLocalMap虽然是ThreadLocal中定义的静态内部类,但实际的ThreadLocalMap实例是作为Thread对象的一个字段存在的。这样设计的目的是允许每个线程存储自己......
  • MIT 6.S081入门lab10 mmap
    MIT6.S081入门lab10mmap一、参考资料阅读与总结1.JournalingtheLinuxext2fsFilesystem文件系统可靠性:磁盘崩溃前数据的稳定性;故障模式的可预测性;操作的原子性-论文核心:将日志事务系统加入Linux的文件系统中;事务系统的要求:元数据的更新;事务系统的顺序性;数据块写入磁......
  • Container容器:未来的最终解:Docker(入门导览)
    容器容器:可以无视机器、系统限制的时刻使用任何的软件或程序的虚拟机-容器解释:[什么是容器?|IBM备注:Docker本身并不是容器,它是创建容器的工具,是应用容器引擎优势:docker虚拟机内存轻量占用大设备几乎支持所有电子设备主要PC主机镜像复用可以打包到官方仓库,云端下载需要点......
  • docker入门
    Docker是一种容器化平台,可以让开发者打包自己的应用程序及其依赖项,并以容器的形式进行交付。以下是Docker的入门指南:安装Docker:首先,你需要在你的操作系统上安装Docker。Docker可以在各种操作系统上运行,包括Linux、macOS和Windows。你可以从Docker官方网站下载......
  • 自然语言处理基础知识入门(二) Word2vec模型,层次softmax,负采样算法详解
    文章目录前言一、Word2vec模型1.1什么是Word2vec模型?1.2Word2vec模型是如何训练?1.3Word2vec最简单版本整体过程1.4Word2vec详细过程1.5CBOW整体过程1.6Skip-gram整体过程二、优化算法2.1层次softmax2.1.1哈夫曼树2.1.2算法详细逻辑2.2负采样策略总结......
  • 手写简易操作系统(二十)--实现堆内存管理
    前情提要前面我们实现了0x80中断,并实现了两个中断调用,getpid和write,其中write还由于没有实现文件系统,是个残血版,这一节我们实现堆内存管理。一、arena在计算机科学中,“arena”内存管理通常指的是一种内存分配和管理技术,它通常用于动态内存分配和释放。在这种管理......
  • .NET Emit 入门教程:第六部分:IL 指令:3:详解 ILGenerator 指令方法:参数加载指令
    前言:在上一篇中,我们介绍了ILGenerator辅助方法。本篇,将详细介绍指令方法,并详细介绍指令的相关用法。在接下来的教程,关于IL指令部分,会将指令分为以下几个分类进行讲解:1、参数加载指令:ld开头的指令,单词为:loadargument2、参数存储指令:st开头的指令,单词为:store3、创建实......