首页 > 其他分享 >手写数字识别-决策树

手写数字识别-决策树

时间:2024-12-19 20:28:46浏览次数:7  
标签:训练 val 分类 train 类别 import 手写 识别 决策树

手写数字识别-决策树

决策树

决策树是一种基于树形结构的分类算法,通过不断地根据特征划分数据集来实现分类。

数据集分析

在本任务中,我们使用的是著名的 MNIST 数据集(https://www.kaggle.com/code/nishan192/mnist-digit-recognition-using-svm中下载使用test即可),它包含了大量手写数字的图像,每个图像由 28x28 个像素组成。为了便于处理,数据集中的每个图像被展平为一个 784 维的向量。第一列是标签(label),它表示图像所代表的数字类别。

我们从数据集中提取了一张图像并进行可视化,代码如下:

import pandas as pd
import matplotlib.pyplot as plt

# 导入训练数据集
train_data = pd.read_csv(r"digit-recognizer\train.csv")  # 请根据实际路径更改

# 提取第四张图像的像素数据(从第2列开始)
four = train_data.iloc[3, 1:]

# 查看图像数据的形状
four.shape  # 应返回 (784,)

# 将像素数据重塑为28x28的二维数组
four = four.values.reshape(28, 28)

# 使用matplotlib显示图像
plt.imshow(four, cmap='gray')
plt.title("Digit 4") 
plt.show()

其中 four.values.reshape(28,28)将一维的 784 个像素值重塑为 28x28 的二维数组,以便能够显示为图像。其图像如下:

image-20241219195117850

数据处理

然后,通过决策树模型进行训练,利用训练集中的数据来构建决策规则。

#导入必要的包
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt

# 导入训练数据
train_data = pd.read_csv(r"digit-recognizer\train.csv")//更换为自己训练集的地址

# 获取特征和标签
X = train_data.iloc[:, 1:]  
y = train_data.iloc[:, 0]   

归一化的目的在于确保不同特征的数值范围相近,从而避免某些特征由于数值较大或较小而在模型训练中占据过多的权重。

# 将数据归一化
X = X / 255.0

​ 我们将原始数据集按照 80% 用于训练集,20% 用于验证集进行划分,确保模型在不同数据集上的泛化能力。具体操作是使用 train_test_split 函数,从特征矩阵 X 和标签向量 y 中分别拆分出训练集 (X_train, y_train) 和验证集 (X_val, y_val),并设定随机种子 random_state=42 以保证实验的可重复性。

# 将训练集拆分成训练集和验证集(80%训练集,20%验证集)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

模型的创建与训练

​ 我们使用决策树分类器(DecisionTreeClassifier)构建模型,并设定了如下参数:criterion='gini',选择基尼指数作为分裂标准,另一种常见的选择是信息增益(criterion='entropy'),该标准基于信息论中的熵来决定分裂方式。为了控制模型复杂度,max_depth=10 限制决策树的最大深度,从而避免过拟合。random_state=42 确保模型训练过程的可重复性。接着,使用训练集 (X_train, y_train) 对模型进行训练,使其能够学习数据中的规律。

# 创建并训练决策树模型
model = DecisionTreeClassifier(criterion='gini', max_depth=10, random_state=42)
model.fit(X_train, y_train)

模型预测与评估

预测与准确率计算:

使用训练好的决策树模型对验证集 X_val 进行预测,并将预测结果存储在 y_pred 中。随后,利用 accuracy_score 函数计算预测结果与真实标签 y_val 之间的准确率,并输出该准确率值,作为评估模型性能的基本指标。

# 使用验证集进行预测
y_pred = model.predict(X_val)
# 计算准确率
accuracy = accuracy_score(y_val, y_pred)
print("准确率:", accuracy)

​ 我得到的结果

image-20241219201052513

分类报告:
通过 classification_report 输出详细的分类报告,报告中包含了各类别的精确度(precision)、召回率(recall)、F1 分数(f1-score)等指标,有助于我们更全面地评估模型在各个类别上的表现。

# 输出分类报告
print("\n分类报告:\n", classification_report(y_val, y_pred))
image-20241219201128513

模型表现分析

  • 按类别分析
    • 表现最好的类别:类别 01 的精确率和召回率均超过 0.90,说明模型对这些类别的识别能力较强。
    • 表现较差的类别:类别 589 的精确率和召回率较低,尤其是类别 8,精确率仅为 0.76,说明模型对该类别的区分能力不足。
  • 整体性能
    • accuracy(准确率):模型的总体准确率为 85%,说明 85% 的样本被正确分类。
    • macro avg(宏平均):对所有类别的指标直接取平均值,不考虑类别样本不均衡的情况。精确率、召回率和 F1 分数均为 0.85。
    • weighted avg(加权平均):对所有类别的指标按 support 加权平均,更能反映整体表现的真实情况,值也为 0.85。

混淆矩阵

# 输出混淆矩阵
print("\n混淆矩阵:\n", confusion_matrix(y_val, y_pred))
image-20241219201653811

表示真实标签(从 0 到 9)。

表示预测标签(从 0 到 9)。

对角线上的值表示模型正确分类的样本数量。

非对角线的值表示分类错误的样本数量,具体来说:

  • 第 i行第 j列的值表示真实标签为 i的样本中,有多少被错误分类为 j。

  • 非对角线元素的值越大,说明该类别的错误分类越严重。

具体来看:

  • 类别 0

    • 正确分类:748 个样本被正确分类。
    • 主要误分类情况:有 18 个样本被错误分类为类别 6,这可能说明类别 0 和类别 6 的特征相似。
  • 类别 1

    • 正确分类:857 个样本被正确分类。
    • 误分类情况:误分类数量总体较少,说明模型对类别 1 的区分能力较强。
  • 类别 5

    • 正确分类:533 个样本被正确分类。
    • 主要误分类情况:有 30 个样本被错误分类为类别 4,这可能是类别 5 和类别 4 的特征不够明显导致的。
  • 类别 8

    • 正确分类:669 个样本被正确分类。
    • 主要误分类情况:有 36 个样本被错误分类为类别 3,说明类别 8 和类别 3 存在一定的混淆。
  • 类别 38、类别 45 之间的混淆较为严重。

  • 某些类别的误分类可能受限于特征的相似性或数据分布的问题。

源码

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt

# 导入训练数据
train_data = pd.read_csv(r"更改为你的文件地址")

# 获取特征和标签
X = train_data.iloc[:, 1:]  # 取出像素数据
y = train_data.iloc[:, 0]   # 取出标签(数字)

# 将数据归一化
X = X / 255.0

# 将训练集拆分成训练集和验证集(80%训练集,20%验证集)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# 创建并训练决策树模型
model = DecisionTreeClassifier(criterion='gini', max_depth=10, random_state=42)
model.fit(X_train, y_train)

# 使用验证集进行预测
y_pred = model.predict(X_val)

# 计算准确率
accuracy = accuracy_score(y_val, y_pred)
print("准确率:", accuracy)

# 输出分类报告
print("\n分类报告:\n", classification_report(y_val, y_pred))

# 输出混淆矩阵
print("\n混淆矩阵:\n", confusion_matrix(y_val, y_pred))

标签:训练,val,分类,train,类别,import,手写,识别,决策树
From: https://blog.csdn.net/rose765/article/details/144594117

相关文章

  • 基于vgg16和efficientnet卷积神经网络的天气识别系统(pytorch框架) 图像识别与分类 前
    基于vgg16和efficientnet卷积神经网络的天气识别系统(pytorch框架)前端界面:flask+python,UI界面:pyqt5+python这是一个完整项目,包括代码,数据集,模型训练记录,前端界面,ui界面,各种指标图:包括准确率,精确率,召回率,F1值,损失曲线,准确率曲线等卷积模型采用vgg16模型或efficien......
  • 基于OpenCV和Python的人脸识别系统
    一、系统概述基于OpenCV和Python的人脸识别系统利用先进的算法和工具,提供高效、准确的人脸识别服务。该系统可以应用于安全监控、门禁系统、移动支付、智能设备解锁等多个场景,具有广泛的应用价值和商业价值。二、核心组件OpenCV:OpenCV是一个开源的计算机视觉和机器学习......
  • 人车防碰撞识别智慧矿山一体机矿山监控系统中的平台一体机和解码器如何选型?
    在构建高效、可靠的视频监控系统时,选择合适的平台一体机和解码器是至关重要的一步。这不仅关系到监控系统的稳定性和可靠性,还直接影响到监控画面的清晰度和系统的扩展性。以下是在选择过程中需要考虑的关键因素,以确保您的监控系统能够满足特定场景的需求,并在未来几年内保持其先进......
  • 毕业设计:python车牌识别系统 HyperLPR车牌识别(深度学习) 可视化 Django框架 大数据毕业
    python车牌识别系统HyperLPR车牌识别(深度学习)可视化Django框架大数据毕业设计(源码+文档)1、项目介绍技术栈:Python语言、Django框架、MySQL数据库、HyperLPR库车牌识别(深度学习)、Echarts可视化系统功能:车牌号码识别,车牌所属省份,再给他搞个各省统计分析,柱状图,折线图......
  • 使用 C++ 和 Tesseract 实现验证码识别
    安装TesseractOCR首先,你需要安装TesseractOCR库。如果你还没有安装,请按照以下步骤操作:在Ubuntu上安装Tesseract:bashsudoaptupdatesudoaptinstalltesseract-ocrsudoaptinstalllibleptonica-devsudoaptinstalllibtesseract-dev在Windows上安装Tesse......
  • 使用 PHP 和 Tesseract 实现验证码识别
    Tesseract是一个开源的OCR引擎,能识别图像中的文本。我们将通过PHP调用Tesseract来实现验证码的识别。安装PHP和Tesseract首先,确保你的系统中安装了PHP和TesseractOCR。Tesseract安装(Ubuntu):bash更多内容访问ttocr.com或联系1436423940sudoapt-getupdatesud......
  • 在移动端页面如何忽略自动识别电话和邮箱?
    在前端开发过程中,如果你想要防止移动端浏览器自动识别并格式化电话号码或电子邮件地址,你可以使用以下几种方法:1.使用HTML的meta标签你可以尝试在HTML的<head>部分添加meta标签来禁用电话号码和电子邮件地址的自动识别。虽然这种方法的效果可能因浏览器和设备而异,但它是一种简单......
  • 使用 Kotlin 实现验证码识别
    步骤安装Kotlin环境如果尚未安装Kotlin,可以通过以下方式安装:对于Android开发,可以通过安装AndroidStudio。对于其他平台,可以按照Kotlin官方文档中的指引进行安装。安装TesseractOCR在Kotlin中使用TesseractOCR,通常可以通过JNI(JavaNativeInterface)调用......
  • 使用 R 语言实现验证码识别
    在R中,我们可以使用tesseract包与TesseractOCR引擎进行验证码识别。这个包提供了对Tesseract的简单接口。步骤安装TesseractOCR引擎首先,你需要安装Tesseract引擎。可以通过以下方式安装:Linux:bashsudoapt-getinstalltesseract-ocrmacOS:bashbrewinstall......
  • 量化分析选优质基金实战-趋势识别
    一、基金量化分析的定义基金量化分析是指运用数学模型、统计方法和计算机技术对基金的各方面特征进行量化评估的过程。它涵盖了基金的业绩表现、风险水平、资产配置、投资风格等多个维度。业绩表现分析包括计算基金的各种收益率指标,如简单收益率、累计收益率、年化收......