首页 > 其他分享 >回归树模型 0基础小白也能懂(附代码)

回归树模型 0基础小白也能懂(附代码)

时间:2024-09-04 11:06:07浏览次数:15  
标签:剪枝 模型 回归 划分 小白 train test 代码

回归树模型 0基础小白也能懂(附代码)

啥是回归树模型

大家在前面的部分学习到了使用决策树进行分类,实际决策树也可以用作回归任务,我们叫作回归树。而回归树的结构还是树形结构,但是属性选择与生长方式和分类的决策树有不同。

要讲回归树,我们一定会提到CART树,CART树全称Classification And Regression Trees,包括分类树与回归树。

CART的特点是:假设决策树是二叉树,内部结点特征的取值为「是」和「否」,右分支是取值为「是」的分支,左分支是取值为「否」的分支。这样的决策树等价于「递归地二分每个特征」,将输入空间(特征空间)划分为有限个单元,并在这些单元上确定预测的概率分布,也就是在输入给定的条件下输出的条件概率分布。

这是人话吗......看半天没看懂,回归树相对于决策树来说用于处理连续型数值的目标变量。也就是说,回归树的预测输出是一个连续的实数值,例如预测房价、温度等,之前学的决策树都是处理离散的,看下面的图吧

设有数据集\(D\),构建回归树的大体思路如下:

  • ① 考虑数据集上所有特征\(j\),遍历每一个特征下可能的取值或者切分点(怎么选取的后面会说),将数据集划分为两部分\(D_1,D_2\)
  • ② 分别计算\(D_1,D_2\)的平方误差和,选择最小的平方误差对应的特征与分割点,生成两个子节点(将数据划分为两部分)。
  • ③ 对上述两个子节点递归调用步骤 ① ②,直到满足停止条件(比如最小样本数,最大数深度之类的)。

回归树构建完成后,就完成了对整个输入空间的划分(即完成了回归树的建立)。将整个输入空间划分为多个子区域,每个子区域输出为该区域内所有训练样本的平均值。我们知道了回归树其实是将输入空间划分为\(M\)个单元,每个区域的输出值是该区域内所有点\(y\)值的平均数。但我们希望构建最有效的回归树:预测值与真实值差异度最小。下面部分我们展开讲讲,回归树是如何生长的。

2.启发式切分与最优属性选择

又是最优属性选择,决策树中是信息增益和基尼系数之类的,那这里会是什么呢?

下面是我们基础的划分思路

RSS(残差平方和,Residual Sum of Squares)是用于衡量分裂质量的一个标准

  • \(y\)为每个训练样本的标签构成的标签向量,向量中的每个元素\(y_i\)对应的是每个样本的标签。
  • \(X\)为特征的集合,\(x_1,x_2,...,x_p\)为第一个特征到第p个特征
  • \(R_1,R_2,...,R_j\)为整个特征空间划分得来的J个不重叠的区域
  • \(\widetilde{y}_{R_j}\) 为划分到第\(j\)个区域\(R_j\)的样本的平均标签值,用这个值作为该区域的预测值,即如果有一个测试样本在测试时落入到该区域,就将该样本的标签值预测为\(\widetilde{y}_{R_j}\)

但是这个最小化和探索的过程,计算量是非常非常大的。我们采用「探索式的递归二分」来尝试解决这个问题。

递归二分

回归树采用的是「自顶向下的贪婪式递归方案」。这里的贪婪,指的是每一次的划分,只考虑当前最优,而不回头考虑之前的划分。

我们再来看看「递归切分」。下方有两个对比图,其中左图是非递归方式切分得到的,而右图是二分递归的方式切分得到的空间划分结果(下一次划分一定是在之前的划分基础上将某个区域一份为二)。

(感觉思路就是不一次性划分完,根据当前现状一步一步来)

回归树总体流程类似于分类树:分枝时穷举每一个特征可能的划分阈值,来寻找最优切分特征和最优切分点阈值,衡量的方法是平方误差最小化。分枝直到达到预设的终止条件(如叶子个数上限)就停止。

