首页 > 编程语言 >机器学习算法原理实现——cart决策树

机器学习算法原理实现——cart决策树

时间:2023-09-10 17:46:10浏览次数:40  
标签:node gini feature cart 算法 right 决策树 class left

 

 

cart决策树示例:

 

本文目标,仿照sklearn写一个cart树,但是仅仅使用max_depth作为剪枝依据。

 

 

 

我们本次实现cart分类,因此用到gini指数:

 为了帮助理解:

 

 

好了,理解了基尼指数。我们看下cart树的构建步骤:

注意还有几个细节:

 

cart树每个treenode存储了哪些数据?

在CART决策树中,每个节点(TreeNode)通常存储以下数据:

  1. 划分特征:这是用于根据某种条件划分数据集的特征。例如,如果一个节点用"年龄 > 30"作为分割条件,那么"年龄"就是这个节点的划分特征。

  2. 划分阈值:与划分特征配合使用,定义了数据应如何分割。在上面的例子中,阈值是30。

  3. 左子节点:满足划分条件的数据子集的节点。例如,在上面的"年龄 > 30"例子中,大于30岁的数据会被划分到左子节点。

  4. 右子节点:不满足划分条件的数据子集的节点。在上面的例子中,30岁及以下的数据会被划分到右子节点。

  5. 类标签:只在叶节点中有效。表示该节点所代表的数据子集中最常见的类别。当新数据通过决策树进行预测时,最终到达的叶节点的类标签就是其预测结果。

  6. 数据子集:节点当前代表的数据子集。在许多实际实现中,为了节省内存,节点可能不直接存储数据子集,而是存储数据索引或其他引用。

  7. 基尼不纯度或其他不纯度指标:代表当前数据子集的不纯度。在构建树的过程中,这个指标用于判断是否应该继续划分当前节点。

  8. 其他可选信息:如节点深度、父节点引用、数据点的数量等。

这些数据允许决策树在训练过程中进行递归分割,以及在预测过程中导航通过树结构。

 

 好了,实现代码如下:

import numpy as np

class TreeNode:
    def __init__(self, gini, num_samples, num_samples_per_class, predicted_class):
        self.gini = gini
        self.num_samples = num_samples
        self.num_samples_per_class = num_samples_per_class
        self.predicted_class = predicted_class
        self.feature_index = 0
        self.threshold = 0
        self.left = None
        self.right = None

def gini(y):
    m = len(y)
    return 1.0 - sum([(np.sum(y == c) / m) ** 2 for c in np.unique(y)])

def grow_tree(X, y, depth=0, max_depth=None):
    classes = np.unique(y)
    num_samples_per_class = [np.sum(y == c) for c in classes]
    predicted_class = classes[np.argmax(num_samples_per_class)]
    node = TreeNode(
        gini=gini(y),
        num_samples=len(y),
        num_samples_per_class=num_samples_per_class,
        predicted_class=predicted_class,
    )

    if depth < max_depth:
        idx, thr = best_split(X, y)
        if idx is not None:
            indices_left = X[:, idx] < thr
            X_left, y_left = X[indices_left], y[indices_left]
            X_right, y_right = X[~indices_left], y[~indices_left]
            node.feature_index = idx
            node.threshold = thr
            node.left = grow_tree(X_left, y_left, depth + 1, max_depth)
            node.right = grow_tree(X_right, y_right, depth + 1, max_depth)
    return node


def best_split(X, y):
    """
    用numpy实现best_split,见下,可以先看不用numpy的实现
    """
    n_samples, n_features = X.shape
    
    if len(np.unique(y)) == 1:
        return None, None
    
    best = {}
    min_gini = float('inf')
    
    for feature_idx in range(n_features):
        thresholds = np.unique(X[:, feature_idx])
        for threshold in thresholds:
            left_mask = X[:, feature_idx] < threshold
            right_mask = ~left_mask
            
            gini_left = gini(y[left_mask])
            gini_right = gini(y[right_mask])
            
            weighted_gini = len(y[left_mask]) / n_samples * gini_left + len(y[right_mask]) / n_samples * gini_right
            if weighted_gini < min_gini:
                best = {
                    'feature_index': feature_idx,
                    'threshold': threshold,
                    'left_labels': y[left_mask],
                    'right_labels': y[right_mask],
                    'gini': weighted_gini
                }
                min_gini = weighted_gini
                
    return best['feature_index'], best['threshold']


