利用KNN算法实现手写数字识别
MNIST手写数字识别 是计算机视觉领域中 "hello world"级别的数据集
- 1999年发布,成为分类算法基准测试的基础
- 随着新的机器学习技术的出现,MNIST仍然是研究人员和学习者的可靠资源。
本次案例中,我们的目标是从数万个手写图像的数据集中正确识别数字。
数据介绍
数据文件 train.csv 和 test.csv 包含从 0 到 9 的手绘数字的灰度图像。
-
每个图像高 28 像素,宽28 像素,共784个像素。
-
每个像素取值范围[0,255],取值越大意味着该像素颜色越深
-
训练数据集(train.csv)共785列。第一列为 “标签”,为该图片对应的手写数字。其余784列为该图像的像素值
-
训练集中的特征名称均有pixel前缀,后面的数字([0,783])代表了像素的序号。
像素组成图像如下:
000 001 002 003 ... 026 027
028 029 030 031 ... 054 055
056 057 058 059 ... 082 083
| | | | ...... | |
728 729 730 731 ... 754 755
756 757 758 759 ... 782 783
数据集示例如下:
# 导入工具包
import joblib
from sklearn.model_selection import train_test_split, GridSearchCV # 分割训练集和测试集的, 网格搜索 + 交叉验证.
from sklearn.neighbors import KNeighborsClassifier # KNN算法 分类对象
import matplotlib.pyplot as plt # 绘图.
import pandas as pd
from collections import Counter
# 需求 定义函数 接收索引 将该行的手写数字 识别为 图片并绘制出来
def dm01_show_digit(idx):
# 1. 读取文件 获取df对象
data = pd.read_csv('./data/手写数字识别.csv')
# 2.判断用户传入值 是否合法
if idx < 0 or idx >= len(data):
print('传入的索引有误 程序结束! ')
return
# 走到这里说明 没问题 查看下所有的数据集
x = data.iloc[:, 1:]
y = data.iloc[:, 0]
print(
f'数字的种类: {Counter(y)}') # Counter({1: 4684, 7: 4401, 3: 4351, 9: 4188, 2: 4177, 6: 4137, 0: 4132, 4: 4072, 8: 4063, 5: 3795})
print(f'像素的形状: {x.shape}')
# 根据传入的索引获取到该行的数据
print(f'您传入的所有 对应的数字是: {y[idx]}')
# 绘制图片
# 把图片的像素点 转为 28*28的图片
digit = x.iloc[idx].values.reshape(28, 28)
# 绘制图片
plt.imshow(digit, cmap='gray') # 灰度图
plt.axis('off') # 关闭坐标
# plt.savefig('./data/demo2.png')
plt.show()
# 需求2 定义函数 使用KNN算法 用于识别 手写数字 保存模型
def dm02_train_mdoel():
data = pd.read_csv('./data/手写数字识别.csv')
# 数据预处理
x = data.iloc[:, 1:]
y = data.iloc[:, 0]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=22, stratify=y)
# 特征工程
x_train = x_train / 255
# 模型训练
estimator = KNeighborsClassifier(n_neighbors=9)
estimator.fit(x_train, y_train)
# 模型评估
print(f'准确率: {estimator.score(x_test, y_test)}')
# 模型保存
joblib.dump(estimator, './model/knn.pkl')
def dm03_use_model():
# 读取图片 绘制图片
img = plt.imread('./data/demo.png')
plt.imshow(img,cmap='gray')
plt.show()
# 读取模型 获取模型对象
knn = joblib.load('./model/knn.pkl')
# 模型预测
y_predict = knn.predict(img.reshape(1,-1))
print(f'预测结果为:{y_predict}')
if __name__ == '__main__':
# dm01_show_digit(20)
# dm02_train_mdoel()
dm03_use_model()
坚持分享 共同进步
标签:KNN,plt,像素,Pyhton,train,test,手写,data From: https://blog.csdn.net/weixin_45423893/article/details/144835500