首页 > 编程语言 >决策树算法 0基础小白也能懂(附代码)

决策树算法 0基础小白也能懂(附代码)

时间:2024-09-03 16:53:27浏览次数:4  
标签:剪枝 1.00 小白 算法 test 属性 节点 决策树

决策树算法

原文链接

啥是决策树

决策树(Decision tree)是基于已知各种情况(特征取值)的基础上,通过构建树型决策结构来进行分析的一种方式,是常用的有监督的分类算法(也就是带有标签的训练数据集训练的,比如后文中使用到的训练集中的好瓜坏瓜就是标签,形容瓜的就是特征)

决策树模型(Decision Tree model)模拟人类决策过程。

根节点:决策树的起点,代表数据集的整体。

内部节点:表示对某个特征进行的判断或测试,也可以说是类别二选一。

分支:从一个节点到另一个节点的路径,根据特征的取值进行分割,表示一个测试输出。

叶节点:代表最终的决策或预测结果。

下面的是决策树家族,后面我们来讲讲他们分别是如何构建起来的

决策树构建

选择特征:决策树通过选取最能分割数据的特征来构建内部节点。通常使用信息增益(Information Gain)或基尼系数(Gini Impurity)等标准来衡量特征的重要性,这些标准后面还会谈到。

分裂:根据选定的特征,将数据集分成若干子集,每个子集对应一个特定的特征取值或范围。

递归分裂:对子集重复上述过程,构建子树,直到满足停止条件(如节点纯度达到阈值、最大深度达到、数据量不足等)。

终止条件:当不能再有效分裂时,节点转化为叶节点,叶节点的输出即为分类标签或回归值。

图里搞的很复杂,重点其实就在递归。

最优属性选择(内部节点)

要了解决策树的「最优属性」选择,我们需要先了解一个信息论的概念「信息熵(entropy)」,它是消除不确定性所需信息量的度量,也是未知事件可能含有的信息量。

假设数据集\(D\)中有\(y\)类,其中第\(k\)类样本占比为\(p_k\),则信息熵的计算公式如下:

\(ENT(D)=-\sum_{k=1}^{|y|}p_k\log_2p_k\)

\(p_k\)为1时,信息熵最小为0,很明显为必然事件,\(p_k\)为均匀分布(概率相等)时,信息商取最大值(\(p_k=\frac{1}{y}\))\(\log_2(|y|)\)(概率同等,不确定性最大)

信息增益(ID3)

还记得我们之前的决策树家族中的ID3吗?构建时用的就是信息增益信息增益(Information Gain),它衡量的是我们选择某个属性进行划分时信息熵的变化(可以理解为基于这个规则划分,不确定性降低的程度)。典型的决策树算法ID3就是基于信息增益来挑选每一节点分支用于划分的属性(特征)的。

这里面的\(D^v\)可能有点难理解,它是将数据集\(D\)根据属性\(a\)的那些取值划分成了\(v\)个子集\(\{D_1,D_2,...,D_v\}\),那划分后的信息熵又是咋来的,其实是一种条件熵\(H(D|a)\),是数据集\(D\)在基于属性\(a\)进行划分后的不确定性。

下面拿一个西瓜的数据集举个例子,一共17个数据,9个好瓜,8个坏瓜

以色泽属性为条件计算信息熵,一共三类色泽:\(青绿,乌黑,浅白\),看看他们在好坏瓜中的占比进行计算

同样的方法,计算其他属性的信息增益为:

对比不同属性,我们发现「纹理」信息增益最大,它就要作为决策树的根节点,可以看到里面被分为三个属性:\(清晰,模糊,稍糊\),也就是下一层的节点要根据这三个属性来看,计算各属性信息增益

图中只给出了纹理=清晰这一个分支的结果,有三个属性信息增益都一样,那么说明这三个特征都是最能分割数据的特征,均作为决策树的节点。纹理=稍糊以及其他属性的计算过程略去了,最后的结果如下图

信息增益率(C4.5)

大家已经了解了信息增益作为特征选择的方法,但信息增益有一个问题,它偏向取值较多的特征。原因是,当特征的取值较多时,根据此特征划分更容易得到纯度更高的子集,因此划分之后的熵更低,由于划分前的熵是一定的。因此信息增益更大,因此信息增益比较偏向取值较多的特征。