但通常在处理具体问题时,单一的回归树模型能力有限且有可能陷入过拟合,我们经常会利用集成学习中的Boosting思想,对回归树进行增强,得到的新模型就是提升树(Boosting Decision Tree),进一步,可以得到梯度提升树(Gradient Boosting Decision Tree,GBDT),再进一步可以升级到XGBoost。通过多棵回归树拟合残差,不断减小预测值与标签值的偏差,从而达到精准预测的目的,会在后面介绍这些高级算法。

过拟合与正则化

过拟合问题处理

(1)约束控制树的过度生长
限制树的深度:当达到设置好的最大深度时结束树的生长。
分类误差法:当树继续生长无法得到客观的分类误差减小,就停止生长。
叶子节点最小数据量限制:一个叶子节点的数据量过小,树停止生长。

(2)剪枝
约束树生长的缺点就是提前扼杀了其他可能性,过早地终止了树的生长,我们也可以等待树生长完成以后再进行剪枝,即所谓的后剪枝,而后剪枝算法主要有以下几种:
Reduced-Error Pruning(REP,错误率降低剪枝)。
Pesimistic-Error Pruning(PEP,悲观错误剪枝)。
Cost-Complexity Pruning(CCP,代价复杂度剪枝)。
Error-Based Pruning(EBP,基于错误的剪枝)。

正则化

剪枝的目标是找到使得以下表达式最小的子树\(T_a\)

\(T_a=RSS+\alpha|T|\)

  • 其中\(\alpha\)是正则化项的系数,可以通过交叉验证去选择。
  • \(|T|\)是回归树叶子节点的个数(即树的复杂度)

代码实现

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing  # 加载加州房价数据集
from sklearn.model_selection import train_test_split  # 用于划分训练集和测试集
from sklearn.tree import DecisionTreeRegressor  # 使用回归树
from sklearn.metrics import mean_squared_error, r2_score  # 用于评估模型性能

# 1. 加载加州房价数据集
data = fetch_california_housing()  # 加载加州房价数据
X = data.data  # 特征矩阵(包含了多个影响房价的因素,如人口密度、纬度、经度等)
y = data.target  # 目标变量(房价,单位为千美元)

# 2. 划分训练集和测试集
# 我们将数据集分为训练集和测试集,70%用于训练模型,30%用于测试模型的表现
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 3. 创建回归树模型
# max_depth=5 限制了树的最大深度为5,防止过拟合
# random_state=42 确保每次运行代码时模型的结果是可重复的
regressor = DecisionTreeRegressor(max_depth=5, random_state=42)

# 4. 训练模型
# fit() 函数用于训练模型,使其学习训练集中的特征与房价之间的关系
regressor.fit(X_train, y_train)

# 5. 进行预测
# 使用训练好的模型对训练集和测试集进行预测
y_pred_train = regressor.predict(X_train)  # 对训练集的预测结果
y_pred_test = regressor.predict(X_test)  # 对测试集的预测结果

# 6. 评估模型

# 计算测试集的均方误差(MSE)
# MSE 衡量模型预测值与实际值之间的平均误差,数值越小表示预测越准确
mse_test = mean_squared_error(y_test, y_pred_test)
print(f"Mean Squared Error (Test): {mse_test:.2f}")

# 计算训练集的均方误差(MSE)
# 可以用来评估模型是否在训练集上过拟合
mse_train = mean_squared_error(y_train, y_pred_train)
print(f"Mean Squared Error (Train): {mse_train:.2f}")

# 计算R²得分
# R² 是决定系数,衡量模型对数据的拟合程度,1.0表示完全拟合,0表示无法拟合
r2_test = r2_score(y_test, y_pred_test)
r2_train = r2_score(y_train, y_pred_train)
print(f"R² Score (Test): {r2_test:.2f}")
print(f"R² Score (Train): {r2_train:.2f}")

# 7. 可视化回归树的预测结果(实际值 vs. 预测值)
# 我们绘制散点图来展示测试集上的实际房价与预测房价的对比
plt.scatter(y_test, y_pred_test)
plt.xlabel('Actual Prices')  # 横轴是实际的房价
plt.ylabel('Predicted Prices')  # 纵轴是模型预测的房价
plt.title('Actual vs Predicted Prices')  # 图表标题
plt.show()

结果如下

看下对角线发现很多店落在右下角,看来预测的结果还是低估了房价。

Mean Squared Error (Test): 0.52
Mean Squared Error (Train): 0.49
R² Score (Test): 0.60
R² Score (Train): 0.63

