首页 > 编程语言 >用Python实现9大回归算法详解——09. 决策树回归算法

用Python实现9大回归算法详解——09. 决策树回归算法

时间:2024-08-23 11:24:35浏览次数:8  
标签:模型 Python 回归 算法 test 数据 节点 决策树

1. 决策树回归的基本概念

决策树回归(Decision Tree Regression)是一种树状结构的回归模型,通过对数据集进行递归分割,将数据分成更小的子集,并在每个子集上进行简单的线性回归。决策树的核心思想是通过选择特征及其阈值来最大化每次分裂后的目标函数增益,从而找到使误差最小化的模型。

2. 决策树回归的算法流程

(1)节点分裂

节点分裂的目标是最小化每个子节点中的均方误差(MSE),具体公式如下:

  • 从当前数据集中选择一个特征及其阈值,将数据集分割为两个子集。
  • 选择一个使得分裂后的两部分数据能够最大化目标函数增益的特征和阈值。

\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} \left( y_i - \hat{y}_i \right)^2

其中:

  • y_i 是第 i 个样本的真实值。
  • \hat{y}_i 是第 i 个样本的预测值(子节点中的均值)。
  • N 是子节点中的样本数。

(2)树的构建

  • 递归地对每个子节点进行分裂,直到达到某个停止条件(如最大深度、最小样本数等)。
  • 每个叶节点的预测值为该节点中所有样本目标值的平均值。

叶节点的预测值公式:

\hat{y}_{\text{leaf}} = \frac{1}{N_{\text{leaf}}} \sum_{i=1}^{N_{\text{leaf}}} y_i

其中:

  • N_{\text{leaf}}​ 是叶节点中的样本数。
  • y_i 是叶节点中第 i 个样本的真实值。

(3)树的剪枝(可选):

  • 为了避免过拟合,可以使用剪枝技术,对已经生成的决策树进行剪枝,去掉那些对最终预测贡献较小的节点。
3. 决策树回归的数学表达

(1)均方误差(MSE)

\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} \left( y_i - \hat{y}_i \right)^2

(2)叶节点的预测值

\hat{y}_{\text{leaf}} = \frac{1}{N_{\text{leaf}}} \sum_{i=1}^{N_{\text{leaf}}} y_i

(3)分裂节点的选择

对于特征j 和阈值 t,选择能够最小化分裂后的两个子节点的总 MSE 的分裂方式:

\text{Split} = \arg\min_{j, t} \left[ \frac{N_{\text{left}}}{N} \text{MSE}_{\text{left}} + \frac{N_{\text{right}}}{N} \text{MSE}_{\text{right}} \right]

其中:

  • \text{MSE}_{\text{left}} 和 \text{MSE}_{\text{right}}​ 分别是左子节点和右子节点的均方误差。
4. 决策树回归的优缺点

优点

  1. 易于理解和解释:决策树结构直观易懂,能够很容易地解释模型的决策过程。
  2. 处理非线性数据:决策树可以处理非线性数据,而不需要对数据进行特殊的处理。
  3. 无需特征缩放:决策树对数据的尺度不敏感,无需进行特征缩放。

缺点

  1. 容易过拟合:决策树容易生成复杂的模型,对训练数据拟合过度,从而降低对新数据的泛化能力。
  2. 不稳定性:小的扰动可能导致完全不同的树结构,因为树的分裂方式可能会对训练数据中的小变化产生较大影响。

5. 决策树回归案例

我们将通过一个具体的案例来展示如何使用决策树回归进行预测,并对结果进行详细分析。

5.1 数据加载与预处理

我们使用加利福尼亚州房价数据集(California Housing Dataset)进行回归预测。

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error, r2_score

# 加载加利福尼亚州房价数据集
housing = fetch_california_housing()
X, y = housing.data, housing.target

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

解释

  • 数据加载:我们选择加利福尼亚州房价数据集,该数据集包含加利福尼亚州的房屋特征数据,用于预测房屋的价格中位数。
  • 数据划分:将数据集划分为训练集和测试集,80% 的数据用于训练,20% 的数据用于测试。
5.2 模型训练与预测

我们使用 DecisionTreeRegressor 进行模型训练,并对测试集进行预测。