那有没有解决这个小问题的方法呢?有的,这就是我们要提到信息增益率(Gain Ratio),信息增益率相比信息增益,多了一个衡量本身属性的分散程度的部分作为分母,而著名的决策树算法C4.5就是使用它作为划分属性挑选的原则。

\(Grain\_ratio(D,a)=\frac{Gain(D,a)}{IV(a)}\)
\(IV(a)=-\sum_{v=1}^{V}\frac{|D^v|}{|D|}\log_2\frac{|D^v|}{|D|}\)

下面那一块就是熵公式的变式,固有熵通过计算特征自身的“熵”,使得信息增益率能够公平地评价特征的分裂能力,不偏向多值特征。

基尼指数(CART)

数学上用于信息量(或者纯度)衡量的不止有上述的熵相关的定义,我们还可以使用基尼指数来表示数据集的不纯度。基尼指数越大,表示数据集越不纯。

基尼系数
\(Gini(D)=\sum_{k=1}^{|y|} \sum_{k^{'}\not=k}p_kp_{k^{'}}=1-\sum_{k=1}^{|y|}p_k^2\)

为什么它可以作为纯度的量度呢?大家可以想象在一个漆黑的袋里摸球,有不同颜色的球,其中第k类占比记作\(p_k\),那两次摸到的球都是第k类的概率就是\(p_k^2\),那两次摸到的球颜色不一致的概率就是\(1-\sum p_k^2\),它的取值越小,两次摸球颜色不一致的概率就越小,纯度就越高。

过拟合和剪枝

如果我们让决策树一直生长,最后得到的决策树可能很庞大,而且因为对原始数据学习得过于充分会有过拟合的问题。缓解决策树过拟合可以通过剪枝操作完成。而剪枝方式又可以分为:预剪枝和后剪枝。并使用「留出法」进行评估剪枝前后决策树的优劣。

我们来看一个例子,下面的数据集,为了评价决策树模型的表现,会划分出一部分数据作为验证集

在上述西瓜数据集上生成的一颗完整的决策树,如下图所示。

剪枝基本策略包含「预剪枝」和「后剪枝」两个:

预剪枝(pre-pruning):在决策树生长过程中,对每个结点在划分前进行估计,若当前结点的划分不能带来决策树泛化性能的提升,则停止划分并将当前结点标记为叶结点。

后剪枝(post-pruning):先从训练集生成一颗完整的决策树,然后自底向上地对非叶结点进行考察,若将该结点对应的子树替换为叶结点能带来决策树泛化性能的提升,则将该子树替换为叶结点。

预剪枝

根据我们的验证集,如果只按照好坏瓜来进行分的话验证集精度为3/7x100%=42.9%

但是加入决策树的➀结点后,验证集精度为5/7x100%=71.4%,比没有划分前大,所以不剪枝,最后分出来如下图

后剪枝

和预剪枝一样是判断精度,只是从下面开始,没剪之前精度42.9,如果把结点⑥的标记为好瓜,精度57.1,可以剪

最终结果为

时间开销:
预剪枝:训练时间开销降低,测试时间开销降低。
后剪枝:训练时间开销增加,测试时间开销降低。

过/欠拟合风险:
预剪枝:过拟合风险降低,欠拟合风险增加。
后剪枝:过拟合风险降低,欠拟合风险基本不变。

泛化性能:后剪枝通常优于预剪枝。

连续值与缺失值的处理

连续值处理

我们用于学习的数据包含了连续值特征和离散值特征,之前的例子中使用的都是离散值属性(特征),决策树当然也能处理连续值属性,我们来看看它的处理方式。

对于离散取值的特征,决策树的划分方式是:选取一个最合适的特征属性,然后将集合按照这个特征属性的不同值划分为多个子集合,并且不断的重复这种操作的过程。

对于连续值属性,显然我们不能以这些离散值直接进行分散集合,否则每个连续值将会对应一种分类。那我们如何把连续值属性参与到决策树的建立中呢?

因为连续属性的可取值数目不再有限,因此需要连续属性离散化处理,常用的离散化策略是二分法,这个技术也是 C4.5 中采用的策略。

缺失值处理

原始数据很多时候还会出现缺失值,决策树算法也能有效的处理含有缺失值的数据。缺失值处理的基本思路是:样本赋权,权重划分。