def best_split2(X, y):
    """
    不用numpy实现best_split
    """
    n_samples, n_features = len(X), len(X[0])
    
    # 如果样本中只有一种输出标签或样本为空,则返回None
    if len(set(y)) == 1:
        return None, None
    
    # 初始化最佳分割的信息
    best = {}
    min_gini = float('inf')
    
    # 遍历每个特征
    for feature_idx in range(n_features):
        # 获取当前特征的所有唯一值,并排序
        unique_values = sorted(set(row[feature_idx] for row in X))
        
        # 遍历每个唯一值,考虑将其作为分割阈值
        for value in unique_values:
            left_y, right_y = [], []
            
            # 对于每个样本,根据其特征值与阈值的关系分到左子集或右子集
            for i, row in enumerate(X):
                if row[feature_idx] < value:
                    left_y.append(y[i])
                else:
                    right_y.append(y[i])
            
            # 计算左子集和右子集的基尼指数
            gini_left = 1.0 - sum([(left_y.count(label) / len(left_y)) ** 2 for label in set(left_y)])
            gini_right = 1.0 - sum([(right_y.count(label) / len(right_y)) ** 2 for label in set(right_y)])
            
            # 计算加权基尼指数
            weighted_gini = len(left_y) / len(y) * gini_left + len(right_y) / len(y) * gini_right
            
            # 如果当前基尼值小于已知的最小基尼值,更新最佳分割
            if weighted_gini < min_gini:
                best = {
                    'feature_index': feature_idx,
                    'threshold': value,
                    'left_labels': left_y,
                    'right_labels': right_y,
                    'gini': weighted_gini
                }
                min_gini = weighted_gini
                
    return best['feature_index'], best['threshold']

def predict_tree(node, X):
    if node.left is None and node.right is None:
        return node.predicted_class
    if X[node.feature_index] < node.threshold:
        return predict_tree(node.left, X)
    else:
        return predict_tree(node.right, X)

def predict_tree2(node, X):
    if node.left is None and node.right is None:
        return node.predicted_class
    if X[node.feature_index] < node.threshold:
        return predict_tree(node.left, X)
    else:
        return predict_tree(node.right, X)

class CARTClassifier:
    def __init__(self, max_depth=None):
        self.max_depth = max_depth

    def fit(self, X, y):
        self.tree_ = grow_tree(X, y, max_depth=self.max_depth)

    def predict(self, X):
        return [predict_tree(self.tree_, x) for x in X]

# 使用示例
if __name__ == "__main__":
    """
    # 好好理解下这个分割的函数
    X = np.array([[2.5], [3.5], [1], [1.5], [2], [3], [0]])
    y = np.array([1, 1, 0, 0, 1, 0, 2])
    best_idx, best_thr = best_split(X, y)
    """

    from sklearn.datasets import load_iris

    data = load_iris()
    X, y = data.data, data.target

    clf = CARTClassifier(max_depth=4)
    clf.fit(X, y)
    preds = clf.predict(X)

    accuracy = sum(preds == y) / len(y)
    print(f"Accuracy: {accuracy:.4f}")

    from sklearn.tree import DecisionTreeClassifier
    # 创建分类树实例
    clf = DecisionTreeClassifier(max_depth=4)
    # 分类树训练
    clf.fit(X, y)
    preds = clf.predict(X)
    accuracy = sum(preds == y) / len(y)
    print(f"sklearn Accuracy: {accuracy:.4f}")

  

输出:

Accuracy: 0.9933 sklearn Accuracy: 0.9933  

标签:node,gini,feature,cart,算法,right,决策树,class,left
From: https://www.cnblogs.com/bonelee/p/17691555.html

