首页 > 其他分享 >【机器学习】利用逻辑回归对iris鸢尾花数据集进行分类

【机器学习】利用逻辑回归对iris鸢尾花数据集进行分类

时间:2024-11-21 16:45:41浏览次数:3  
标签:iris 逻辑 plt 分类 类别 train test import 鸢尾花

目标

本文旨在通过实现一个基础的逻辑回归分类模型,了解并应用逻辑回归模型,完成从数据加载、预处理到训练与评估的整个流程。通过使用Scikit-learn的逻辑回归模型,掌握如何进行模型训练与预测。学会评估模型性能,理解准确率、混淆矩阵及分类报告的含义。掌握混淆矩阵的可视化技术,通过图形化呈现分类结果,帮助分析模型性能。

环境

Python编程语言

Scikit-learn库

Matplotlib(用于数据可视化)

NumPy和Pandas库(用于数据处理)

Jupyter Notebook或类似IDE(用于代码编写和结果展示)

数据集

本实验使用的是鸢尾花数据集(Iris dataset),它是一个经典的多分类数据集,包含150个样本,4个特征(花萼长度、花萼宽度、花瓣长度、花瓣宽度),以及3个目标类别(Setosa、Versicolor、Virginica)。为了简化实验并将问题转化为二分类问题,我们将类别2(Virginica)标记为0,而类别0(Setosa)和类别1(Versicolor)标记为1。

步骤

1. 数据加载与初步探索

2. 数据集划分与标准化

3. 模型训练与预测

4. 模型性能评估

5. 可视化混淆矩阵

代码示例

引入实验用的包

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

加载鸢尾花数据集。

iris = load_iris()
X = iris.data
y = iris.target

将目标类别从三分类问题转换为二分类问题。

# 原始类别
# 0: Setosa
# 1: Versicolor
# 2: Virginica
# 将三分类问题转化为二分类问题:
# 定义规则:类别 2(Virginica) 转化为 0,其余类别 0 和 1 转化为 1
y_binary = np.where(y == 2, 0, 1)
# 查看二分类后的类别分布
print("原始类别分布:", np.bincount(y))
print("二分类后的类别分布:", np.bincount(y_binary))

图1

划分训练集测试集

使用 train_test_split 方法将数据划分为训练集(80%)和测试集(20%)。

# 数据集划分

X_train, X_test, y_train, y_test = train_test_split(X, y_binary, test_size=0.2, random_state=42, stratify=y_binary)

数据标准化

使用 StandardScaler 对数值型特征使用进行标准化,以确保每个特征具有相同的尺度,避免不同尺度的特征对模型训练产生影响。

# 数据标准化

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

初始化逻辑回归模型

初始化LogisticRegression,使用模型进行训练,基于训练集数据进行拟合。

model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)

预测

使用训练好的模型对测试集进行预测,得到预测结果。

# 测试集预测

y_pred = model.predict(X_test)
y_pred

图2

计算模型准确率

使用 accuracy_score 方法评估模型在测试集上的准确率。

# 评估准确率

accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

图3

计算混淆矩阵

通过 confusion_matrix 查看模型的混淆矩阵,以了解分类情况。

# 混淆矩阵

conf_matrix = confusion_matrix(y_test, y_pred)
print("混淆矩阵:\n", conf_matrix)

图4

输出分类报告

使用 classification_report 获得模型的精确率、召回率和 F1 分数等详细指标。

# 分类报告

class_report = classification_report(y_test, y_pred)
print("分类报告:\n", class_report)

图5

可视化混淆矩阵

使用Matplotlib可视化混淆矩阵,帮助直观分析模型的分类效果。

plt.figure(figsize=(8, 6))  # 调整图形尺寸
sns.set(font_scale=1.2)  # 调整字体大小
sns.heatmap(
    conf_matrix,
    annot=True,
    cmap="BuGn",  # 使用蓝绿色配色方案
    fmt="d",
    cbar=True,  # 显示颜色条
    annot_kws={"size": 14, "weight": "bold"},  # 注释字体大小和加粗
    xticklabels=["Class 0", "Class 1"],
    yticklabels=["Class 0", "Class 1"],
    linewidths=1.5,  # 增加单元格边框
    linecolor="gray"  # 边框颜色为灰色
)

# 图表标题和轴标签

plt.title("Confusion Matrix", fontsize=18, weight="bold", pad=20, color="teal")  # 标题加粗并增加间距,颜色为青绿色
plt.xlabel("Predicted Labels", fontsize=14, labelpad=10, color="darkblue")  # 调整标签颜色
plt.ylabel("True Labels", fontsize=14, labelpad=10, color="darkblue")

# 旋转 x 轴标签
plt.xticks(rotation=45, ha="right", fontsize=12, color="darkgreen")
plt.yticks(fontsize=12, color="darkgreen")

# 显示图形
plt.tight_layout()  # 调整布局避免溢出
plt.show()


图6

完整代码

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# 加载鸢尾花数据集
iris = load_iris()
X = iris.data
y = iris.target

# 原始类别
# 0: Setosa
# 1: Versicolor
# 2: Virginica

# 将三分类问题转化为二分类问题:
# 定义规则:类别 2(Virginica) 转化为 0,其余类别 0 和 1 转化为 1
y_binary = np.where(y == 2, 0, 1)

# 查看二分类后的类别分布
print("数据集形状:", X.shape)
print("原始类别分布:", np.bincount(y))
print("二分类后的类别分布:", np.bincount(y_binary))

# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y_binary, test_size=0.2, random_state=42, stratify=y_binary)

# 数据标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 初始化并训练逻辑回归模型
model = LogisticRegression(random_state=42)
model.fit(X_train, y_train)