样本赋权是为每个训练样本分配一个权重,用以影响模型的学习过程。在训练过程中,权重较大的样本对模型的贡献更大,模型会更多地关注这些样本。

权重划分是指将整体权重分配给不同的部分或类别,确保模型能够有效地学习这些部分。例如,在决策树的构建过程中,使用样本的权重来影响节点的分裂决策。

我们来通过下图这份有缺失值的西瓜数据集,看看具体处理方式。

仅通过无缺失值的样例来判断划分属性的优劣,学习开始时,根结点包含样例集\({D}\)中全部17个样例,权重均为1。

\(\widetilde{D^1},\widetilde{D^2},\widetilde{D^3}\)分别表示在属性「色泽」上取值为「青绿」「乌黑」以及「浅白」的样本子集:

再计算其他属性的增益

因此选择「纹理」作为接下来的划分属性。感觉权重可能就体现在那个\(\widetilde{r_v}\)里,就是排除了缺失值的占比

代码实现

用的是iris数据集,直接用sklearn库

样本数量:150个。
特征数量:4个连续特征。
类别数量:3个类别,每个类别包含50个样本。
数据平衡:每个类别的样本数量相同,均为50个。

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn import metrics
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

# 1. 加载数据集
iris = load_iris()
X = iris.data
y = iris.target

# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 3. 训练决策树模型
clf = DecisionTreeClassifier()
clf.fit(X_train, y_train)

# 4. 进行预测
y_pred = clf.predict(X_test)

# 5. 评估模型
print("Accuracy:", metrics.accuracy_score(y_test, y_pred))
print("Classification Report:\n", metrics.classification_report(y_test, y_pred))

# 6. 可视化决策树
plt.figure(figsize=(12, 8))
plot_tree(clf, filled=True, feature_names=iris.feature_names, class_names=iris.target_names)
plt.show()

结果如下

Accuracy: 1.0
Classification Report:
               precision    recall  f1-score   support

           0       1.00      1.00      1.00        19
           1       1.00      1.00      1.00        13
           2       1.00      1.00      1.00        13

    accuracy                           1.00        45
   macro avg       1.00      1.00      1.00        45
weighted avg       1.00      1.00      1.00        45

颜色深浅代表该节点中的样本纯度(越纯的节点颜色越深)

原生实现

import numpy as np
from collections import Counter
from math import log2
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

# 计算熵的函数,用于衡量数据集的纯度
def entropy(labels):
    counts = Counter(labels)  # 统计每个类别的数量
    total = len(labels)  # 样本总数
    # 熵公式:sum(-p * log2(p)),其中p是每个类别的概率
    return -sum((count / total) * log2(count / total) for count in counts.values())

# 分裂数据集,根据某个特征的某个值将数据分为两部分
def partition(data, labels, index, value):
    left_data, left_labels = [], []  # 保存分裂后左子集的数据和标签
    right_data, right_labels = [], []  # 保存分裂后右子集的数据和标签

    # 遍历数据集,按特征值将数据分配到左子集或右子集
    for i, row in enumerate(data):
        if row[index] <= value:  # 小于等于分裂值的放到左子集
            left_data.append(row)
            left_labels.append(labels[i])
        else:  # 大于分裂值的放到右子集
            right_data.append(row)
            right_labels.append(labels[i])

    return left_data, left_labels, right_data, right_labels

# 计算信息增益,衡量某次分裂的效果
def info_gain(left_labels, right_labels, current_uncertainty):
    p = float(len(left_labels)) / (len(left_labels) + len(right_labels))  # 计算左子集样本占总样本的比例
    # 信息增益 = 当前不确定性 - (左子集不确定性 * 左子集权重 + 右子集不确定性 * 右子集权重)
    return current_uncertainty - p * entropy(left_labels) - (1 - p) * entropy(right_labels)

