首页 > 其他分享 >基于mnist数据集的手写数字识别模型的训练可视化预测

基于mnist数据集的手写数字识别模型的训练可视化预测

时间:2024-07-20 16:25:17浏览次数:17  
标签:args test train 可视化 tf 手写 save model mnist

使用 tensorflow 库创建训练模型

数据集使用公开的 mnist 

一、构建模型

from  tensorflow.keras.layers import Dense, Dropout
import tensorflow as tf
def mnistModel():
    model = tf.keras.Sequential([
        tf.keras.layers.Flatten(input_shape=(28, 28)), #对其进行形状转化 转化成一维28*28
        tf.keras.layers.Dense(128, activation='relu'), # 添加隐藏层,使用全连接层,设置128个节点,使用relu作为激活函数
        tf.keras.layers.Dense(10, activation='softmax')  # 添加输出层,使用全连接层,设置10个节点,使用 softmax 作为激活函数
    ])
    return model

  1.  将输入28x28 转化成一维
  2. 添加隐藏层,使用全连接层,设置128个节点,使用relu作为激活函数
  3. 添加输出层,使用全连接层,设置10个节点,使用 softmax 作为激活函数

二、构建训练模型脚本

import argparse
import tensorflow as tf
from model import mnistModel

# 设置使用GPU,防止内存不足
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

def run(optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["sparse_categorical_accuracy"],
        batch_size=64,  # 每个批量使用64条数据
        epochs=5,  # 训练轮数为5
        validation_split=0.2,  # 划分出20%作为验证
        save_model="./model/mnist_model.h5"
        ):
    model = mnistModel()
    # 配置训练方法
    model.compile(optimizer=optimizer,  # 使用adam优化器
                  loss=loss,  # 损失函数使用稀疏叉熵损失函数
                  metrics=metrics  # 准确率使用稀疏分类准确率函数
                  )
    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    # 属性归一化
    X_train, X_test = tf.cast(train_x / 255.0, tf.float32), tf.cast(test_x / 255.0, tf.float32)
    y_train, y_test = tf.cast(train_y, tf.int16), tf.cast(test_y, tf.int16)
    # 训练模型
    model.fit(X_train,
              y_train,
              batch_size=batch_size,  # 每个批量使用64条数据
              epochs=epochs,  # 训练轮数为5
              validation_split=validation_split  # 划分出20%作为验证
              )
    if save_model:
        model.save(save_model, overwrite=True, save_format=None)
    # 评估模型
    accuracy = model.evaluate(X_test, y_test)
    print(accuracy)

def parse_args():
    parser = argparse.ArgumentParser(description="Train a MNIST model with custom parameters.")
    parser.add_argument('--optimizer', type=str, default='adam', help='Optimizer to use for training.')
    parser.add_argument('--loss', type=str, default='sparse_categorical_crossentropy', help='Loss function to use.')
    parser.add_argument('--metrics', type=str, nargs='+', default=['sparse_categorical_accuracy'], help='List of metrics to evaluate during training.')
    parser.add_argument('--batch_size', type=int, default=64, help='Number of samples per gradient update.')
    parser.add_argument('--epochs', type=int, default=5, help='Number of epochs to train the model.')
    parser.add_argument('--validation_split', type=float, default=0.2, help='Fraction of the training data to be used as validation data.')
    parser.add_argument('--save_model', type=str, default='./model/mnist_model.h5', help='Path to save the trained model.')

    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    run(optimizer=args.optimizer,
        loss=args.loss,
        metrics=args.metrics,
        batch_size=args.batch_size,
        epochs=args.epochs,
        validation_split=args.validation_split,
        save_model=args.save_model)
  1. 使用adam 当作优化器
  2. 使用sparse_categorical_crossentropy 作为评价指标
  3. 划分出两成数据当作验证集validation_split=0.2
  4. batch_size=65,每个批量使用64条数据
  5. epochs=5 设置训练轮数,可以设置成200甚至更高
  6. 设置保存路径,配置训练方法
    model.compile(optimizer=optimizer,  # 使用adam优化器
                  loss=loss,  # 损失函数使用稀疏叉熵损失函数
                  metrics=metrics  # 准确率使用稀疏分类准确率函数
                  )

       7. 读取数据集并进行归一化

    mnist = tf.keras.datasets.mnist
    (train_x, train_y), (test_x, test_y) = mnist.load_data()
    # 属性归一化
    X_train, X_test = tf.cast(train_x / 255.0, tf.float32), tf.cast(test_x / 255.0, tf.float32)
    y_train, y_test = tf.cast(train_y, tf.int16), tf.cast(test_y, tf.int16)

 8. 训练模型并保存

    # 训练模型
    model.fit(X_train,
              y_train,
              batch_size=batch_size,  # 每个批量使用64条数据
              epochs=epochs,  # 训练轮数为5
              validation_split=validation_split  # 划分出20%作为验证
              )
    if save_model:
        model.save(save_model, overwrite=True, save_format=None)

