首页 > 其他分享 >手写数字识别——KNN模型实现

手写数字识别——KNN模型实现

时间:2024-07-19 22:24:58浏览次数:20  
标签:KNN knn labels test train images import 手写 识别

MNIST手写数字识别

        MNIST 手写数字数据库有一个包含 60,000 个示例的训练集和一个包含 10,000 个示例的测试集。

  • 每个图像高 28 像素,宽28 像素,共784个像素。

  • 每个像素取值范围[0,255],取值越大意味着该像素颜色越深

        下载:http://yann.lecun.com/exdb/mnist/

import os
from torchvision import datasets

# 设置数据集的根目录
os.environ['TORCH_HOME'] = './MNIST'

# 下载数据集
train_dataset = datasets.MNIST(root=os.getenv('TORCH_HOME'), train=True, download=True,)
test_dataset = datasets.MNIST(root=os.getenv('TORCH_HOME'), train=False, download=True,)

print(train_dataset.train_data.shape)
print(train_dataset.train_labels.shape)
print(test_dataset.test_data.shape)
print(test_dataset.test_labels.shape)

torch.Size([60000, 28, 28])
torch.Size([60000])
torch.Size([10000, 28, 28])
torch.Size([10000])

KNN算法预测

import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import joblib
import numpy as np

# 加载MNIST数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# 数据归一化
train_images = train_images / 255.0
test_images = test_images /255.0

# 将数据全部组合
x = np.concatenate((train_images,test_images),axis=0).reshape(70000, -1)
y = np.concatenate((train_labels, test_labels),axis=0)

# 数据集划分
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, stratify=y, random_state=0)

# 定义并训练KNN分类器
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(x_train, y_train)

print('测试集准确率:',knn.score(x_test,y_test))

# 保存训练好的KNN模型
joblib.dump(knn, './knn_model.pth')

# 加载并预测新图像
img = plt.imread('./test_image.png')
plt.imshow(img)

# 加载保存的KNN模型
knn_loaded = joblib.load('./knn_model.pth')

# 对图像进行预测
img_flat = img.reshape(1, -1)  # 将图像展平成一维数组
y_predict = knn_loaded.predict(img_flat)
print(f'Predicted Label: {y_predict}')

测试集准确率: 0.9735
Predicted Label: [2]

 

标签:KNN,knn,labels,test,train,images,import,手写,识别
From: https://blog.csdn.net/weixin_74254879/article/details/140452511

相关文章

  • paddleocr识别并按行输出结果
    frompaddleocrimportPaddleOCR#初始化OCR引擎ocr=PaddleOCR(use_angle_cls=True,lang="ch")#使用中文模型#对图像进行OCR识别img_path='./imgs/img_3.png'result=ocr.ocr(img_path,cls=True)#按y坐标对文本块进行排序sorted_result=sorted(result,......
  • 63文章解读与程序——电力系统保护与控制EI\CSCD\北大核心《基于混沌集成决策树的电
    ......
  • 百度人脸识别Windows C++离线sdk C#接入
    百度人脸识别WindowsC++离线sdkC#接入目录说明设计背景•场景特点:•客户特点:•核心需求:SDK包结构效果代码说明自己根据SDK封装了动态库,然后C#调用。功能接口设计背景•场景特点:--网络:对于无网、局域网等情况,无法连接公网,API方式无法运作。如政府单......
  • 基于语音识别的会议记录系统
    目录核心功能页面展示使用技术方案功能结构设计数据库表展示核心功能页面展示视频展示功能1.创建会议在开始会议之前需要管理员先创建一个会议,为了能够快速开始会议,仅需填写会议的名称、会议举办小组、会议背景等简要会议信息即可成功创建。2.语音识别会议记录(最核心功......
  • Halcon的学习笔记(一)——非线性字符识别
    Halcon非线性模式的字符识别(ocr_cd_print_polar_trans.hdev例程分析)Halcon的学习笔记(一)——非线性字符识别项目上需要对非线性模式的字符进行识别,halcon中包含的例程,我搜了一下,网上对于该例程的解析比较少,因此自己便记录了一下自己的学习例程,也算自己的学习笔记。1.什......
  • RFID无线射频识别
    一、简要说明RFID是一种无线通信技术,通过无线电信号识别特定目标,并读取相关数据,而无需建立机械或者光学接触。RFID全称:Radio-FrequencyIdentification二、工作原理需要辨别的物体上附有标签,通过阅读器(双向无线电波收发器)向标签发出信号并解读其应答。传输方式:无线电的信号通......
  • 从零手写实现 nginx-31-load balance 负载均衡介绍
    前言大家好,我是老马。很高兴遇到你。我们为java开发者实现了java版本的nginxhttps://github.com/houbb/nginx4j如果你想知道servlet如何处理的,可以参考我的另一个项目:手写从零实现简易版tomcatminicat手写nginx系列如果你对nginx原理感兴趣,可以阅读:从零......
  • whisper-api语音识别语音翻译高性能兼容openai接口协议的开源项目
    whisper-api介绍使用openai的开源项目winsper语音识别开源模型封装成openaichatgpt兼容接口软件架构使用uvicorn、fastapi、openai-whisper等开源库实现高性能接口更多介绍[https://blog.csdn.net/weixin_40986713/article/details/138712293](https://blog.csdn.net......
  • uniapp对接人脸识别,人脸核身,双录 ,阿里云,以及腾讯云对接方法。
    腾讯云uniapp接入】第一步,申请人脸核身服务:https://cloud.tencent.com/apply/p/shcgszvmppc第二步,申请业务流程WBAppid:-获取WBappid方法指引:https://cloud.tencent.com/document/product/1007/49634-申请链接:https://console.cloud.tencent.com/faceid/access第三步,uni插件接入......