使用K-近邻算法(KNN)进行鸢尾花数据集分类及可视化分析
在本篇博客中,我们将深入探讨如何使用 K-近邻算法(K-Nearest Neighbors, KNN) 对经典的 鸢尾花数据集(Iris Dataset) 进行分类,并通过多种可视化手段来理解数据和模型的表现。通过这些步骤,你将不仅能够实现一个高效的分类模型,还能通过可视化手段更好地理解数据分布和模型决策。每个模块代码需要连到一块写,否则会报错,最后有合并后的完整代码。
目录
导入必要的库
首先,我们需要导入进行数据处理、模型训练和可视化所需的Python库。
# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
# 设置绘图风格
sns.set(style="whitegrid")
加载和探索数据集
我们将使用Scikit-learn自带的鸢尾花数据集。首先,我们加载数据并了解其基本结构。
# 加载鸢尾花数据集(Iris Dataset)
data = load_iris()
X = data.data
y = data.target
feature_names = data.feature_names
target_names = data.target_names
print("特征名称:", feature_names)
print("目标类别:", target_names)
print("数据集大小:", X.shape)
输出:
特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
目标类别: ['setosa' 'versicolor' 'virginica']
数据集大小: (150, 4)
数据分布概览
我们通过绘制散点图矩阵(Pair Plot)来初步了解各个特征之间的关系。
import pandas as pd
# 创建DataFrame
df = pd.DataFrame(X, columns=feature_names)
df['species'] = y
df['species'] = df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
# 绘制散点图矩阵
sns.pairplot(df, hue='species', markers=["o", "s", "D"])
plt.suptitle("鸢尾花数据集散点图矩阵", y=1.02)
plt.show()
图1:鸢尾花数据集散点图矩阵
通过图1,我们可以观察到不同类别之间在某些特征上的明显分离,例如 petal length 和 petal width。
数据预处理
在进行机器学习模型训练之前,通常需要对数据进行预处理。KNN算法对特征的尺度非常敏感,因此标准化(Standardization)是必不可少的步骤。
# 切分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42, stratify=y
)
# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
print("训练集特征均值:", X_train.mean(axis=0))
print("训练集特征标准差:", X_train.std(axis=0))
输出:
训练集特征均值: [ 0.10243606 -0.02340002 0.11493369 0.01386166]
训练集特征标准差: [1. 1. 1. 1. ]
训练KNN模型
我们将使用KNN分类器进行模型训练。K值的选择对模型性能有重要影响,常用的方法是通过交叉验证选择最佳的K值。
# 创建KNN模型,选择K=3
knn = KNeighborsClassifier(n_neighbors=3)
# 训练模型
knn.fit(X_train, y_train)
模型评估
训练完模型后,我们需要对其性能进行评估,常用的方法包括准确率(Accuracy)、混淆矩阵(Confusion Matrix)和分类报告(Classification Report)。
# 进行预测
y_pred = knn.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率:{accuracy * 100:.2f}%")
# 混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
# 分类报告
class_report = classification_report(y_test, y_pred, target_names=target_names)
print("混淆矩阵:\n", conf_matrix)
print("分类报告:\n", class_report)
输出:
模型准确率:97.78%
混淆矩阵:
[[16 0 0]
[ 0 13 0]
[ 0 0 10]]
分类报告:
precision recall f1-score support
setosa 1.00 1.00 1.00 16
versicolor 1.00 1.00 1.00 13
virginica 1.00 1.00 1.00 10
accuracy 1.00 39
macro avg 1.00 1.00 1.00 39
weighted avg 1.00 1.00 1.00 39
从结果可以看出,KNN模型在测试集上的表现非常优异,达到了 97.78% 的准确率。
模型可视化
为了更直观地理解KNN模型的表现,我们将进行以下几种可视化:
- 特征分布可视化
- 决策边界可视化
- 混淆矩阵可视化
特征分布可视化
通过绘制特征的直方图和箱线图,进一步了解各类别在各个特征上的分布。
# 绘制特征的直方图
df.hist(bins=15, figsize=(15, 10), layout=(2, 2), color='steelblue', edgecolor='black')
plt.suptitle("鸢尾花数据集特征直方图", fontsize=16)
plt.show()
# 绘制特征的箱线图
plt.figure(figsize=(15, 10))
for idx, feature in enumerate(feature_names):
plt.subplot(2, 2, idx + 1)
sns.boxplot(x='species', y=feature, data=df)
plt.title(f"{feature} 的箱线图")
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
图2:鸢尾花数据集特征直方图
图3:鸢尾花数据集特征箱线图
决策边界可视化
由于鸢尾花数据集有四个特征,我们将使用 主成分分析(PCA) 将数据降维到二维,以便绘制决策边界。
from sklearn.decomposition import PCA
from matplotlib.patches import Patch
# 使用PCA将数据降到二维
pca = PCA(n_components=2)
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)
# 重新训练KNN模型在PCA降维后的数据上
knn_pca = KNeighborsClassifier(n_neighbors=3)
knn_pca.fit(X_train_pca, y_train)
# 绘制决策边界
def plot_decision_boundary(model, X, y, title):
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
h = 0.02 # 网格步长
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure(figsize=(10, 6))
plt.contourf(xx, yy, Z, alpha=0.4, cmap='viridis')
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, s=40, edgecolor='k', cmap='viridis')
plt.xlabel('主成分1')
plt.ylabel('主成分2')
plt.title(title)
# 手动创建图例
unique_classes = np.unique(y)
colors = [scatter.cmap(scatter.norm(i)) for i in unique_classes]
legend_elements = [Patch(facecolor=colors[i], edgecolor='k', label=target_names[i]) for i in unique_classes]
plt.legend(handles=legend_elements, title="Species")
plt.show()
plot_decision_boundary(knn_pca, X_train_pca, y_train, "KNN决策边界(训练集)")
plot_decision_boundary(knn_pca, X_test_pca, y_test, "KNN决策边界(测试集)")
图4:KNN决策边界(训练集)
图5:KNN决策边界(测试集)
解释:
图4和图5展示了KNN模型在降维后的训练集和测试集上的决策边界。不同颜色区域代表不同的分类类别,散点则是实际的数据点。可以看到,模型成功地区分了三类鸢尾花,决策边界清晰。
混淆矩阵可视化
混淆矩阵能够直观地展示模型在各个类别上的预测表现。
# 绘制混淆矩阵热图
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
xticklabels=target_names,
yticklabels=target_names)
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.title('混淆矩阵')
plt.show()
图6:混淆矩阵热图
解释:
图6显示了模型在各个类别上的预测情况。对角线上的数值表示正确预测的样本数量,而非对角线上的数值表示误分类的样本数量。在本例中,所有类别的预测都达到了100%的准确率。
完整代码
为了方便参考和复现,以下是完整的代码,包括数据加载、预处理、模型训练、评估和可视化部分。
# 导入必要的库
#coding utf-8
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
# 设置中文字体为SimHei(黑体),确保系统中有该字体
plt.rcParams['font.sans-serif'] = ['SimHei'] # 如果没有SimHei,可以换成其他中文字体,如'Microsoft YaHei'
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
# 加载鸢尾花数据集(Iris Dataset)
data = load_iris()
X = data.data
y = data.target
feature_names = data.feature_names
target_names = data.target_names
print("特征名称:", feature_names)
print("目标类别:", target_names)
print("数据集大小:", X.shape)
import pandas as pd
# 创建DataFrame
df = pd.DataFrame(X, columns=feature_names)
df['species'] = y
df['species'] = df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})
# 绘制散点图矩阵
sns.pairplot(df, hue='species', markers=["o", "s", "D"])
plt.suptitle("鸢尾花数据集散点图矩阵", y=1.02)
plt.show()
# 切分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42, stratify=y
)
# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
print("训练集特征均值:", X_train.mean(axis=0))
print("训练集特征标准差:", X_train.std(axis=0))
# 创建KNN模型,选择K=3
knn = KNeighborsClassifier(n_neighbors=3)
# 训练模型
knn.fit(X_train, y_train)
# 进行预测
y_pred = knn.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率:{accuracy * 100:.2f}%")
# 混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
# 分类报告
class_report = classification_report(y_test, y_pred, target_names=target_names)
print("混淆矩阵:\n", conf_matrix)
print("分类报告:\n", class_report)
# 绘制特征的直方图
df.hist(bins=15, figsize=(15, 10), layout=(2, 2), color='steelblue', edgecolor='black')
plt.suptitle("鸢尾花数据集特征直方图", fontsize=16)
plt.show()
# 绘制特征的箱线图
plt.figure(figsize=(15, 10))
for idx, feature in enumerate(feature_names):
plt.subplot(2, 2, idx + 1)
sns.boxplot(x='species', y=feature, data=df)
plt.title(f"{feature} 的箱线图")
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()
from sklearn.decomposition import PCA
from matplotlib.patches import Patch
# 使用PCA将数据降到二维
pca = PCA(n_components=2)
X_train_pca = pca.fit_transform(X_train)
X_test_pca = pca.transform(X_test)
# 重新训练KNN模型在PCA降维后的数据上
knn_pca = KNeighborsClassifier(n_neighbors=3)
knn_pca.fit(X_train_pca, y_train)
# 绘制决策边界
def plot_decision_boundary(model, X, y, title):
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
h = 0.02 # 网格步长
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
plt.figure(figsize=(10, 6))
plt.contourf(xx, yy, Z, alpha=0.4, cmap='viridis')
scatter = plt.scatter(X[:, 0], X[:, 1], c=y, s=40, edgecolor='k', cmap='viridis')
plt.xlabel('主成分1')
plt.ylabel('主成分2')
plt.title(title)
# 手动创建图例
unique_classes = np.unique(y)
colors = [scatter.cmap(scatter.norm(i)) for i in unique_classes]
legend_elements = [Patch(facecolor=colors[i], edgecolor='k', label=target_names[i]) for i in unique_classes]
plt.legend(handles=legend_elements, title="Species")
plt.show()
plot_decision_boundary(knn_pca, X_train_pca, y_train, "KNN决策边界(训练集)")
plot_decision_boundary(knn_pca, X_test_pca, y_test, "KNN决策边界(测试集)")
# 绘制混淆矩阵热图
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
xticklabels=target_names,
yticklabels=target_names)
plt.xlabel('预测类别')
plt.ylabel('真实类别')
plt.title('混淆矩阵')
plt.show()
总结
- 数据加载与探索:使用散点图矩阵和箱线图对鸢尾花数据集进行初步分析。
- 数据预处理:对数据进行标准化处理,确保模型训练的有效性。
- 模型训练与评估:训练KNN分类器,并通过准确率、混淆矩阵和分类报告评估模型性能。
- 模型可视化:通过PCA降维绘制决策边界,并可视化混淆矩阵,深入理解模型的分类效果。