从入门到实战:深入解析 K 最近邻(KNN)算法在手写数字分类中的应用
K最近邻(K-Nearest Neighbors, KNN)算法
基本原理
K最近邻(KNN)是一种基于距离度量的监督学习算法,既可以用于分类,也可以用于回归任务。其核心思想是,对于一个需要预测的样本点,通过计算它与训练集中所有样本点之间的距离,选取距离最近的 K 个邻居,并基于这些邻居的特性来决定待预测点的输出。KNN 是一种懒惰学习算法,因为它在训练阶段不生成显式的模型,而是在预测阶段利用整个训练数据集来进行计算。
具体来说,KNN 的工作原理可以分为以下几步:
-
计算距离:对于每个待预测的数据点,计算它与训练集中所有样本点之间的距离。常用的距离度量包括:
- 欧氏距离(Euclidean Distance):最常用的距离度量,计算两个点在特征空间中的直线距离。公式为:
d ( x , y ) = ∑ i = 1 n ( x i − y i ) 2 d(x, y) = \sqrt{\sum_{i=1}^n (x_i - y_i)^2} d(x,y)=i=1∑n(xi−yi)2 - 曼哈顿距离(Manhattan Distance):适合衡量两个点在网格空间中的路径距离,公式为:
d ( x , y ) = ∑ i = 1 n ∣ x i − y i ∣ d(x, y) = \sum_{i=1}^n |x_i - y_i| d(x,y)=i=1∑n∣xi−yi∣ - 切比雪夫距离(Chebyshev Distance):衡量两个点在各个维度上的最大差异,公式为:
d ( x , y ) = max i = 1 n ∣ x i − y i ∣ d(x, y) = \max_{i=1}^n |x_i - y_i| d(x,y)=i=1maxn∣xi−yi∣ - 闵可夫斯基距离(Minkowski Distance):是欧氏距离和曼哈顿距离的泛化形式,公式为:
d ( x , y ) = ( ∑ i = 1 n ∣ x i − y i ∣ p ) 1 p d(x, y) = \left( \sum_{i=1}^n |x_i - y_i|^p \right)^{\frac{1}{p}} d(x,y)=(i=1∑n∣xi−yi∣p)p1
当 p = 2 p=2 p=2 时为欧氏距离,当 p = 1 p=1 p=1 时为曼哈顿距离。特别地,当 p = 3 p=3 p=3 时,公式为:
d ( x , y ) = ( ∑ i = 1 n ∣ x i − y i ∣ 3 ) 1 3 d(x, y) = \left( \sum_{i=1}^n |x_i - y_i|^3 \right)^{\frac{1}{3}} d(x,y)=(i=1∑n∣xi−yi∣3)31
该距离更注重点之间的较大差异,对某些特定任务可以体现出更高的区分能力。
- 欧氏距离(Euclidean Distance):最常用的距离度量,计算两个点在特征空间中的直线距离。公式为:
-
选择 K 值:K 值表示最近邻居的数量,是 KNN 的超参数。K 的大小直接影响分类效果:
- 如果 K K K 值较小,则分类决策主要受局部邻居的影响,可能会对噪声数据过于敏感,导致过拟合。
- 如果 K K K 值较大,则分类决策会基于更大的邻域,可能会忽略局部结构,导致欠拟合。
-
确定邻居:从计算出的距离中选择距离最近的 K 个样本点作为邻居。
-
分类或回归:
- 分类:根据 K 个邻居的类别,通过多数投票法决定待分类点的类别。
- 回归:通过对 K 个邻居的目标值取平均来预测目标值。
特点
-
简单直观:KNN 是一种直观易懂的算法,没有复杂的数学模型或训练过程。它直接基于样本间的距离计算和邻居的类别或目标值来进行预测。
-
懒惰学习:KNN 属于懒学习算法(Lazy Learning),即在训练阶段不构建显式模型,仅将训练数据存储起来,推迟计算到预测阶段。尽管这种特性简化了训练过程,但预测时计算量较大。
-
非参数化:KNN 不对数据的分布做任何假设,因而可以处理复杂的、非线性分布的数据。这使得它非常灵活,但也意味着它对高维数据的表现可能不佳(受维度灾难的影响)。
-
计算复杂度高:由于 KNN 需要对每个预测样本计算所有训练样本的距离,因此预测阶段的计算复杂度为 O ( n ⋅ m ) O(n \cdot m) O(n⋅m),其中 n n n 是训练样本数量, m m m 是特征维度。
-
对噪声敏感:KNN 对于数据中的异常点和噪声非常敏感。这是因为噪声数据可能会被误选为邻居,从而导致分类错误。通过使用加权距离(较近的邻居权重更大)可以部分缓解此问题。
-
维度灾难:在高维空间中,所有样本之间的距离差异可能会变得非常小,难以区分。此时,降维(如 PCA)或特征选择是有效的解决方法。
-
K 值选择的重要性:
- 小的 K K K 值能够捕获局部模式,但可能受噪声影响较大。
- 大的 K K K 值能够平滑决策边界,但可能会忽略局部细节,影响预测精度。
- 通常通过交叉验证来选择合适的 K K K 值。
-
需要标准化特征:KNN 直接基于距离度量,特征的尺度差异会显著影响结果。因此,在应用 KNN 前通常需要对数据进行标准化或归一化处理。
总结
KNN 算法以其直观性和易实现性著称,特别适合小型数据集和低维特征空间中的分类和回归任务。然而,其高计算复杂度和对噪声的敏感性在大规模、高维数据上表现不佳。通过调整 K 值、使用合适的距离度量(如闵可夫斯基距离)以及对数据进行预处理,可以有效提升 KNN 的性能。
实战
基于KNN对手写数字进行分类
首先,针对MNIST数据集的特点,我采用了合理的数据预处理策略。考虑到手写数字图像中存在的书写风格差异、笔画粗细不一等问题,我们使用StandardScaler对数据进行标准化处理。这种处理方式不仅规范化了特征的尺度,还有效减少了书写习惯差异带来的影响。经过处理的数据集包含1797个样本,每个样本由64个特征组成,这些特征对应于8×8像素图像的灰度值。
在数据集划分方面,采用了分层抽样的方式,按照8:2的比例将数据集划分为训练集(1437个样本)和测试集(360个样本)。分层抽样确保了训练集和测试集中各个数字类别的分布保持一致,这对于保证模型训练的有效性和评估结果的可靠性至关重要。
在分类器的具体实现中,我们重点关注了距离度量方式的选择。考虑到手写数字识别任务的特点,我们实现了多种距离度量方式,包括欧氏距离、曼哈顿距离、切比雪夫距离和闵可夫斯基距离(p=3)。这种多样化的距离度量方式设计,使得分类器能够从不同角度捕捉数字图像的特征差异[8]。
超参数调节
在KNN算法中,近邻数K和距离度量方式是两个关键的超参数。为了找到最优的参数组合,我们进行了系统的参数搜索。对于K值的选择,我们在1到30的范围内进行了遍历测试。如下图1所示,可以看出,在变量唯一的情况下,K=4的时候,KNN在MNIST数据集上的训练结果无论是从准确率、精确率、召回率还是F1分数,性能会远远超过在K等于其他值。所以在后面选择距离度量方式的时候,选择是基于K=4的情况下做修改。
模型训练与测试
为了提高分类器的计算效率,我在实现中采用了并行计算策略,充分利用多核处理器的优势。同时,我还实现了一个高效的近邻搜索机制,这在保证分类准确性的同时,显著提升了模型的预测速度。这种优化对于实际应用中的实时识别需求具有重要意义。
为了确保实验结果的可靠性,我采用了5折交叉验证的方式进行模型训练和评估。在每一折的验证中,都保持了数据的分层抽样特性,确保了各个数字类别的分布一致性。训练过程中,我记录了不同参数组合下的模型性能和训练时间,这为后续的模型选择提供了重要依据。如图4所示,是5折交叉验证之后的结果,可以发现,使用欧式距离和闵可夫斯基距离(p=3)的最后模型性能比较高,使用切比雪夫距离度量模型最后性能最差。总体来说,使用欧氏距离度量,在这次任务中有明显的优势。
结合图6可以看出,在对图片进行预测时,使用欧氏距离和曼哈顿距离预测准确率都高达100%,使用切比雪夫距离和Minkowski距离(p=3)预测准确率最终均为80%。因为样本只用了10张,所以这个准确率只能做一个参考。
性能评估与混淆矩阵绘制
为了从多方面评估模型性能,我把使用K=4、欧氏距离度量训练的KNN模型最终在手写数字数据集上的预测结果画出了相关的混淆矩阵,如图7所示。可以看出,模型在预测6个数字上准确率都达到了100%,3个数字准确率达到了90%以上,从各方面评估发现,KNN在手写数字的识别上有不错的效果。
通过对混淆矩阵中预测错误的样本进行分析,发现:
- 某些类别(如“3”和“8”)在特征空间中的相似性较高,导致误分类。
- 个别样本笔画模糊(如“4”写成了“9”),对模型的区分能力提出挑战。
针对这些问题,后续改进方案可包括:
- 增加数据增强(如旋转、模糊处理)提升模型鲁棒性。
- 结合特征提取方法(如 CNN)捕获更深层次的特征。
完整代码
这里提供完整的训练和测试代码,大家有兴趣可以自己跑一下
训练代码
import matplotlib
matplotlib.use('Agg')
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
import numpy as np
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, cross_val_score, KFold
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import make_scorer, accuracy_score, precision_score, recall_score, f1_score
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import time
from sklearn.metrics import precision_recall_fscore_support
def load_and_preprocess_data():
"""加载并预处理数据"""
print("正在加载数据...")
digits = load_digits()
X = digits.data
y = digits.target
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
X_scaled, y, test_size=0.2, random_state=42, stratify=y
)
print(f"训练集大小: {X_train.shape[0]} 样本")
print(f"测试集大小: {X_test.shape[0]} 样本")
return X_train, X_test, y_train, y_test, scaler
def perform_cross_validation(X_train, X_test, y_train, y_test, k=5, n_folds=5):
"""使用交叉验证评估不同距离度量方式的性能"""
# 定义要测试的距离度量方式
metrics = {
'欧氏距离': {'metric': 'euclidean'},
'曼哈顿距离': {'metric': 'manhattan'},
'切比雪夫距离': {'metric': 'chebyshev'},
'Minkowski (p=3)': {'metric': 'minkowski', 'p': 3},
}
# 定义交叉验证
kf = KFold(n_splits=n_folds, shuffle=True, random_state=42)
# 定义评估指标
scoring = {
'准确率': make_scorer(accuracy_score),
'精确率': make_scorer(precision_score, average='macro'),
'召回率': make_scorer(recall_score, average='macro'),
'F1分数': make_scorer(f1_score, average='macro')
}
# 创建更详细的结果记录
detailed_results = []
for name, params in metrics.items():
print(f"\n评估 {name} 的性能...")
knn = KNeighborsClassifier(n_neighbors=k, **params)
start_time = time.time()
# 对每个评估指标进行交叉验证
cv_scores = {}
for score_name, scorer in scoring.items():
scores = cross_val_score(knn, X_train, y_train,
scoring=scorer, cv=kf, n_jobs=-1)
cv_scores[score_name] = scores.mean()
cv_scores[f'{score_name}_std'] = scores.std()
# 在完整训练集上训练并评估
knn.fit(X_train, y_train)
test_score = knn.score(X_test, y_test)
total_time = time.time() - start_time
# 保存详细结果
detailed_results.append({
'距离度量': name,
'交叉验证准确率': cv_scores['准确率'],
'交叉验证准确率标准差': cv_scores['准确率_std'],
'交叉验证精确率': cv_scores['精确率'],
'交叉验证精确率标准差': cv_scores['精确率_std'],
'交叉验证召回率': cv_scores['召回率'],
'交叉验证召回率标准差': cv_scores['召回率_std'],
'交叉验证F1分数': cv_scores['F1分数'],
'交叉验证F1分数标准差': cv_scores['F1分数_std'],
'测试集准确率': test_score,
'总耗时(秒)': total_time,
'准确率置信区间下限': cv_scores['准确率'] - 2*cv_scores['准确率_std'],
'准确率置信区间上限': cv_scores['准确率'] + 2*cv_scores['准确率_std']
})
# 打印当前结果
print(f"交叉验证准确率: {cv_scores['准确率']:.4f} (±{cv_scores['准确率_std']:.4f})")
print(f"准确率95%置信区间: [{detailed_results[-1]['准确率置信区间下限']:.4f}, "
f"{detailed_results[-1]['准确率置信区间上限']:.4f}]")
print(f"测试集准确率: {test_score:.4f}")
print(f"总耗时: {total_time:.4f}秒")
# 创建详细的DataFrame
results_df = pd.DataFrame(detailed_results)
# 保存完整的评估结果
results_df.to_csv('detailed_metrics_results.csv', index=False)
# 创建一个更易读的总结表格
summary_df = pd.DataFrame({
'距离度量': results_df['距离度量'],
'交叉验证准确率': results_df.apply(lambda x: f"{x['交叉验证准确率']:.4f} (±{x['交叉验证准确率标准差']:.4f})", axis=1),
'准确率置信区间': results_df.apply(lambda x: f"[{x['准确率置信区间下限']:.4f}, {x['准确率置信区间上限']:.4f}]", axis=1),
'测试集准确率': results_df['测试集准确率'].apply(lambda x: f"{x:.4f}"),
'总耗时(秒)': results_df['总耗时(秒)'].apply(lambda x: f"{x:.4f}")
})
# 保存总结表格
summary_df.to_csv('metrics_summary.csv', index=False)
return results_df, summary_df
def plot_cv_results(results_df, summary_df):
"""可视化交叉验证结果"""
# 1. 绘制平均性能对比图
plt.figure(figsize=(15, 8))
x = np.arange(len(results_df))
width = 0.2
# 设置y轴范围
plt.ylim(0, 1.0)
plt.bar(x - width*1.5, results_df['交叉验证准确率'], width,
yerr=results_df['交叉验证准确率标准差'],
label='准确率', color='#2ecc71', capsize=5)
plt.bar(x - width/2, results_df['交叉验证精确率'], width,
label='精确率', color='#e74c3c', capsize=5)
plt.bar(x + width/2, results_df['交叉验证召回率'], width,
label='召回率', color='#3498db', capsize=5)
plt.bar(x + width*1.5, results_df['交叉验证F1分数'], width,
label='F1分数', color='#f1c40f', capsize=5)
plt.title('不同距离度量方式的交叉验证性能比较', fontsize=14)
plt.xlabel('距离度量方式', fontsize=12)
plt.ylabel('性能指标值', fontsize=12)
plt.xticks(x, results_df['距离度量'], rotation=45)
# 设置y轴刻度为百分比格式
plt.yticks(np.arange(0, 1.1, 0.1), [f'{x:.0%}' for x in np.arange(0, 1.1, 0.1)])
plt.legend(prop={'family': 'SimHei'})
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('cv_performance_comparison.png', bbox_inches='tight', dpi=300)
plt.close()
def main():
# 加载和预处理数据
X_train, X_test, y_train, y_test, scaler = load_and_preprocess_data()
# 进行交叉验证
print("\n开始交叉验证评估...")
results_df, summary_df = perform_cross_validation(X_train, X_test, y_train, y_test)
# 可视化结果
print("\n生成可视化结果...")
plot_cv_results(results_df, summary_df)
# 打印总结表格
print("\n模型性能总结:")
print(summary_df.to_string(index=False))
# 找出最佳距离度量方式
best_model_info = results_df.loc[results_df['交叉验证准确率'].idxmax()]
best_metric = best_model_info['距离度量']
print("\n最佳模型信息:")
print(f"距离度量方式: {best_metric}")
print(f"交叉验证准确率: {best_model_info['交叉验证准确率']:.4f} (±{best_model_info['交叉验证准确率标准差']:.4f})")
print(f"测试集准确率: {best_model_info['测试集准确率']:.4f}")
# 使用最佳距离度量方式训练最终模型
print(f"\n使用最佳距离度量方式({best_metric})训练最终模型...")
metric_params = {
'欧氏距离': {'metric': 'euclidean'},
'曼哈顿距离': {'metric': 'manhattan'},
'切比雪夫距离': {'metric': 'chebyshev'},
'Minkowski (p=3)': {'metric': 'minkowski', 'p': 3},
}
final_model = KNeighborsClassifier(n_neighbors=5, **metric_params[best_metric])
final_model.fit(X_train, y_train)
final_accuracy = final_model.score(X_test, y_test)
# 分析每个数字的性能并生成图表
y_pred = final_model.predict(X_test)
precision, recall, f1, support = precision_recall_fscore_support(y_test, y_pred)
# 创建性能指标数据框
performance_df = pd.DataFrame({
'精确率': precision,
'召回率': recall,
'F1分数': f1,
'样本数量': support
}, index=range(10))
# 绘制每个数字的性能指标
plt.figure(figsize=(15, 6))
performance_df[['精确率', '召回率', 'F1分数']].plot(kind='bar', width=0.8)
plt.title(f'每个数字的识别性能 (使用{best_metric})', fontsize=14)
plt.xlabel('数字', fontsize=12)
plt.ylabel('性能指标', fontsize=12)
plt.legend(prop={'family': 'SimHei'})
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('digit_performance.png', bbox_inches='tight', dpi=300)
plt.close()
# 保存性能数据
performance_df.to_csv('digit_performance.csv', index=True)
print(f"最终模型测试集准确率: {final_accuracy:.4f}")
print("\n每个数字的识别性能已保存到 digit_performance.png 和 digit_performance.csv")
# 保存最终模型和数据处理器
joblib.dump(final_model, 'best_knn_model.pkl')
joblib.dump(scaler, 'scaler.pkl')
print("\n最终模型和数据处理器已保存")
if __name__ == "__main__":
main()
测试代码
import numpy as np
import matplotlib
matplotlib.use('Agg')
matplotlib.rcParams['font.sans-serif'] = ['SimHei']
matplotlib.rcParams['axes.unicode_minus'] = False
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
import joblib
import glob
def load_sample_digits():
"""从数据集中每个数字加载一个样本"""
digits = load_digits()
X = digits.data
y = digits.target
# 为每个数字(0-9)选择一个样本
sample_images = []
sample_data = []
sample_labels = []
np.random.seed(42) # 设置随机种子以确保结果可重现
for digit in range(10):
# 找到当前数字的所有索引
indices = np.where(y == digit)[0]
# 随机选择一个索引
sample_idx = np.random.choice(indices)
# 保存图像数据
sample_images.append(digits.images[sample_idx])
# 保存特征数据
sample_data.append(X[sample_idx])
# 保存真实标签
sample_labels.append(y[sample_idx])
return np.array(sample_images), np.array(sample_data), np.array(sample_labels)
def predict_and_visualize(models, images, data, true_labels):
"""使用所有模型预测并可视化结果"""
n_models = len(models)
n_samples = len(images)
# 创建图表,增加左边的空间用于标注
fig = plt.figure(figsize=(22, 4*n_models))
gs = fig.add_gridspec(n_models, n_samples + 1, width_ratios=[1] + [3]*n_samples)
# 添加总标题
fig.suptitle('不同距离度量方式的手写数字识别结果', fontsize=16, y=0.95)
# 定义距离度量方式的中文名称映射
distance_names = {
'Euclidean': '欧氏距离',
'Manhattan': '曼哈顿距离',
'Chebyshev': '切比雪夫距离',
'Minkowski': 'Minkowski (p=3)'
}
# 创建自定义图例
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor='none', edgecolor='green', label='预测正确'),
Patch(facecolor='none', edgecolor='red', label='预测错误')
]
fig.legend(handles=legend_elements,
loc='upper right',
bbox_to_anchor=(0.98, 0.98),
prop={'family': 'SimHei', 'size': 12})
# 为每个模型进行预测和可视化
for i, (model_name, model) in enumerate(models.items()):
# 添加距离标注
ax_text = fig.add_subplot(gs[i, 0])
ax_text.axis('off')
# 获取对应的中文名称
for name_en, name_zh in distance_names.items():
if name_en.lower() in model_name.lower():
display_name = name_zh
break
else:
display_name = model_name
ax_text.text(0.5, 0.5, f'第{i+1}行:{display_name}',
ha='center', va='center',
fontsize=12, fontfamily='SimHei',
bbox=dict(facecolor='white', edgecolor='black', pad=5))
# 进行预测
predictions = model.predict(data)
# 显示每个数字及其预测结果
for j, (image, pred, true) in enumerate(zip(images, predictions, true_labels)):
ax = fig.add_subplot(gs[i, j+1])
ax.imshow(image, cmap='gray')
ax.axis('off')
# 根据预测是否正确设置标题颜色和边框颜色
color = 'green' if pred == true else 'red'
ax.set_title(f'预测: {pred}\n真实: {true}',
color=color, pad=10)
# 添加带颜色的边框
for spine in ax.spines.values():
spine.set_edgecolor(color)
spine.set_linewidth(2)
spine.set_visible(True)
plt.tight_layout()
# 调整子图之间的间距,为图例留出空间
plt.subplots_adjust(right=0.95)
plt.savefig('prediction_results.png', bbox_inches='tight', dpi=300)
plt.close()
def main():
try:
# 加载数据处理器
print("正在加载数据处理器...")
scaler = joblib.load('scaler.pkl')
# 加载所有模型
print("正在加载模型...")
models = {}
model_files = glob.glob('knn_model_*.pkl')
if not model_files:
raise FileNotFoundError("未找到任何模型文件!")
for model_file in model_files:
# 从文件名提取模型名称
model_name = model_file.replace('knn_model_', '').replace('.pkl', '')
model_name = model_name.replace('_', ' ').title()
models[model_name] = joblib.load(model_file)
# 加载样本数据
print("正在准备测试样本...")
sample_images, sample_data, true_labels = load_sample_digits()
# 数据标准化
sample_data = scaler.transform(sample_data)
# 预测并可视化
print("正在进行预测和可视化...")
predict_and_visualize(models, sample_images, sample_data, true_labels)
print("预测完成!结果已保存为 'prediction_results.png'")
# 计算并显示每个模型的准确率
print("\n各模型在测试样本上的表现:")
distance_names = {
'Euclidean': '欧氏距离',
'Manhattan': '曼哈顿距离',
'Chebyshev': '切比雪夫距离',
'Minkowski': 'Minkowski (p=3)'
}
for name, model in models.items():
predictions = model.predict(sample_data)
correct = np.sum(predictions == true_labels)
# 获取对应的中文名称
for name_en, name_zh in distance_names.items():
if name_en.lower() in name.lower():
display_name = name_zh
break
else:
display_name = name
print(f"{display_name}:")
print(f" 正确预测: {correct}/10")
print(f" 准确率: {correct/10:.2%}")
except FileNotFoundError as e:
print(f"错误:{str(e)}")
print("请先运行训练程序!")
except Exception as e:
print(f"发生错误:{str(e)}")
if __name__ == "__main__":
main()
不同度量方法比较
import matplotlib
matplotlib.use('Agg') # 使用Agg后端
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei'] # 设置中文字体
plt.rcParams['axes.unicode_minus'] = False # 正确显示负号
import numpy as np
import pandas as pd
def load_results():
"""加载所有方法的结果"""
try:
# 加载KNN的结果
knn_results = pd.read_csv('detailed_metrics_results.csv')
# 获取最佳距离度量方式的结果
best_metric = knn_results.loc[knn_results['交叉验证准确率'].idxmax()]
knn_metrics = {
'方法': 'KNN (最佳距离)',
'准确率': best_metric['交叉验证准确率'],
'精确率': best_metric['交叉验证精确率'],
'召回率': best_metric['交叉验证召回率'],
'F1分数': best_metric['交叉验证F1分数'],
'训练时间(秒)': best_metric['总耗时(秒)'],
'距离度量': best_metric['距离度量']
}
# 加载SVM的结果
svm_results = pd.read_csv('svm_detailed_performance.csv')
svm_avg = svm_results[svm_results['数字'] == '平均值'].iloc[0]
svm_std = svm_results[svm_results['数字'] == '标准差'].iloc[0]
svm_metrics = {
'方法': 'SVM',
'准确率': float(svm_avg['准确率']),
'精确率': float(svm_avg['精确率']),
'召回率': float(svm_avg['召回率']),
'F1分数': float(svm_avg['F1分数']),
'训练时间(秒)': float(svm_avg['训练时间(秒)'])
}
# ���载逻辑回归的结果
logistic_results = pd.read_csv('logistic_detailed_performance.csv')
logistic_avg = logistic_results[logistic_results['数字'] == '平均值'].iloc[0]
logistic_std = logistic_results[logistic_results['数字'] == '标准差'].iloc[0]
logistic_metrics = {
'方法': '逻辑回归',
'准确率': float(logistic_avg['准确率']),
'精确率': float(logistic_avg['精确率']),
'召回率': float(logistic_avg['召回率']),
'F1分数': float(logistic_avg['F1分数']),
'训练时间(秒)': float(logistic_avg['训练时间(秒)'])
}
# 加载CNN的结果
cnn_results = pd.read_csv('cnn_results.csv')
cnn_metrics = {
'方法': 'CNN',
'准确率': float(cnn_results['准确率'].iloc[0]),
'精确率': float(cnn_results['精确率'].iloc[0]),
'召回率': float(cnn_results['召回率'].iloc[0]),
'F1分数': float(cnn_results['F1分数'].iloc[0]),
'训练时间(秒)': float(cnn_results['训练时间(秒)'].iloc[0])
}
# 合并所有结果
comparison_df = pd.DataFrame([knn_metrics, svm_metrics, logistic_metrics, cnn_metrics])
return comparison_df, knn_metrics['距离度量']
except Exception as e:
print(f"加载结果时出错: {str(e)}")
return None, None
def plot_comparison(comparison_df):
"""绘制性能对比图"""
plt.figure(figsize=(15, 8))
x = np.arange(len(comparison_df))
width = 0.15
# 设置y轴范围
plt.ylim(0, 1.0)
# 绘制性能指标
bars1 = plt.bar(x - width*1.5, comparison_df['准确率'], width, label='准确率', color='#2ecc71')
bars2 = plt.bar(x - width/2, comparison_df['精确率'], width, label='精确率', color='#e74c3c')
bars3 = plt.bar(x + width/2, comparison_df['召回率'], width, label='召回率', color='#3498db')
bars4 = plt.bar(x + width*1.5, comparison_df['F1分数'], width, label='F1分数', color='#f1c40f')
# 添加数值标签
def add_labels(bars):
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, height,
f'{height:.3f}',
ha='center', va='bottom', rotation=90)
add_labels(bars1)
add_labels(bars2)
add_labels(bars3)
add_labels(bars4)
plt.title('不同方法的性能比较', fontsize=14)
plt.xlabel('方法', fontsize=12)
plt.ylabel('性能指标值', fontsize=12)
plt.xticks(x, comparison_df['方法'])
# 设置y轴刻度为百分比格式
plt.yticks(np.arange(0, 1.1, 0.1), [f'{x:.0%}' for x in np.arange(0, 1.1, 0.1)])
plt.legend(prop={'family': 'SimHei'})
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('methods_comparison.png', bbox_inches='tight', dpi=300)
plt.close()
def plot_training_time_comparison(comparison_df):
"""绘制训练时间对比图"""
plt.figure(figsize=(10, 6))
bars = plt.bar(comparison_df['方法'], comparison_df['训练时间(秒)'], color='#3498db')
plt.title('不同方法的训练时间比较', fontsize=14)
plt.xlabel('方法', fontsize=12)
plt.ylabel('训练时间(秒)', fontsize=12)
plt.grid(True, alpha=0.3)
# 添加具体数值标签
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2, height,
f'{height:.2f}s',
ha='center', va='bottom')
plt.tight_layout()
plt.savefig('methods_training_time.png', bbox_inches='tight', dpi=300)
plt.close()
def main():
# 加载结果
comparison_df, best_knn_metric = load_results()
if comparison_df is None:
return
# 创建详细的比较表格
detailed_comparison = []
# 加载标准差信息
try:
svm_results = pd.read_csv('svm_detailed_performance.csv')
logistic_results = pd.read_csv('logistic_detailed_performance.csv')
knn_results = pd.read_csv('detailed_metrics_results.csv')
# 获取KNN的标准差
best_knn_idx = knn_results['交叉验证准确率'].idxmax()
knn_std = {
'准确率': knn_results.loc[best_knn_idx, '交叉验证准确率标准差'],
'精确率': knn_results.loc[best_knn_idx, '交叉验证精确率标准差'],
'召回率': knn_results.loc[best_knn_idx, '交叉验证召回率标准差'],
'F1分数': knn_results.loc[best_knn_idx, '交叉验证F1分数标准差']
}
# 获取SVM的标准差
svm_std = svm_results[svm_results['数字'] == '标准差'].iloc[0]
# 获取逻辑回归的标准差
logistic_std = logistic_results[logistic_results['数字'] == '标准差'].iloc[0]
# 为每个方法创建详细记录
for _, row in comparison_df.iterrows():
method_name = row['方法']
if method_name == 'KNN (最佳距离)':
method_name = f"KNN ({best_knn_metric})"
std_values = knn_std
elif method_name == 'SVM':
std_values = {k: float(svm_std[k].replace('±', '')) for k in ['准确率', '精确率', '召回率', 'F1分数']}
elif method_name == '逻辑回归':
std_values = {k: float(logistic_std[k].replace('±', '')) for k in ['准确率', '精确率', '召回率', 'F1分数']}
else: # CNN
# CNN没有标准差信息,使用0
std_values = {k: 0.0 for k in ['准确率', '精确率', '召回率', 'F1分数']}
detailed_comparison.append({
'方法': method_name,
'准确率': f"{row['准确率']:.4f} (±{std_values['准确率']:.4f})",
'精确率': f"{row['精确率']:.4f} (±{std_values['精确率']:.4f})",
'召回率': f"{row['召回率']:.4f} (±{std_values['召回率']:.4f})",
'F1分数': f"{row['F1分数']:.4f} (±{std_values['F1分数']:.4f})",
'训练时间(秒)': f"{row['训练时间(秒)']:.4f}"
})
# 创建一个更美观的表格
comparison_table = pd.DataFrame(detailed_comparison)
# 保存为Excel文件以保持格式
comparison_table.to_excel('methods_comparison_detailed.xlsx', index=False)
# 同时保存为CSV
comparison_table.to_csv('methods_comparison_detailed.csv', index=False)
# 打印比较结果
print("\n不同方法的性能比较:")
print(comparison_table.to_string(index=False))
# 找出最佳方法
best_method = comparison_df.loc[comparison_df['准确率'].idxmax()]
print(f"\n最佳方法: {best_method['方法']}")
print(f"准确率: {best_method['准确率']:.4f}")
print(f"精确率: {best_method['精确率']:.4f}")
print(f"召回率: {best_method['召回率']:.4f}")
print(f"F1分数: {best_method['F1分数']:.4f}")
print(f"训练时间: {best_method['训练时间(秒)']:.4f}秒")
print("\n详细比较结果已保存到 'methods_comparison_detailed.xlsx' 和 'methods_comparison_detailed.csv'")
except Exception as e:
print(f"创建详细比较表格时出错: {str(e)}")
# 绘制比较图
plot_comparison(comparison_df)
plot_training_time_comparison(comparison_df)
if __name__ == "__main__":
main()
总结
本文从理论到实践,系统性地探讨了 K 最近邻(KNN)算法的基本原理、超参数调节方法,以及其在手写数字分类任务中的具体应用。通过分析实验结果,我们发现 KNN 在小型数据集上的表现相对稳健,尤其是在低维空间下,其直观的设计和良好的分类性能使其成为许多任务中的一种有效选择。此外,通过不同超参数的调节(如 K 值的选择和距离度量方法的调整)以及优化策略(如标准化处理、交叉验证)可以显著提升模型性能。
然而,KNN 也有其局限性,比如计算复杂度高,对高维数据不够友好,以及对噪声数据的敏感性等。结合实验中的错误案例,我们进一步提出了可能的改进方向,包括使用 KD 树或球树加速最近邻搜索、结合降维技术减少计算复杂度,以及引入数据增强提升模型对噪声的鲁棒性。未来,我们还可以将 KNN 应用到更复杂的数据集(如 CIFAR-10)或任务中,进一步探索其实际价值和适用范围。
作为一名算法学习的初学者,在撰写这篇文章的过程中,我深感 KNN 算法的应用广泛和背后数学原理的深刻,同时也意识到自己的许多不足。若文中存在错误或不严谨之处,恳请大家批评指正,我也期待和大家共同讨论学习,感谢您的耐心阅读!
标签:KNN,plt,入门,df,results,距离,准确率,print,手写 From: https://blog.csdn.net/m0_74882984/article/details/144261544