首页 > 其他分享 >手写数字识别-使用TensorFlow构建和训练一个简单的神经网络

手写数字识别-使用TensorFlow构建和训练一个简单的神经网络

时间:2024-07-05 14:55:27浏览次数:23  
标签:layers add images 神经网络 test TensorFlow model 手写 history

下面是一个具体的Python代码示例,展示如何使用TensorFlow实现一个简单的神经网络来解决手写数字识别问题(使用MNIST数据集)。以下是一个完整的Python代码示例,展示如何使用TensorFlow构建和训练一个简单的神经网络来进行手写数字识别。

MNIST数据集的训练集有60000个样本:

Python代码

import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
import json
import os

# 加载MNIST数据集
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 预处理数据
train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255

# 构建神经网络模型
def create_model():
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))

    # 编译模型
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# 训练模型并保存
def train_and_save_model():
    model = create_model()
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
    model.save('mnist_model.h5')

    # 保存训练历史记录
    with open('training_history.json', 'w') as f:
        json.dump(history.history, f)

# 加载模型和历史记录
def load_model_and_history():
    model = tf.keras.models.load_model('mnist_model.h5')

    with open('training_history.json', 'r') as f:
        history = json.load(f)
    
    return model, history

# 评估模型
def evaluate_model(model):
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print("Test accuracy: {}".format(test_acc))

# 可视化训练过程
def plot_training_history(history):
    plt.plot(history['accuracy'], label='accuracy')
    plt.plot(history['val_accuracy'], label='val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.show()

# 检查是否已经存在模型和历史记录
if not os.path.exists('mnist_model.h5') or not os.path.exists('training_history.json'):
    train_and_save_model()

model, training_history = load_model_and_history()
evaluate_model(model)
plot_training_history(training_history)

代码解释

  1. 加载MNIST数据集

    mnist = tf.keras.datasets.mnist
    (train_images, train_labels), (test_images, test_labels) = mnist.load_data()
    
  2. 预处理数据

    • 将图像数据调整为 (28, 28, 1) 的形状。
    • 将像素值标准化为 [0, 1] 之间。
    train_images = train_images.reshape((60000, 28, 28, 1)).astype('float32') / 255
    test_images = test_images.reshape((10000, 28, 28, 1)).astype('float32') / 255
    
  3. 构建神经网络模型

    • 使用 Sequential 模型,按顺序添加层。
    • 添加卷积层、池化层、全连接层。
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10, activation='softmax'))
    
  4. 编译模型

    • 使用 adam 优化器,损失函数为 sparse_categorical_crossentropy,评估指标为 accuracy
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    
  5. 训练模型

    • 训练模型5个epochs,并使用验证数据集评估模型性能。
    history = model.fit(train_images, train_labels, epochs=5, validation_data=(test_images, test_labels))
    
  6. 评估模型

    • 在测试集上评估模型性能,并打印测试准确率。
    test_loss, test_acc = model.evaluate(test_images, test_labels)
    print(f"Test accuracy: {test_acc}")
    
  7. 可视化训练过程

    • 绘制训练和验证准确率随epoch变化的曲线。
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.ylim([0, 1])
    plt.legend(loc='lower right')
    plt.show()
    

通过这个修正后的示例,应该可以正常运行并训练一个简单的神经网络模型来进行手写数字识别。

标签:layers,add,images,神经网络,test,TensorFlow,model,手写,history
From: https://www.cnblogs.com/WG11/p/18285804

相关文章

  • 脉冲神经网络(Spiking Neural Network,SNN)相关论文最新推荐(一)
    用稀疏代理梯度直接训练时态脉冲神经网络论文链接:www.sciencedirect.comBenchmarkingArtificialNeuralNetworkArchitecturesforHigh-PerformanceSpikingNeuralNetworks论文链接:www.mdpi.comHierarchicalspikingneuralnetworkauditoryfeaturebaseddry-typet......
  • 基于CNN的蒙牛评论情感分析(用的Tensorflow,因为不会Pytorch)
    目录一、数据来源二、导相应入的库 三、数据预处理四、模型的构建五、预测函数六、总结一、数据来源 蒙牛评论数据集:共有2000多条数据,其中一列为label,一列为review,label这一列已经分好差评和好评,差评为0,好评为1,好评和差评占比为1比1.二、导相应入的库#导入数......
  • 实现第一个神经网络
    PyTorch包含创建和实现神经网络的特殊功能。在本节实验中,将创建一个简单的神经网络,其中一个隐藏层开发一个输出单元。通过以下步骤使用PyTorch实现第一个神经网络。第1步首先,需要使用以下命令导入PyTorch库。In [1]:import torchimport torch.nnas nn第2步定......
  • Python基于PyQt5和卷积神经网络分类模型(ResNet50分类算法)实现生活垃圾分类系统GUI界
    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取。1.项目背景在当今社会,随着人们对环境保护意识的增强以及科技的快速发展,智能化的垃圾分类系统成为了一个热门的研究方向。结合深度学习技术,尤其是先进的图像识......
  • 使用python基本库代码实现神经网络常见层
    一:批量归一化(BatchNormalization)代码解释:函数定义:batch_norm函数接受输入数据X、缩放参数gamma、平移参数beta和一个小常数epsilon,用于防止除零错误。X的形状为(N,D),其中N是批量大小,D是特征维度。gamma和beta的形状为(1,D)。计算批量均值和方差:me......
  • AI算法04-自组织映射神经网络Self-Organizing Map | SOM
    自组织映射神经网络自组织映射(SOM)或自组织特征映射(SOFM)是一种类型的人工神经网络(ANN),其使用已训练的无监督学习以产生低维(通常为二维),离散的表示训练样本的输入空间,称为地图,因此是一种减少维数的方法。自组织映射与其他人工神经网络不同,因为它们应用竞争学习而不是纠错学习(例如......
  • element 手写季度组件
    组件:<template><divclass="time_quarter"><markstyle="position:fixed;top:0;bottom:0;left:0;right:0;background:rgba(0,0,0,0);z-index:999;"v-show="showSeason"@click.stop="showSeason=false&q......
  • 基于卷积神经网络的交通标志识别系统(通过TensorFlow构建LeNet-5模型,并使用GTSRB德国交
    完成程序下载点此下载1、资源项目源码均已通过严格测试验证,保证能够正常运行;2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通;3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业......
  • TensorFlow中numpy与tensor数据相互转化(支持tf1.x-tf2.x)
    TensorFlow中numpy与tensor数据相互转化(支持tf1.x-tf2.x)TF1.x版本有时候解决起来很简单,就是错误比较难找到,所以我推荐的方法为将数据进行显式的转化。Numpy2Tensor虽然TensorFlow网络在输入Numpy数据时会自动转换为Tensor来处理,但是我们自己也可以去显式的转换:data_tensor......
  • 基于卷积神经网络的交通标志识别系统(通过TensorFlow构建LeNet-5模型,并使用GTSRB德国交
    完成程序下载点此下载1、资源项目源码均已通过严格测试验证,保证能够正常运行;2、项目问题、技术讨论,可以给博主私信或留言,博主看到后会第一时间与您进行沟通;3、本项目比较适合计算机领域相关的毕业设计课题、课程作业等使用,尤其对于人工智能、计算机科学与技术等相关专业......