1 朴素贝叶斯介绍
朴素贝叶斯(Naive Bayes)分类器是基于贝叶斯定理的一种简单概率分类器。它假设各特征之间相互独立,这一假设被称为“朴素”的假设。朴素贝叶斯分类器广泛应用于文本分类、垃圾邮件检测等领域。
2 公式
朴素贝叶斯分类器的核心公式是贝叶斯定理:
其中:
-
是给定特征 时,样本属于类别 的后验概率。
-
是类别 的先验概率。
-
是给定类别 时,特征 的条件概率。
-
是所有特征的联合概率,通常可以视为一个常数,因为在计算类别概率时会相互抵消。
3 公式推导
由于 对于所有类别都是相同的,我们可以忽略它,只关注分子部分:
在朴素贝叶斯分类器中,我们通常使用极大似然估计来计算先验概率和条件概率:
其中:
-
是类别 的样本数量。
-
是特征 在类别 中出现的次数
-
是总样本数量。
4 案例实现
1. 创建并训练模型,预测测试集:
加载鸢尾花数据集
,只取前两个特征以便于可视化,并用这些特征训练朴素贝叶斯模型。
2. 创建并训练模型,预测测试集:
创建一个高斯朴素贝叶斯模型,并使用训练集数据进行训练。
3. 评估模型:
计算模型的准确率,并打印结果。
4. 绘制混淆绘制:
使用 seaborn
的 heatmap
函数绘制混淆矩阵,以直观地展示模型的分类效果。
5. 绘制分类结果图:
-
创建网格以绘制决策边界。
-
预测网格中每个点的分类结果。
-
绘制决策边界和分类点,包括训练集和测试集的数据点。
5 完整代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
from sklearn.metrics import accuracy_score, confusion_matrix
import seaborn as sns
# 1. 加载数据集
iris = load_iris()
X = iris.data[:, :2] # 只取前两个特征以便于可视化
y = iris.target
# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 3. 创建并训练模型
model = GaussianNB()
model.fit(X_train, y_train)
# 4. 预测测试集
y_pred = model.predict(X_test)
# 5. 评估模型
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
# 6. 绘制混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=iris.target_names, yticklabels=iris.target_names)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
# 7. 绘制分类结果图
# 创建网格以绘制决策边界
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
np.arange(y_min, y_max, 0.02))
# 预测网格中每个点的分类结果
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
# 绘制决策边界和分类点
plt.figure(figsize=(12, 8))
plt.contourf(xx, yy, Z, alpha=0.4)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, edgecolor='k', s=20, label='Training data')
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, edgecolor='k', s=20, marker='x', label='Test data')
plt.xlabel('Sepal length')
plt.ylabel('Sepal width')
plt.title('Classification Result')
plt.legend()
plt.show()
运行结果:
我们可以直观地看到模型在二维特征空间中的分类效果,包括决策边界和分类点。这有助于我们理解模型的分类行为和性能。
-
准确率
-
绘制混淆矩阵
-
绘制分类结果图
标签:plt,贝叶斯,---,train,test,import,绘制,朴素 From: https://blog.csdn.net/qq_51749909/article/details/143248868文章持续跟新,可以微信搜一搜公众号 [ rain雨雨编程 ],第一时间阅读,涉及数据分析,机器学习,Java编程,爬虫,实战项目等。