# 定义决策树回归模型
dtr = DecisionTreeRegressor(max_depth=5, random_state=42)

# 训练模型
dtr.fit(X_train, y_train)

# 对测试集进行预测
y_pred = dtr.predict(X_test)

解释

  • 模型定义DecisionTreeRegressor 是决策树回归的实现。我们设置 max_depth=5 来限制树的最大深度,以防止过拟合。
  • 模型训练:使用训练集数据进行模型训练,构建决策树模型。
  • 模型预测:训练完成后,使用模型对测试集进行预测,得到预测值。
5.3 模型评估与结果分析

我们使用均方误差(MSE)和决定系数(R^2)来评估模型的性能。

# 计算均方误差 (MSE) 和决定系数 (R²)
mse = mean_squared_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print("均方误差 (MSE):", mse)
print("决定系数 (R²):", r2)

 输出:

均方误差 (MSE): 0.5245146178314735
决定系数 (R²): 0.5997321244428706

解释

  • 均方误差 (MSE):模型的预测误差为 0.524,表明模型对测试集的预测有一定误差。
  • 决定系数 (R²):模型的 R^2 值为 0.599,说明模型能够解释 59.9% 的目标变量方差,模型拟合效果尚可。
5.4 决策树可视化

我们可以通过可视化决策树的结构,来更好地理解模型的决策过程。

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(20,10))
plot_tree(dtr, filled=True, feature_names=housing.feature_names, rounded=True)
plt.show()

输出: 

可视化解释

  • 树结构:每个节点显示了用于分裂的数据特征、阈值、节点中的样本数量以及预测的目标值。
  • 颜色深浅:颜色表示了节点中目标值的均值,颜色越深表示预测值越高。
5.5 结果可视化

我们还可以通过绘制预测值与实际值的散点图,来进一步验证模型的表现。

# 绘制预测值与实际值的散点图
plt.scatter(y_test, y_pred, color="blue", alpha=0.5)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
plt.xlabel("Actual")
plt.ylabel("Predicted")
plt.title("Decision Tree Regression: Actual vs Predicted")
plt.show()

输出:

 

可视化解释

  • 散点图:横轴表示测试集的实际房价,纵轴表示模型预测的房价。每个点代表一个测试样本的预测结果。
  • 红色虚线:表示理想情况下,预测值应与实际值完全一致的参考线(即 y = x 的线)。
  • 分析:如果大多数散点分布在红色虚线附近,说明模型的预测效果较好。散点分布越集中,表示模型的预测准确性越高。反之,如果散点分布较为分散,特别是在远离红色虚线的区域,说明模型的预测误差较大。
5.6 参数调优

为了进一步提升模型性能,我们可以通过网格搜索(Grid Search)来调优决策树的超参数,如最大深度(max_depth)、最小分裂样本数(min_samples_split)等。

from sklearn.model_selection import GridSearchCV

# 定义参数网格
param_grid = {
    'max_depth': [3, 5, 7, 10],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}

# 实例化决策树回归模型
dtr = DecisionTreeRegressor(random_state=42)

# 进行网格搜索
grid_search = GridSearchCV(estimator=dtr, param_grid=param_grid, cv=5, scoring='neg_mean_squared_error', n_jobs=-1)
grid_search.fit(X_train, y_train)

# 输出最佳参数
print("最佳参数:", grid_search.best_params_)

解释

  • 参数选择:通过网格搜索,我们定义了多个超参数的组合,如 max_depthmin_samples_splitmin_samples_leaf,用于寻找最优的参数设置。
  • 交叉验证:使用 5 折交叉验证(cv=5)来评估每个参数组合的表现,从而选择最优的参数。
  • 输出最佳参数:模型训练完成后,输出最佳的参数组合,这些参数可以用于构建性能更优的模型。

6. 总结

决策树回归是一种简单且强大的回归模型,能够有效处理线性和非线性数据。它通过递归地分割数据集,找到使得均方误差最小化的分裂方式,最终生成一个树状的回归模型。尽管决策树回归易于理解和解释,但它也容易过拟合,尤其是在树的深度较大的情况下。因此,在应用时,通常需要通过剪枝或调优超参数来控制模型的复杂度。

