手写数字识别-决策树
决策树
决策树是一种基于树形结构的分类算法,通过不断地根据特征划分数据集来实现分类。
数据集分析
在本任务中,我们使用的是著名的 MNIST 数据集(https://www.kaggle.com/code/nishan192/mnist-digit-recognition-using-svm中下载使用test即可),它包含了大量手写数字的图像,每个图像由 28x28 个像素组成。为了便于处理,数据集中的每个图像被展平为一个 784 维的向量。第一列是标签(label
),它表示图像所代表的数字类别。
我们从数据集中提取了一张图像并进行可视化,代码如下:
import pandas as pd
import matplotlib.pyplot as plt
# 导入训练数据集
train_data = pd.read_csv(r"digit-recognizer\train.csv") # 请根据实际路径更改
# 提取第四张图像的像素数据(从第2列开始)
four = train_data.iloc[3, 1:]
# 查看图像数据的形状
four.shape # 应返回 (784,)
# 将像素数据重塑为28x28的二维数组
four = four.values.reshape(28, 28)
# 使用matplotlib显示图像
plt.imshow(four, cmap='gray')
plt.title("Digit 4")
plt.show()
其中 four.values.reshape(28,28)
将一维的 784 个像素值重塑为 28x28 的二维数组,以便能够显示为图像。其图像如下:
数据处理
然后,通过决策树模型进行训练,利用训练集中的数据来构建决策规则。
#导入必要的包
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
# 导入训练数据
train_data = pd.read_csv(r"digit-recognizer\train.csv")//更换为自己训练集的地址
# 获取特征和标签
X = train_data.iloc[:, 1:]
y = train_data.iloc[:, 0]
归一化的目的在于确保不同特征的数值范围相近,从而避免某些特征由于数值较大或较小而在模型训练中占据过多的权重。
# 将数据归一化
X = X / 255.0
我们将原始数据集按照 80% 用于训练集,20% 用于验证集进行划分,确保模型在不同数据集上的泛化能力。具体操作是使用 train_test_split
函数,从特征矩阵 X
和标签向量 y
中分别拆分出训练集 (X_train
, y_train
) 和验证集 (X_val
, y_val
),并设定随机种子 random_state=42
以保证实验的可重复性。
# 将训练集拆分成训练集和验证集(80%训练集,20%验证集)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
模型的创建与训练
我们使用决策树分类器(DecisionTreeClassifier
)构建模型,并设定了如下参数:criterion='gini'
,选择基尼指数作为分裂标准,另一种常见的选择是信息增益(criterion='entropy'
),该标准基于信息论中的熵来决定分裂方式。为了控制模型复杂度,max_depth=10
限制决策树的最大深度,从而避免过拟合。random_state=42
确保模型训练过程的可重复性。接着,使用训练集 (X_train
, y_train
) 对模型进行训练,使其能够学习数据中的规律。
# 创建并训练决策树模型
model = DecisionTreeClassifier(criterion='gini', max_depth=10, random_state=42)
model.fit(X_train, y_train)
模型预测与评估
预测与准确率计算:
使用训练好的决策树模型对验证集 X_val
进行预测,并将预测结果存储在 y_pred
中。随后,利用 accuracy_score
函数计算预测结果与真实标签 y_val
之间的准确率,并输出该准确率值,作为评估模型性能的基本指标。
# 使用验证集进行预测
y_pred = model.predict(X_val)
# 计算准确率
accuracy = accuracy_score(y_val, y_pred)
print("准确率:", accuracy)
我得到的结果
分类报告:
通过 classification_report
输出详细的分类报告,报告中包含了各类别的精确度(precision)、召回率(recall)、F1 分数(f1-score)等指标,有助于我们更全面地评估模型在各个类别上的表现。
# 输出分类报告
print("\n分类报告:\n", classification_report(y_val, y_pred))
模型表现分析:
- 按类别分析:
- 表现最好的类别:类别 0 和 1 的精确率和召回率均超过 0.90,说明模型对这些类别的识别能力较强。
- 表现较差的类别:类别 5、8 和 9 的精确率和召回率较低,尤其是类别 8,精确率仅为 0.76,说明模型对该类别的区分能力不足。
- 整体性能:
- accuracy(准确率):模型的总体准确率为 85%,说明 85% 的样本被正确分类。
- macro avg(宏平均):对所有类别的指标直接取平均值,不考虑类别样本不均衡的情况。精确率、召回率和 F1 分数均为 0.85。
- weighted avg(加权平均):对所有类别的指标按
support
加权平均,更能反映整体表现的真实情况,值也为 0.85。
混淆矩阵
# 输出混淆矩阵
print("\n混淆矩阵:\n", confusion_matrix(y_val, y_pred))
行表示真实标签(从 0 到 9)。
列表示预测标签(从 0 到 9)。
对角线上的值表示模型正确分类的样本数量。
非对角线的值表示分类错误的样本数量,具体来说:
-
第 i行第 j列的值表示真实标签为 i的样本中,有多少被错误分类为 j。
-
非对角线元素的值越大,说明该类别的错误分类越严重。
具体来看:
-
类别 0:
- 正确分类:748 个样本被正确分类。
- 主要误分类情况:有 18 个样本被错误分类为类别 6,这可能说明类别 0 和类别 6 的特征相似。
-
类别 1:
- 正确分类:857 个样本被正确分类。
- 误分类情况:误分类数量总体较少,说明模型对类别 1 的区分能力较强。
-
类别 5:
- 正确分类:533 个样本被正确分类。
- 主要误分类情况:有 30 个样本被错误分类为类别 4,这可能是类别 5 和类别 4 的特征不够明显导致的。
-
类别 8:
- 正确分类:669 个样本被正确分类。
- 主要误分类情况:有 36 个样本被错误分类为类别 3,说明类别 8 和类别 3 存在一定的混淆。
-
类别 3 和 8、类别 4 和 5 之间的混淆较为严重。
-
某些类别的误分类可能受限于特征的相似性或数据分布的问题。
源码
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
# 导入训练数据
train_data = pd.read_csv(r"更改为你的文件地址")
# 获取特征和标签
X = train_data.iloc[:, 1:] # 取出像素数据
y = train_data.iloc[:, 0] # 取出标签(数字)
# 将数据归一化
X = X / 255.0
# 将训练集拆分成训练集和验证集(80%训练集,20%验证集)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# 创建并训练决策树模型
model = DecisionTreeClassifier(criterion='gini', max_depth=10, random_state=42)
model.fit(X_train, y_train)
# 使用验证集进行预测
y_pred = model.predict(X_val)
# 计算准确率
accuracy = accuracy_score(y_val, y_pred)
print("准确率:", accuracy)
# 输出分类报告
print("\n分类报告:\n", classification_report(y_val, y_pred))
# 输出混淆矩阵
print("\n混淆矩阵:\n", confusion_matrix(y_val, y_pred))
标签:训练,val,分类,train,类别,import,手写,识别,决策树
From: https://blog.csdn.net/rose765/article/details/144594117