# 找到最佳分裂方式,返回最佳分裂的特征索引和分裂值
def find_best_split(data, labels):
    best_gain = 0  # 初始化最佳信息增益为0
    best_index = 0  # 初始化最佳分裂特征索引为0
    best_value = 0  # 初始化最佳分裂值为0
    current_uncertainty = entropy(labels)  # 当前数据集的不确定性(熵)

    # 遍历每个特征
    for index in range(len(data[0])):
        values = set(row[index] for row in data)  # 获取该特征的所有可能取值

        # 对每个特征的每个取值进行分裂,并计算信息增益
        for value in values:
            left_data, left_labels, right_data, right_labels = partition(data, labels, index, value)

            # 如果分裂后没有数据,跳过该分裂方式
            if not left_labels or not right_labels:
                continue

            # 计算信息增益
            gain = info_gain(left_labels, right_labels, current_uncertainty)

            # 如果信息增益更大,则更新最佳分裂方式
            if gain > best_gain:
                best_gain, best_index, best_value = gain, index, value

    return best_index, best_value

# 决策树的节点类,包含分裂条件和子节点或预测值
class DecisionNode:
    def __init__(self, feature=None, value=None, true_branch=None, false_branch=None, prediction=None):
        self.feature = feature  # 用于分裂的特征索引
        self.value = value  # 分裂值
        self.true_branch = true_branch  # 左子树(True分支)
        self.false_branch = false_branch  # 右子树(False分支)
        self.prediction = prediction  # 如果是叶节点,保存预测值

# 递归地构建决策树
def build_tree(data, labels):
    # 如果数据集中的所有标签都相同,创建叶节点
    if len(set(labels)) == 1:
        return DecisionNode(prediction=labels[0])

    # 找到最佳分裂方式
    index, value = find_best_split(data, labels)
    # 如果找不到有效分裂,创建叶节点(返回样本最多的类别)
    if index is None:
        return DecisionNode(prediction=Counter(labels).most_common(1)[0][0])

    # 根据最佳分裂方式分裂数据集
    left_data, left_labels, right_data, right_labels = partition(data, labels, index, value)

    # 递归构建左子树和右子树
    true_branch = build_tree(left_data, left_labels)
    false_branch = build_tree(right_data, right_labels)

    # 返回当前节点
    return DecisionNode(feature=index, value=value, true_branch=true_branch, false_branch=false_branch)

# 使用构建好的决策树对单个样本进行分类
def classify(row, node):
    # 如果当前节点是叶节点,返回预测值
    if node.prediction is not None:
        return node.prediction

    # 根据分裂条件递归调用分类函数
    if row[node.feature] <= node.value:
        return classify(row, node.true_branch)
    else:
        return classify(row, node.false_branch)

# 打印决策树,便于理解树的结构
def print_tree(node, spacing=""):
    # 如果当前节点是叶节点,打印预测结果
    if node.prediction is not None:
        print(spacing + "Predict", node.prediction)
        return

    # 打印当前节点的分裂条件
    print(spacing + f"[Feature {node.feature}] <= {node.value}")

    # 递归打印左子树
    print(spacing + '--> True:')
    print_tree(node.true_branch, spacing + "  ")

    # 递归打印右子树
    print(spacing + '--> False:')
    print_tree(node.false_branch, spacing + "  ")

# 1. 加载数据集
iris = load_iris()  # 加载Iris数据集,包含150个样本,每个样本有4个特征
X = iris.data  # 特征矩阵(花的测量值)
y = iris.target  # 标签(花的品种)

# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 使用train_test_split函数将数据集分为训练集和测试集,测试集占30%

# 3. 训练决策树模型
tree = build_tree(X_train.tolist(), y_train.tolist())
# 调用build_tree函数使用训练集数据递归构建决策树

# 4. 进行预测
y_pred = [classify(row, tree) for row in X_test]
# 使用构建好的决策树对测试集进行预测

# 5. 评估模型
accuracy = sum(1 for actual, predicted in zip(y_test, y_pred) if actual == predicted) / len(y_test)
# 计算准确率:正确预测的数量 / 总测试样本数量
print("Accuracy:", accuracy)

# 打印分类报告
print(classification_report(y_test, y_pred))
# 使用classification_report打印精确度、召回率和F1分数

# 6. 打印决策树
print_tree(tree)
# 调用print_tree函数打印决策树的结构

结果

Accuracy: 0.9111111111111111
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        19
           1       0.85      0.85      0.85        13
           2       0.85      0.85      0.85        13

    accuracy                           0.91        45
   macro avg       0.90      0.90      0.90        45
weighted avg       0.91      0.91      0.91        45

标签:剪枝,1.00,小白,算法,test,属性,节点,决策树
From: https://www.cnblogs.com/Mephostopheles/p/18385104