标签:模型,Python,回归,算法,test,数据,节点,决策树
From: https://blog.csdn.net/qq_41698317/article/details/141436883

相关文章

  • 【LLM & RAG & text2sql】大模型在知识图谱问答上的核心算法详细思路及实践
    前言本文介绍了一个融合RAG(Retrieval-AugmentedGeneration)思路的KBQA(Knowledge-BasedQuestionAnswering)系统的核心算法及实现步骤。KBQA系统的目标是通过自然语言处理技术,从知识图谱中提取和生成精确的答案。系统的实现包括多个关键步骤:mention识别、实体链接及排序、属......
  • Python下载安装全流程(Python 最新版本),新手小白必看!
    第一次接触Python,可能是爬虫或者是信息AI开发的小朋友,都说Python语言简单,那么多学一些总是有好处的,下面从一个完全不懂的Python的小白来安装Python等一系列工作的记录,并且遇到的问题也会写出,让完全不懂的小白也可上手安装,并且完成第一个Helloworld代码。需要安装包......
  • 字符串搜索算法
    目录二分搜索(适用于有序字符串数组)Trie树(前缀树)后缀树与后缀数组二分搜索(适用于有序字符串数组)基本概念二分搜索(BinarySearch)是一种高效的查找算法,适用于在有序数组中查找特定元素。与线性搜索相比,二分搜索通过逐步减少搜索范围,大幅提高查找效率。算法步骤确定中间元......
  • python socket编辑示例 UDP
    服务端:fromsocketimportsocket,AF_INET,SOCK_DGRAMrecv_socket=socket(AF_INET,SOCK_DGRAM)recv_socket.bind(('127.0.0.1',8888))whileTrue:data,addr=recv_socket.recvfrom(1024)#接收数据print('客户说:',data.decode('......
  • Python3测试mysql插入数据代码(chatgpt生成)
      实现的功能:先连接mysql数据库,然后读取某个目录所有以txt文件命名后缀的json内容文件,解析出对应的key和value,然后插入数据到mysql数据库,最后关闭数据库连接 importosimportjsonimportpymysqlimportre"""尝试插入json文件到MySQL数据库。dbInfo:MySQL数据库......
  • 零基础学习人工智能—Python—Pytorch学习(八)
    前言本文介绍卷积神经网络的上半部分。其实,学习还是需要老师的,因为我自己写文章的时候,就会想当然,比如下面的滑动窗口,我就会想当然的认为所有人都能理解,而实际上,我们在学习的过程中之所以卡顿的点多,就是因为学习资源中想当然的地方太多了。概念卷积神经网络,简称CNN,即Convoluti......
  • 4-线性回归
    python中*运算符的使用用于将可迭代对象(如列表或元组)的元素解压缩为单独的参数当我们从Dataloader取出来的时候,又会将压缩为的单独参数分开importtorchfromtorch.utilsimportdata#准备数据true_w=torch.tensor([2,-3.4])true_b=4.2defsynthetic_data(w,b......
  • Python-批量统计MySQL中表的数据量
    背景在数据中台中,有时为了核对数据,需要每天批量统计MySQL数据库中表的数据量,但是DMS中没有周期调度功能。MySQL创建表--统计的表清单CREATETABLE`dws_table_list`(`table_name`varchar(255)DEFAULTNULL,`flag`varchar(255)DEFAULTNULL);--每天的数据量CRE......
  • python socket编辑示例
    服务端代码:fromsocketimportsocket,AF_INET,SOCK_STREAM#1.创建socket对象AF_INET:用于internet之间的进程通信,SOCK_STREAM:表示TCP协议server_socket=socket(AF_INET,SOCK_STREAM)#2.绑定ip和端口号ip='127.0.0.1'port=8888server_socket.bind((ip,p......
  • 基于Python flask的图书借阅管理系统的设计与实现
    基于PythonFlask的图书借阅管理系统旨在为图书馆或类似机构提供一个高效、便捷的管理平台,覆盖图书借阅的各个环节,帮助管理员和读者更好地管理和使用图书资源。该系统采用Python编程语言和Flask框架进行开发,结合了数据库管理、用户认证、数据可视化等技术,确保系统的功能完备和......