一、前言
前面用TensorFlow浅做了一个温度预测,使用的是全连接网络,同时我们还对网上的示例做了调试和修改,使得预测结果还能看。本篇我们更进一步使用CNN(卷积)网络,不过再预测温度就有点大材小用,所以本篇是做手写数字的识别。
手写数字识别是非常经典的分类问题,是入门必备的,门槛又比猫狗识别低很多(猫狗图片太大需要的计算资源太大)。
二、数据准备
手写数字识别由于过于经典,所以TensorFlow已经自带了mnist数据集,所以也不用我们自己去找数据来训练。
下面的简单代码可以自动下载mnist数据集,可以看到这个数据集是由28x28像素的图片组成,训练集有6W个图片,测试集有1W个图片,另外配套有等量的label。这里的x是图片,y是图片的值。
将x数据用图片输出,y数据先转为one-hot格式再输出。one-hot是做分类任务必备的,否则网络很难收敛,有兴趣的可以自己试下。
我们再将图片的数值输出来看,会发现5这个图片,其实是可以用数字矩阵来数字化的。矩阵中数值为0的地方就是纯黑色,数值为255的地方就是纯白色,10、20、30这种就是渐变的灰色。
三、构建网络
使用下面的代码构建了一个简单的CNN网络,甚至没有使用到激活函数。
与全连接层相比较:
1)CNN是layers.Conv2D,全连接层是layers.Dense
2)参数有很大不同,全连接层没这么多参数要设置,以后再详细介绍参数
- Filters 卷积核个数
- Kernel_size 卷积核尺寸
- Padding 边缘填充设置
- Input_shape 输入尺寸限制
3)有一个池化层
4)最后还是需要一个全连接层来得到一维的结果
5)直接输入的是二维数组,而不是一维数组。
四、训练
训练使用的参数与之前没什么区别,训练了100次就收敛的挺不错,而且没有明显的过拟合。
五、预测
这里我输出了前5个数字的识别,由于Y是one-hot格式,所以输出结果的格式是每一个“位”都输出一个值,这个值描述的是,图片数值在该位的概率。
以第一个数为例:
[ 0.04167002 0.05634148 -0.03675765 -0.06590495 0.06823181 -0.0774302
0.08698683 0.93804836 -0.06439071 0.16399711]
其含义是图片数字为0的概率是0.04167002,1的概率是0.05634148,7的概率是0.93804836。所以我们取最高概率位即可,也就是7。
六、回顾
由于有了前面温度预测的基础,这里我们只是简单过一下代码,后面的篇幅再详细说。
标签:数字,卷积,网络,图片,识别,mnist,浅用 From: https://www.cnblogs.com/cation/p/17379491.html