9. 评估模型

    # 评估模型
    accuracy = model.evaluate(X_test, y_test)

三、训练模型

 运行train.py脚本

 

四、使用

使用PyQt5构建可视化操作界面

​​​​​​​

标签:args,test,train,可视化,tf,手写,save,model,mnist
From: https://blog.csdn.net/shxhzxj/article/details/140572610

相关文章

  • UnicodeEncodeError: ‘gbk‘ codec can‘t encode character ‘\xb5‘ in position
    报错UnicodeEncodeError是由于文件写入过程中编码格式不匹配导致的。为了避免这种问题,可以显式指定使用UTF-8编码来写入文件。以下是修改后的代码,确保在写入HTML文件时使用UTF-8编码:importnumpyasnpimportpandasaspdfromsklearn.datasetsimportload_iri......
  • MP+XML手写sql语句分页查询
    原则:让IPage接收从数据库查处的记录@AutowaireprivateUserMapperuserMapper;publicPageDTO<UserVO>pageUser(UserPageQueryquery){IPage<UserVO>page=newPage<>(query.getPageNo(),2);page=userMapper.PageAndXml(query,page);List<U......
  • 手写数字识别——KNN模型实现
    MNIST手写数字识别        MNIST手写数字数据库有一个包含60,000个示例的训练集和一个包含10,000个示例的测试集。每个图像高28像素,宽28像素,共784个像素。每个像素取值范围[0,255],取值越大意味着该像素颜色越深    下载:http://yann.lecun.com/e......
  • 视觉探秘:sklearn中聚类标签的可视化之道
    视觉探秘:sklearn中聚类标签的可视化之道在数据科学领域,聚类分析是一种无监督学习方法,用于将数据集中的样本划分为若干个组或“簇”,使得同一组内的样本相似度高,而不同组之间的样本相似度低。Scikit-Learn(简称sklearn),作为Python中广受欢迎的机器学习库,不仅提供了多种聚类算法......
  • 使用免费工具,大屏可视化古董展览
    传统的古董展览,虽能展现文物的精美与历史的厚重,但往往受限于物理空间与展示形式的单一。而今,随着可视化平台的兴起,以及3D建模、虚拟现实(VR)、增强现实(AR)等技术的广泛应用,我们得以以前所未有的方式“走进”历史。 山海鲸可视化通过高精度的3D建模,可以一比一还原古董样貌,古董的每......
  • 这款免费可视化工具能帮你实现零代码GIS场景编辑
    在当今快速发展的科技时代,GIS场景编辑已成为各行业不可或缺的一部分。然而,复杂的操作和昂贵的软件成本常常令许多人望而却步。幸运的是,现在有了一款免费的可视化工具——山海鲸可视化,它能帮你轻松实现零代码GIS场景编辑,满足你从三维GIS需求出发的所有要求。山海鲸可视化在GIS场......
  • 易优CMS模板标签uitype栏目调用在模板文件index.htm中调用uitype标签,实现指定栏目可视
    【基础用法】标签:uitype描述:栏目编辑,比uitext、uihtml、uiupload标签多了一个typeid属性,使用时结合html一起才能完成可视化布局,只针对具有可视化功能的模板。用法:<divclass="eyou-edit"e-id="文件模板里唯一的数字ID"e-page='文件模板名'e-type="type">{eyou:uitypetypeid=......
  • 计算机毕业设计Python+Tensorflow小说推荐系统 K-means聚类推荐算法 深度学习 Kears
    2、基于物品协同过滤推荐算法2.1、基于⽤户的协同过滤算法(UserCF)该算法利⽤⽤户之间的相似性来推荐⽤户感兴趣的信息,个⼈通过合作的机制给予信息相当程度的回应(如评分)并记录下来以达到过滤的⽬的进⽽帮助别⼈筛选信息,回应不⼀定局限于特别感兴趣的,特别不感兴趣信息的纪录也相......
  • 【译】使 Visual Studio 更加可视化
    任何Web、桌面或移动开发人员都经常使用图像。你可以从C#、HTML、XAML、CSS、C++、VB、TypeScript甚至代码注释中引用它们。有些图像是本地的,有些存在于线上或网络共享中,而其他图像可能仅以base64编码字符串的形式存在。我们在代码中以多种方式引用它们,但总是作为字符串......
  • GIS地图可视化怎么做?这款免费工具帮你轻松搞定
    GIS地图可视化怎么做?山海鲸可视化这款免费可视化工具帮你轻松搞定。从三维GIS地图可视化需求出发,山海鲸可视化提供了强大的GIS场景编辑功能,包括支持添加倾斜摄影和地形编辑。无论是复杂的地形调整还是细致的倾斜摄影添加,这款工具都能轻松实现。山海鲸可视化是一款非常易用的软件......