相关文章

  • 粒子群优化算法
    写在前面在大大的花园里面挖呀挖呀挖,挖大大的坑呀寻大大的WA。官方解释利用群体中的个体对信息的共享使整个群体的运动在问题求解空间中产生从无序到有序的演化过程。(这个解释不美丽.......)诡异的故事法解释那是一个暴风雨之夜,伴随着一声巨响,空气开始震动,狂风忽然吹向东方,比......
  • 2023“钉耙编程”中国大学生算法设计超级联赛(5)
    1001Typhoon题意:给你台风的轨迹坐标以及避难所的坐标,台风的半径不可预测,求让每个避难所不安全的最小台风半径是多少。分析:枚举每个点到所有“线段”的距离取个min。代码:附上队友的代码(懒):#include<bits/stdc++.h>#include<math.h>#definerep(i,a,b)for(inti=a;i<......
  • 5 排序算法总结
    5排序算法总结首先总结表如下:排序方法平均时间复杂度最好情况最坏情况空间复杂度是否稳定排序方式冒泡排序\(O(n^2)\)\(O(n)\)\(O(n^2)\)\(O(1)\)稳定内部排序选择排序\(O(n^2)\)\(O(n^2)\)\(O(n^2)\)\(O(1)\)不稳定内部排序插入排序\(O(n^2)......
  • 代码随想录算法训练营-回溯算法|491.递增子序列
    491. 递增子序列 不对原数组进行排序,利用set对同层的子集进行去重。1classSolution:2deffindSubsequences(self,nums):3result=[]4path=[]5self.backtracking(nums,0,path,result)6returnresult78......
  • 机器学习算法原理实现——决策树里根据信息增益选择特征
    先说熵的定义:  再看信息增益信息增益是一种用于特征选择的指标,用于衡量特征对于数据集分类的贡献程度。它基于信息熵的概念,通过比较特征划分前后的信息熵差异来评估特征的重要性。信息熵是衡量数据集纯度的指标,表示数据集中的不确定性或混乱程度。信息熵越高,数据集的不确......
  • 代码随想录算法训练营第四天| 24. 两两交换链表中的节点, 19.删除链表的倒数第N个结点
    24.两两交换链表中的节点mydemo(超时)/***Definitionforsingly-linkedlist.*structListNode{*intval;*ListNode*next;*ListNode():val(0),next(nullptr){}*ListNode(intx):val(x),next(nullptr){}*ListNode(intx,Lis......
  • 《算法图解》的读书体会
    最近工作有点内耗严重,经常头痛,静下心来读一本书架里的书,好久没有练算法了,程序员算法还是不能丢,在这里分享读书体会。有时候看一本书并不一定是抱着学习的心思去读,那样太累,学习太枯燥了,抱着玩的心态去读,更能让我读下去。如果遇到问题,绕开它,我们不一定要解决问题,逃避问题也是处理问......
  • 文心一言 VS 讯飞星火 VS chatgpt (89)-- 算法导论8.3 2题
    二、用go语言,下面的排序算法中哪些是稳定的:插入排序、归并排序、堆排序和快速排序?给出一个能使任何排序算法都稳定的方法。你所给出的方法带来的额外时间和空间开销是多少?文心一言:在Go语言中,以下排序算法是稳定的:插入排序归并排序而以下的排序算法是不稳定的:堆排序快速排......
  • 文心一言 VS 讯飞星火 VS chatgpt (89)-- 算法导论8.3 2题
    二、用go语言,下面的排序算法中哪些是稳定的:插入排序、归并排序、堆排序和快速排序?给出一个能使任何排序算法都稳定的方法。你所给出的方法带来的额外时间和空间开销是多少?文心一言:在Go语言中,以下排序算法是稳定的:插入排序归并排序而以下的排序算法是不稳定的:堆排序快速排序要使任......
  • C#希尔排序算法
    前言希尔排序简单的来说就是一种改进的插入排序算法,它通过将待排序的元素分成若干个子序列,然后对每个子序列进行插入排序,最终逐步缩小子序列的间隔,直到整个序列变得有序。希尔排序的主要思想是通过插入排序的优势,减小逆序对的距离,从而提高排序效率。希尔排序实现原理首先要确......