首页 > 其他分享 >CNN入门级实战教程

CNN入门级实战教程

时间:2024-03-18 18:31:41浏览次数:28  
标签:教程 train 模型 28 入门级 test CNN model mnist

本教程将使用Keras构建一个简单的的卷积神经网络(Convolutional Neural Network,CNN)来对手写数字进行识别。使用的数据集为MNIST数据集,一个包含手写数字图像的经典数据集。

0. 环境设置

确保你已经安装了所需的库,可以通过以下方式安装:

pip install keras
conda install matplotlib
conda install numpy

1. 数据准备

首先,我们需要加载MNIST数据集,该数据集包含60,000个训练图像和10,000个测试图像。

# numberCNNtrain.py
​
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from keras.utils import to_categorical
​
# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 数据预处理
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255
​
# 将标签转换为one-hot编码
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)

如果这一行出现了报错的话。

(x_train, y_train), (x_test, y_test) = mnist.load_data()

报错信息如下:Exception: URL fetch failure on https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz: None -- [WinError 10054] 远程主机强迫关闭了一个现有的连接。
说明你连接不上googleapis。不过可以直接将数据集下载,并将mnist.npz文件保存到 C:\Users\Administrator.keras\datasets目录下。这里提供一个网盘链接可以下载mnist.npz数据集。

百度网盘 请输入提取码

2. 构建CNN模型

接下来,我们构建一个简单的CNN模型。该模型包含两个卷积层(Conv2D),两个最大池化层(MaxPooling2D),以及两个全连接层(Dense)。

# numberCNNtrain.py
​
# 构建CNN模型
model = Sequential()
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(10, activation='softmax'))

卷积核个数(Filters):

较小的卷积核个数: 导致模型较少的参数和更快的训练速度,但降低模型的表达能力。

较大的卷积核个数: 增加模型的表达能力,但会增加模型的参数数量和训练时间。

卷积核尺寸(Kernel Size):

小尺寸的卷积核(如3x3): 通常用于较浅的层,可以捕捉更局部的特征。

大尺寸的卷积核(如5x5或7x7): 通常用于较深的层,可以捕捉更全局的特征。

3. 编译和训练模型

在构建模型后,我们需要编译它,并通过训练数据对其进行训练。

# numberCNNtrain.py
​
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
​
# 训练模型
model.fit(x_train, y_train, batch_size=128, epochs=5, validation_data=(x_test, y_test))

关于优化器、损失函数、评估指标等,可以进一步查阅Keras文档以了解更多可用参数和配置。

关于超参数:

batch_size(批量大小): 定义每次迭代中用于更新模型权重的样本数量。较大的批量大小可能会加快训练速度,但也可能导致内存不足。较小的批量大小可以提高模型的泛化性,但训练可能变慢。

epochs(训练轮数): 定义整个训练数据集被迭代的次数。增加 epochs 可以提高模型的性能,但也可能导致过拟合。

4. 评估和保存模型

训练完成后,我们对模型进行评估,并将其保存到文件中。

# numberCNNtrain.py
​
# 评估模型
test_loss, test_acc = model.evaluate(x_test, y_test)
print('Test accuracy:', test_acc)
​
# 保存模型
model.save('mnist_cnn_model.h5')

开始运行!运行结束后,会生成mnist_cnn_model.h5,这个就是我们训练的模型。

——————分割线——————

5. 模型预测和可视化

现在,我们可以使用训练好的模型mnist_cnn_model.h5进行预测并可视化结果。

# numberCNNpredict.py

import numpy as np
from keras.datasets import mnist
from keras.models import load_model
import matplotlib.pyplot as plt

# 加载模型
model = load_model('mnist_cnn_model.h5')

# 加载测试数据集
(_, _), (x_test, y_test) = mnist.load_data()
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype('float32') / 255

# 使用模型进行预测
predictions = model.predict(x_test)

# 获取预测的标签
predicted_labels = np.argmax(predictions, axis=1)

# 可视化预测结果
plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(x_test[i].reshape(28, 28), cmap=plt.cm.binary)
    plt.xlabel(f'Predicted: {predicted_labels[i]}')
plt.show()

可视化:

6.

标签:教程,train,模型,28,入门级,test,CNN,model,mnist
From: https://blog.csdn.net/qq_56633600/article/details/136689290