相关文章

  • LeetCode_0028. 找出字符串第一个匹配项的下标,KMP算法的实现
    题目描述  给你两个字符串haystack和needle,请你在haystack字符串中找出needle字符串的第一个匹配项的下标(下标从0开始)。如果needle不是haystack的一部分,则返回-1。示例1:输入:haystack="sadbutsad",needle="sad"输出:0解释:"sad"在下标0和6处匹......
  • 召回策略算法-粗排算法-精排算法
     召回策略算法召回策略算法用于在海量文档中快速识别和选择与用户查询相关的文档,以满足用户的检索需求:提高检索效率:召回策略算法能够快速过滤出与用户查询相关的文档,减少了后续排序和排除不相关文档的计算量,从而提高了检索效率。提高搜索结果的相关性:通过选择与用户查询......
  • 过滤策略算法
    过滤策略算法过滤策略算法是指根据特定的规则或条件,从一组数据中筛选出符合要求的数据集合的方法。在信息检索和搜索引擎领域,过滤策略算法常用于对搜索结果或推荐结果进行过滤,以提供更符合用户需求的结果集合。比如针对过滤用户拉黑的内容和不感兴趣的内容,可以采用基于用户行......
  • 掌握检索技术:构建高效知识检索系统的架构与算法5
    在检索专业知识层需要涵盖更高级的检索技术,包括工程架构和算法策略。一、工程架构工程架构在构建检索系统中决定了系统的可扩展性、高可用性和性能。比如需要考虑的基本点:分布式架构:水平扩展:采用分布式架构,将检索任务分布到多个节点上,实现水平扩展。这可以通过将索引数据......
  • 图算法太难懂?凸包算法搞不通?看这篇文章就够了
    标题:你以为凸包算法只是数学游戏?不,这才是竞赛中的制胜法宝!你以为几何算法只是竞赛中的小儿科,顶多画个漂亮图形?但是,朋友,你要知道,如果你还停留在这样的认知,那你已经out了!凸包(ConvexHull)——听起来像个不起眼的小问题,但实际上,它是算法竞赛中的核武器,是能让你在众多参赛者中脱......
  • 求职季来了,是时候让豆包MarsCode 陪你刷算法题了
    金九银十求职季,对于广大技术岗求职者来说,拥有扎实的算法知识是打开理想职业大门的金钥匙。为了更好地帮助广大求职者找到自己心仪的岗位,豆包MarsCode特推出代码练习能力,将全功能的代码编辑器和AI能力相结合,希望帮助开发者更快速地在求职季进行算法题目练习,100道大厂真题,助你高效......
  • 卡尔曼滤波算法的学习总结
    本文为作者学习卡尔曼滤波算法后的学习总结,如有错误请指正,万分感谢!前言本文学自B站up主华南小虎队,原视频讲得很好,推荐去观看。原视频卡尔曼滤波讲解一、简介(1)作用在学习卡尔曼滤波之前,我们首先要明白在使用该滤波器后,可以给我们带来什么好处?在此给读者举出一个例子,方......
  • 【数据结构与算法】:十大经典排序算法
    文章目录前言一、冒泡排序(BubbleSort)1.1冒泡排序原理1.2冒泡排序代码1.3输出结果二、选择排序(SelectionSort)2.1选择排序原理2.2选择排序代码2.3输出结果三、插入排序(InsertionSort)3.1插入排序原理3.2插入排序代码3.3输出结果四、希尔排序4.1希尔排序原......
  • 基于SIR模型的疫情发展趋势预测算法matlab仿真
    1.程序功能描述基于SIR模型的疫情发展趋势预测算法.对病例增长进行SIR模型拟合分析,并采用模型参数拟合结果对疫情防控力度进行比较。整体思路为采用SIR微分方程模型,对疫情发展进行过程进行拟合。2.测试软件版本以及运行结果展示MATLAB2022a版本运行3.核心程序Opt.LargeScale......
  • 常见算法和lambda的使用
    常见的七种查找算法:数据结构是数据存储的方式,算法是数据计算的方式。所以在开发中,算法和数据结构息息相关。今天的讲义中会涉及部分数据结构的专业名词,如果各位铁粉有疑惑,可以先看一下哥们后面录制的数据结构,再回头看算法。1.基本查找也叫做顺序查找说明:顺序查找适合于存储结......