# 测试集预测
y_pred = model.predict(X_test)
y_pred

# 评估准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

# 混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
print("混淆矩阵:\n", conf_matrix)

# 分类报告
class_report = classification_report(y_test, y_pred)
print("分类报告:\n", class_report)

plt.figure(figsize=(8, 6))  # 调整图形尺寸
sns.set(font_scale=1.2)  # 调整字体大小
sns.heatmap(
    conf_matrix,
    annot=True,
    cmap="BuGn",  # 使用蓝绿色配色方案
    fmt="d",
    cbar=True,  # 显示颜色条
    annot_kws={"size": 14, "weight": "bold"},  # 注释字体大小和加粗
    xticklabels=["Class 0", "Class 1"],
    yticklabels=["Class 0", "Class 1"],
    linewidths=1.5,  # 增加单元格边框
    linecolor="gray"  # 边框颜色为灰色
)

# 图表标题和轴标签
plt.title("Confusion Matrix", fontsize=18, weight="bold", pad=20, color="teal")  # 标题加粗并增加间距,颜色为青绿色
plt.xlabel("Predicted Labels", fontsize=14, labelpad=10, color="darkblue")  # 调整标签颜色
plt.ylabel("True Labels", fontsize=14, labelpad=10, color="darkblue")

# 旋转 x 轴标签
plt.xticks(rotation=45, ha="right", fontsize=12, color="darkgreen")
plt.yticks(fontsize=12, color="darkgreen")

# 显示图形
plt.tight_layout()  # 调整布局避免溢出
plt.show()

标签:iris,逻辑,plt,分类,类别,train,test,import,鸢尾花
From: https://blog.csdn.net/fukase_mio/article/details/143919761

相关文章

  • 实验二:逻辑回归算法实现与测试
    实验二:逻辑回归算法实现与测试 一、实验目的深入理解对数几率回归(即逻辑回归的)的算法原理,能够使用Python语言实现对数几率回归的训练与测试,并且使用五折交叉验证算法进行模型训练与评估。 二、实验内容(1)从scikit-learn库中加载iris数据集,使用留出法留出1/3的样......
  • 【C#】【winforms】MVP架构中从 Model 或 View 层主动向 Presenter 传递数据或调用处
    背景使用winforms做上位机软件,软件功能简单来说就是与串口通信。因为一个软件要应用于不同型号的下位机,采用MVP架构提高代码复用性。 其中Model层中实例化SerialPort对象:privateSerialPort_serialPort;只关注串口收发。 presenter层负责主要业务逻辑。view层负责......
  • SAP付款日期计算逻辑
    1.根据采购订单找付款条件→EKKO-ZTERM2.根据付款条件找付款条件的配置→T052-ZTERM,基准日期计算方式缺省值对应T052日期类型B凭证日期,D记账日期,空没有默认值,C输入日期。检查付款条件的基准日期计算,优先按照固定日期计算。BSIK表未清-ZFBDT(付款起算日期,基准日期)。到期日......
  • 一件事有A点和B点, 并且有路径能从A点准确到达B点, 这就是有逻辑。
    一件事有A点和B点,并且有路径能从A点准确到达B点,这就是有逻辑。这段话阐明了逻辑的一个基本概念,即从一个起点到达一个终点的过程如果是可行的、可预见的,并且能够遵循一定的规则或步骤,那么就可以认为这是一个有逻辑的过程。逻辑的基本含义在这句话中,“A点”和“B点”代表......
  • React+AntD文件上传并自定义上传逻辑
    上传组件DragClickUpload.tsximport{CloudUploadOutlined}from'@ant-design/icons';importtype{UploadProps}from'antd';import{message,Upload}from'antd';importReact,{useState}from'react';importaxiosfrom&......
  • 逻辑回归算法实现与测试
    逻辑回归算法实现与测试一、实验目的 深入理解对数几率回归(即逻辑回归的)的算法原理,能够使用Python语言实现对数几率回归的训练与测试,并且使用五折交叉验证算法进行模型训练与评估。 二、实验内容 (1)从scikit-learn库中加载iris数据集,使用留出法留出1/3的样本作......
  • #渗透测试#SRC漏洞挖掘#网络运维# 黑客脚本编写05之字符串运算符与逻辑运算
    免责声明本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停止本文章阅读。                            ......
  • Luogu P9869 NOIp2023 三值逻辑 题解 [ 绿 ] [ 带权并查集 ]
    三值逻辑:有点坑并且细节较繁琐,但有点板子的并查集。修改操作发现对于每个点,只有对他的最后一次操作才是有用的,所以记录下最终的祖先即可。然而这里并不能用并查集来实现,因为并查集它具有的是传递性,无论你路不路径压缩,每次修改一个父节点时它的子节点一定会被修改,所以我们不能使......
  • MySQL 逻辑备份与恢复指南
    MySQL逻辑备份与恢复指南引言逻辑备份将数据库数据和结构导出为SQL文件,用于数据迁移或恢复。本文提供常用备份和恢复命令,适用于单表、单数据库、多数据库及所有数据库场景。命令行参数说明<参数>:尖括号内的内容为用户需替换的实际值(例如:主机、端口、用户名等)。>:表示输出重......
  • 机器学习-38-对ML的思考之探寻Iris数据集的来源及并非完美的标准数据集
    文章目录1标准数据集的滥用1.1机器学习不研究采集数据1.2基于别人采集的数据学习2经典的数据集2.1Iris鸢尾花数据集2.2探寻Iris数据集的源头2.2.1Iris物种2.2.2侦探工作2.2.3解开谜团2.2.4哪里可以找到Iris数据集2.3Iris数据的采集过程......