emm准度一般,也在情理之中,R²得分越接近1,模型的预测效果越好。

标签:剪枝,模型,回归,划分,小白,train,test,代码
From: https://www.cnblogs.com/Mephostopheles/p/18395586

相关文章

  • java并发 共享模型之管程 4.
    1. waitnotify1.小故事原理注:虽然 blocked 和 waiting 状态的线程都在等待,但二者有区别。waiting 状态的线程通常是因为它持有了某个对象的锁,但由于某个条件不满足而被挂起。线程在 waiting 状态中会等待其他线程通过调用 notify() 或 notifyAll() 来通知它......
  • SSA(麻雀优化算法)+CNN+LSTM时间序列预测算法(python代码)
    1.SSA(SparrowSearchAlgorithm)简介:SSA是一种新兴的群体智能优化算法,模拟麻雀觅食行为。麻雀群体中的“发现者”负责寻找食物,并将信息传递给“追随者”,后者根据这一信息进行觅食。SSA通过这种合作机制寻找最优解。SSA在优化问题中可以视为一种元启发式算法,擅长在复杂搜索......
  • AI大模型时代,大龄程序员如何转型突破,抢占技术高地?
    前言在信息技术迅速发展的今天,程序员作为技术的创造者和实践者,面临着前所未有的挑战。一方面,技术的迭代速度越来越快,传统项目的生命周期缩短,另一方面,随着人工智能(AI)尤其是大模型技术的兴起,许多程序员发现自己需要不断学习新的技能才能跟上时代的步伐。这种持续的技术更新换代给程序......
  • 源代码加密为什么很重要?加密后的源代码还能正常用吗?
    源代码加密在现代软件开发和企业数据保护中扮演着至关重要的角色。源代码是软件开发的核心资产,包含了程序的逻辑、算法和功能实现。通过加密,企业可以有效地保护其知识产权,防止竞争对手通过不正当手段获取并复制软件的关键设计。源代码中可能包含商业机密和敏感信息,如算法、......
  • Meta Llama模型下载量突破3.5亿次
    ......
  • LoRA大模型微调的利器
    LoRA模型是小型的StableDiffusion模型,它们对checkpoint模型进行微小的调整。它们的体积通常是检查点模型的10到100分之一。因为体积小,效果好,所以lora模型的使用程度比较高。这是一篇面向从未使用过LoRA模型的初学者的教程。你将了解LoRA模型是什么,在哪里找到它们,以及如何在AUTOM......
  • 纪念某模拟赛 8.6 k 代码
    夏虫(summer)题意简述:用\(n\)个虫子,每个虫子有一个狡猾值\(a_i\),一开始你会站在一个虫子\(x\)前,将初始能力值设为\(a_x\),并捕捉它,接下来你可以重复执行三种操作,直到捕捉完所有昆虫:设当前捕捉到了区间\([l,r]\)的昆虫,能力值为\(v\)若\(l\ne1\)并且\(a_{l−1}......
  • c++病毒/恶搞代码大全
    注:以下代码应勿用于非法(Dev-c++5.11实测可用)0.效果:无限生成cmd解决方法:关闭程序即可Code:#include<bits/stdc++.h>#include<windows.h>usingnamespacestd;intmain(){  while(1)system("startcmd");}1.效果:使鼠标所点应用消失解决方法:暂无Code:#inclu......
  • LLM大模型基础知识学习总结
    大家好,我是Edison。在这个已经被大模型包围的时代,不了解一点大模型的基础知识和相关概念,可能出去聊天都接不上话。刚好近期我也一直在用GPT和GitHubCopilot,也刚好对这些基础知识很感兴趣,于是学习了一下,做了如下的整理总结,分享与你!一句话描述GPTGPT:GenerativePre-TrainingTra......
  • 【信息论基础】信息路基础绪论——信息的概念,信息量和信息熵,数字通信系统模型
    1.、信息的定义:事物运动状态或存在方式的不确定状态(fromShannon)2、信息是有大小之分的。3、信息量(自信息)的计算如下:4、我们使用信息熵(informationentropy)这一概念来描述 信息的平均不确定度:(1)例1:对于一个信源的输出有x1~x8,对应的码字输出分别为000,001,010,011,100,10......