首页 > 其他分享 >使用TensorFlow实现MNIST数据集分类

使用TensorFlow实现MNIST数据集分类

时间:2023-03-19 12:14:24浏览次数:77  
标签:分类 batch train mnist tf TensorFlow data MNIST

1 MNIST数据集

MNIST数据集由70000张28x28像素的黑白图片组成,每一张图片都写有0~9中的一个数字,每个像素点的灰度值在0 ~ 255(0是黑色,255是白色)之间。
在这里插入图片描述
MINST数据集是由Yann LeCun教授提供的手写数字数据库文件,其官方下载地址THE MNIST DATABASE of handwritten digits
在这里插入图片描述
下载好MNIST数据集后,将其放在Spyder工作目录下(若使用Jupyter编程,则放在Jupyter工作目录下),如图:
在这里插入图片描述
G:\Anaconda\Spyder为笔者Spyder工作目录,MNIST_data为新建文件夹,读者也可以自行命名。

2 实验

为方便设计神经网络输入层,将每张28x28像素图片的像素值按行排成一行,故输入层设计28x28=784个神经元,隐藏层设计600个神经元,输出层设计10个神经元。使用read_data_sets()函数载入数据集,并返回一个类,这个类将MNIST数据集划分为train、validation、test 3个数据集,对应图片数分别为55000、5000、10000。本文采用交叉熵损失函数,并且为防止过拟合问题产生,引入正则化方法。
mnist.py

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

#载入数据集
mnist=input_data.read_data_sets("MNIST_data",one_hot=True)

#每批次的大小
batch_size=100
#总批次数
batch_num=mnist.train.num_examples//batch_size
#训练轮数
training_step = tf.Variable(0,trainable=False)

#定义两个placeholder
x=tf.placeholder(tf.float32, [None,784])
y=tf.placeholder(tf.float32, [None,10])

#神经网络layer_1
w1=tf.Variable(tf.random_normal([784,600]))
b1=tf.Variable(tf.constant(0.1,shape=[600]))
z1=tf.matmul(x,w1)+b1
a1=tf.nn.tanh(z1)

#神经网络layer_2
w2=tf.Variable(tf.random_normal([600,10]))
b2=tf.Variable(tf.constant(0.1,shape=[10]))
z2=tf.matmul(a1,w2)+b2

#交叉熵代价函数
cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y,1),logits=z2)
#cross_entropy=tf.nn.softmax_cross_entropy_with_logits_v2(labels=y,logits=z2) 
#L2正则化函数
regularizer=tf.contrib.layers.l2_regularizer(0.0001)
#总损失
loss=tf.reduce_mean(cross_entropy)+regularizer(w1)+regularizer(w2)
#学习率(指数衰减法)
laerning_rate = tf.train.exponential_decay(0.8,training_step,batch_num,0.999)
#梯度下降法优化器
train=tf.train.GradientDescentOptimizer(laerning_rate).minimize(loss,global_step=training_step)

#预测精度
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(z2,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

#初始化变量
init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    test_feed={x:mnist.test.images,y:mnist.test.labels}
    for epoch in range(51):
        for batch in range(batch_num):
            x_,y_=mnist.train.next_batch(batch_size)
            sess.run(train,feed_dict={x:x_,y:y_})
        acc=sess.run(accuracy,feed_dict=test_feed)
        if epoch%10==0:
            print("epoch:",epoch,"accuracy:",acc)  

在这里插入图片描述
迭代50次后,精度达到97.68%。

​ 声明:本文转自使用TensorFlow实现MNIST数据集分类

标签:分类,batch,train,mnist,tf,TensorFlow,data,MNIST
From: https://www.cnblogs.com/zhyan8/p/17232722.html

相关文章

  • tensorflow中高维数组乘法运算
    1前言声明:本博客里的数组乘法运算是指矩阵乘法运算,不是对应元素相乘。在线性代数或高等代数中,我们学习了矩阵乘法,那么,什么样的高维数组才能相乘?tensorflow又是如何定义......
  • tensorflow中交叉熵损失函数详解
    1前言tensorflow中定义了3个交叉熵损失函数:softmax_cross_entropy_with_logits(logits,labels)softmax_cross_entropy_with_logits_v2(logits,labels)sparse_softm......
  • 图像处理(1):PyTorch垃圾分类 数据预处理
    基于深度学习框架PyTorchtransforms方法进行数据的预处理产品和技术负责人,专注于NLP、图像、推荐系统整个过程主要包括:缩放、裁剪、归一化、标准化几个基本步骤。图像归一......
  • tensorflow yolov3训练自己的数据集,详细教程
    这个教程是我在自己学习的过程中写的,当作一个笔记,写的比较详细在github上下载yolov3的tensorflow1.0版本:​​​https://github.com/YunYang1994/tensorflow-yolov3​​​......
  • 基于keras采用LSTM实现多标签文本分类
    我先抓取博客园知识库的文章标题和分类代码:#coding=utf-8importosimportsysimportrequestsfromlxmlimportetree,htmlimportlxmlimporttimeimportref......
  • tensorflow.keras.datasets 中关于imdb.load_data的使用说明
    python深度学习在加载数据时(num_words=10000)所代表的意义首先写一段深度学习加载数据集的代码:fromkeras.datasetsimportreuters(train_data,train_labels),(test_dat......
  • 人工智能python3+tensorflow人脸识别_使用 face-api.js 在你的浏览器中做人脸识别(基于
    我很兴奋地告诉你,终于可以在浏览器中运行人脸识别了!这篇文章我将介绍face-api.js,这个类库构建于tensorflow.js之上。它实现了多个CNNs(卷积神经网络)以解决人脸检测、......
  • 06.深度学习--分类模型
    分类模型输入对象x,输出是这个对象属于哪一个类class,这样的应用同样有很多,比如:在金融上可以通过分类模型来决定是否贷款给某人;图像识别方面;人脸辨识方面,等等。这里依然使......
  • Matlab建立SVM,KNN和朴素贝叶斯模型分类绘制ROC曲线|附代码数据
    原文链接:http://tecdat.cn/?p=15508最近我们被客户要求撰写关于SVM,KNN和朴素贝叶斯模型的研究报告,包括一些图形和统计输出。绘制ROC曲线通过Logistic回归进行分类 加......
  • YOLOv5 图片分类
     官网链接、1、命令1)没有模型配置,只能通过--model配置加载预训练模型pythontrain.py--epochs5--batch-size4--workers4--img224--dataE:\数据集\flower_pho......