相关文章

  • Python教程:如何获取颜色的RGB值
    简介在许多计算机图形和图像处理应用中,颜色的RGB值是至关重要的信息。Python作为一种多功能的编程语言,提供了丰富的工具和库,可以轻松地获取颜色的RGB值。本文将介绍如何使用Python获取颜色的RGB值,以及一些实际应用的示例。使用PIL工具获取颜色的RGB值PIL(PythonImagingLibrar......
  • 【教程】为什么要为 App 应用加固 ?如何为 App 应用加固 ?
    ​一:为什么要为App应用加固来看下腾讯开放平台官方的解释说明若应用不做任何安全防护,极易被病毒植入、广告替换、支付渠道篡改、钓鱼、信息劫持等,严重侵害开发者的利益。App加固后,可以对应用进行安全防护,防止应用分发后,被反编译、调试、盗版、破解、二次打包等威胁,维护开发......
  • 【教程】苹果iOS证书制作教程
     摘要苹果iOS证书是上架苹果APP必不可少的签名工具,本文将详细介绍注册账号、创建APPID、制作p12证书、发布mobileprovision证书等步骤,帮助开发者顺利完成证书制作过程。引言在开发与发布苹果APP的过程中,制作苹果iOS证书是至关重要的一环。只有正确的证书签名,才能确保APP在苹......
  • 【Java入门教程】第五讲:if-else控制语句
    现实世界是复杂多变的,同一个程序我们需要根据不同的场景做出不同的反应。在Java编程中,if-else 语句就是这样一种工具,它允许程序根据不同的条件执行不同的代码块。一、基础语法if-else 语句的基本语法结构如下:if(condition){//代码块1:当条件为true时执行}else......
  • 【黑马MySQL】MySQL的下载&安装&启停&配置环境变量【一条龙教程】
    前言大家好吖,欢迎来到YY滴MySQL系列,热烈欢迎!本章主要内容面向接触过C++Linux的老铁主要内容含:欢迎订阅YY滴C++专栏!更多干货持续更新!以下是传送门!YY的《C++》专栏YY的《C++11》专栏YY的《Linux》专栏YY的《数据结构》专栏YY的《C语言基础》专栏YY的《初学者易......
  • Tomcat安装与配置详细教程:从入门到精通
    Tomcat安装与配置详细教程:从入门到精通简介:本教程旨在为广大开发者提供一份Tomcat服务器的安装与配置指南。通过本教程的学习,您将能够掌握Tomcat服务器的安装步骤、环境变量的配置方法,以及验证Tomcat配置是否成功的技巧。同时,我们还将简要介绍JavaJDK的安装与配置,为Tomca......
  • CentOS安装JDK17教程(完整版)
    JDK17是JavaDevelopmentKit(Java开发工具包)的第17个长期支持(LTS)版本,由Oracle公司于2021年9月发布。作为Java语言的主要发行版,JDK17带来了许多新特性、增强功能和优化。但是我们在Linux环境下使用yum安装时,发现不能直接安装JDK17,使用:yumsearchjava|grep......
  • 一键制作iOS上架App Store描述文件教程
     摘要本篇博文详细介绍了在iOS上架过程中所需的基础项目,包括IOS生产环境证书、APPID包名制作以及APP的描述文件。通过使用appuploader进行证书制作和上传IPA到AppStore,能够快速掌握真机测试和上架流程。引言在iOS应用开发过程中,正确制作描述文件对于应用的上架至关重要。本......
  • docker菜鸟教程
    Docker是一个开源的应用容器引擎,允许开发者将应用程序及其依赖打包到一个可移植的容器中,然后发布到任何流行的Linux机器或Windows机器上,也可以实现虚拟化。容器是完全使用沙箱机制,相互之间不会有任何接口,因此不会相互影响。Docker的基本使用步骤如下:安装Docker。根据......
  • C++实名认证接口教程-好集成的身份证实名认证接口-三要素认证
    现如今,随着实名制的实施,各行各业都将进行人员身份的核查,如家政、保洁、物流、金融、电商等,身份证实名认证接口主要是验证个人用户提交的姓名、人像和身份证号码信息,和公安数据库内对应的数据是否匹配一致,可以验证个人身份证信息的真伪。以下是C++语言调用翔云身份证实名......