首页 > 其他分享 >基于信息增益和基尼指数的二叉决策树

基于信息增益和基尼指数的二叉决策树

时间:2024-11-07 10:57:19浏览次数:3  
标签:splitInfo para 样本 tree 二叉 基尼 samples 节点 决策树

# coding: UTF-8
'''
基于信息增益和基尼指数的二叉决策树的实现。
该决策树可以用于分类问题,通过选择合适的特征来划分样本。
'''

from collections import Counter

class biTree_node:
    '''
    二叉树节点定义
    每个节点可以是叶子节点或内部节点。
    '''

    def __init__(self, f=-1, fvalue=None, leafLabel=None, l=None, r=None, splitInfo="gini"):
        '''
        类初始化函数
        para f: int, 切分的特征,用样本中的特征次序表示
        para fvalue: float or int, 切分特征的决策值
        para leafLabel: int, 叶节点的标签
        para l: biTree_node指针, 当前节点的左子树
        para r: biTree_node指针, 当前节点的右子树
        para splitInfo: string, 切分的标准, 可取值'infogain'和'gini', 分别表示信息增益和基尼指数。
        每个节点都保存了其用于划分的特征以及该特征的具体值,并且指向其左右子树。
        如果是叶子节点,则保存了该节点的标签。
        '''
        self.f = f  # 特征索引,即样本中的特征次序
        self.fvalue = fvalue  # 特征切分值,用于决定样本走向左子树还是右子树
        self.leafLabel = leafLabel  # 如果是叶节点,则保存对应的类别标签
        self.l = l  # 左子树,指向当前节点的左子节点
        self.r = r  # 右子树,指向当前节点的右子节点
        self.splitInfo = splitInfo  # 切分标准,用于决定使用何种方法来计算最佳特征和特征值


def gini_index(samples):
    '''
    计算基尼指数。
    para samples: list, 样本列表,每个样本的最后一个元素是标签。
    return: float, 基尼指数。
    '''
    label_counts = sum_of_each_label(samples)
    total = len(samples)
    gini = 1.0
    for label in label_counts:
        prob = label_counts[label] / total
        gini -= prob ** 2
    return gini

def info_entropy(samples):
    '''
    计算信息熵。
    para samples: list, 样本列表,每个样本的最后一个元素是标签。
    return: float, 信息熵。
    '''
    label_counts = sum_of_each_label(samples)
    total = len(samples)
    entropy = 0.0
    for label in label_counts:
        prob = label_counts[label] / total
        entropy -= prob * (prob * 3.321928094887362)  # 以2为底的对数
    return entropy

def split_samples(samples, feature, value):
    '''
    根据特征和值分割样本集。
    para samples: list, 样本列表。
    para feature: int, 特征索引。
    para value: float or int, 特征值。
    return: tuple, 两个列表,分别为左子集和右子集。
    '''
    left = [sample for sample in samples if sample[feature] < value]
    right = [sample for sample in samples if sample[feature] >= value]
    return left, right

def sum_of_each_label(samples):
    '''
    统计样本中各类别标签的分布。
    para samples: list, 样本列表。
    return: dict, 标签及其出现次数的字典。
    '''
    labels = [sample[-1] for sample in samples]
    return Counter(labels)

def build_biTree(samples, splitInfo="gini"):
    '''构建二叉决策树
    para samples: list, 样本的列表,每样本也是一个列表,样本的最后一项为标签,其它项为特征。
    para splitInfo: string, 切分的标准,可取值'infogain'和'gini', 分别表示信息增益和基尼指数。
    return: biTree_node, 二叉决策树的根节点。
    该函数递归地构建决策树,每次选择一个最佳特征和其值来切分样本集,直到无法有效切分为止。
    '''
    # 如果没有样本,则返回空节点
    if len(samples) == 0:
        return biTree_node()

    # 检查切分标准是否合法
    if splitInfo != "gini" and splitInfo != "infogain":
        return biTree_node()

    bestInfo = 0.0  # 最佳信息增益或基尼指数减少量
    bestF = None  # 最佳特征
    bestFvalue = None  # 最佳特征的切分值
    bestlson = None  # 左子树
    bestrson = None  # 右子树

    # 计算当前集合的基尼指数或信息熵
    curInfo = gini_index(samples) if splitInfo == "gini" else info_entropy(samples)

    sumOfFeatures = len(samples[0]) - 1  # 样本中特征的个数
    for f in range(0, sumOfFeatures):  # 遍历每个特征
        featureValues = [sample[f] for sample in samples]  # 提取特征值
        for fvalue in featureValues:  # 遍历当前特征的每个值
            lson, rson = split_samples(samples, f, fvalue)  # 根据特征及其值切分样本
            # 计算分裂后两个集合的基尼指数或信息熵
            if splitInfo == "gini":
                info = (gini_index(lson) * len(lson) + gini_index(rson) * len(rson)) / len(samples)
            else:
                info = (info_entropy(lson) * len(lson) + info_entropy(rson) * len(rson)) / len(samples)

            gain = curInfo - info  # 计算增益或基尼指数的减少量

            # 找到最佳特征及其切分值
            if gain > bestInfo and len(lson) > 0 and len(rson) > 0:
                bestInfo = gain
                bestF = f
                bestFvalue = fvalue
                bestlson = lson
                bestrson = rson

    # 如果找到了最佳切分
    if bestInfo > 0.0:
        l = build_biTree(bestlson, splitInfo)  # 递归构建左子树
        r = build_biTree(bestrson, splitInfo)  # 递归构建右子树
        return biTree_node(f=bestF, fvalue=bestFvalue, l=l, r=r, splitInfo=splitInfo)
    else:
        # 如果没有有效切分,则生成叶节点
        label_counts = sum_of_each_label(samples)
        return biTree_node(leafLabel=max(label_counts, key=label_counts.get), splitInfo=splitInfo)


def predict(sample, tree):
    '''
    对给定样本进行预测
    para sample: list, 需要预测的样本
    para tree: biTree_node, 构建好的分类树
    return: int, 预测样本所属的类别
    '''
    if tree.leafLabel is not None:  # 如果当前节点是叶节点
        return tree.leafLabel
    else:
        # 否则根据特征值选择子树
        sampleValue = sample[tree.f]
        branch = tree.r if sampleValue >= tree.fvalue else tree.l
        return predict(sample, branch)  # 递归下去


def print_tree(tree, level='0'):
    '''简单打印树的结构
    para tree: biTree_node, 树的根节点
    para level: str, 当前节点在树中的深度,0表示根,0L表示左子节点,0R表示右子节点
    '''
    if tree.leafLabel is not None:  # 如果是叶节点
        print('*' + level + '-' + str(tree.leafLabel))  # 打印标签
    else:
        print('+' + level + '-' + str(tree.f) + '-' + str(tree.fvalue))  # 打印特征索引及切分值
        print_tree(tree.l, level + 'L')  # 打印左子树
        print_tree(tree.r, level + 'R')  # 打印右子树


if __name__ == "__main__":

    # 示例数据集:某人相亲的数据
    blind_date = [[35, 176, 0, 20000, 0],
                  [28, 178, 1, 10000, 1],
                  [26, 172, 0, 25000, 0],
                  [29, 173, 2, 20000, 1],
                  [28, 174, 0, 15000, 1]]

    print("信息增益二叉树:")
    tree = build_biTree(blind_date, splitInfo="infogain")  # 构建信息增益的二叉树
    print_tree(tree)  # 打印树结构
    print('信息增益二叉树对样本进行预测的结果:')

    test_sample = [[24, 178, 2, 17000],
                   [27, 176, 0, 25000],
                   [27, 176, 0, 10000]]

    # 对测试样本进行预测
    for x in test_sample:
        print(predict(x, tree))

    print("基尼指数二叉树:")
    tree = build_biTree(blind_date, splitInfo="gini")  # 构建基尼指数的二叉树
    print_tree(tree)  # 打印树结构
    print('基尼指数二叉树对样本进行预测的结果:')

    # 再次对测试样本进行预测
    for x in test_sample:
        print(predict(x, tree))  # 预测并打印结果

输出结果:


标签:splitInfo,para,样本,tree,二叉,基尼,samples,节点,决策树
From: https://www.cnblogs.com/h4o3/p/18531757

相关文章

  • 华为OD机试真题-数组二叉树码-2024年OD统一考试(E卷)
    最新华为OD机试考点合集:华为OD机试2024年真题题库(E卷+D卷+C卷)_华为od机试题库-CSDN博客     每一题都含有详细的解题思路和代码注释,精编c++、JAVA、Python三种语言解法。帮助每一位考生轻松、高效刷题。订阅后永久可看,发现新题及时跟新。题目描述二叉树也可以用数组来......
  • 数据结构 ——— 链式二叉树oj题:相同的树
    目录题目要求手搓两个链式二叉树代码实现 题目要求给你两棵二叉树的根节点 p 和 q ,编写一个函数来检验这两棵树是否相同。如果两个树在结构上相同,并且节点具有相同的值,则认为它们是相同的。手搓两个链式二叉树代码演示://数据类型typedefintBTDataType;......
  • 数据结构 ——— 计算链式二叉树第k层的节点个数
    目录链式二叉树示意图手搓一个链式二叉树计算链式二叉树第k层的节点个数 链式二叉树示意图手搓一个链式二叉树代码演示://数据类型typedefintBTDataType;//二叉树节点的结构typedefstructBinaryTreeNode{BTDataTypedata;//每个节点的数据s......
  • 代码随想录算法训练营第十八天|leetcode530.二叉搜索树的最小绝对差、leetcode501.二
    1leetcode530.二叉搜索树的最小绝对差题目链接:530.二叉搜索树的最小绝对差-力扣(LeetCode)文章链接:代码随想录视频链接:你对二叉搜索树了解的还不够!|LeetCode:98.验证二叉搜索树_哔哩哔哩_bilibili思路:定义一个极大值作为结果,然后在中序遍历过程中进行比较出结果1.1自己的......
  • 红黑树:自平衡的二叉搜索树
    简介红黑树(Red-BlackTree)是一种自平衡的二叉搜索树,其中每个节点都有一个颜色属性,可以是红色或黑色。红黑树在计算机科学中被广泛用于各种应用,如关联数组、数据库和调度程序。它们提供了一种有效的方式来保持数据的有序性,同时在插入和删除操作中保持较低的时间复杂度。红黑树......
  • 数据结构树与二叉树
    语雀链接:https://www.yuque.com/g/wushi-ls7km/ga9rkw/qw8kwzxigbx61kxy/collaborator/join?token=2vdSjDBgJyJb0VSL&source=doc_collaborator#《树与二叉树》......
  • 实验4:二叉树的基本操作
    c++解释:new相当于malloc()函数,其他没有区别!点击查看代码#include<iostream>usingnamespacestd;structtree{ intdata; tree*light,*ture;};intjie,shen,maxx;//创建tree*chu(){ tree*head; head=newtree; cout<<"请输入数值:\n"; cin>&g......
  • 代码随想录算法训练营第十六天|leetcode513.找树左下角的值、leetcode112.路径总和、l
    1leetcode513.找树左下角的值题目链接:513.找树左下角的值-力扣(LeetCode)文章链接:代码随想录视频链接:怎么找二叉树的左下角?递归中又带回溯了,怎么办?|LeetCode:513.找二叉树左下角的值_哔哩哔哩_bilibili思路:就是用一个东西存储result,使用后续遍历,如果遇到了最深的那一个值,就......
  • 【算法】递归+深搜:106.从中序与后序遍历序列构造二叉树(medium)
    目录1、题目链接相似题目:2、题目3、解法函数头-----找出重复子问题函数体---解决子问题4、代码1、题目链接106.从中序与后序遍历序列构造二叉树(LeetCode)相似题目:105.从前序与中序遍历序列构造二叉树889.根据前序和后序遍历构造二叉树(LeetCode)2、题目3、解法......
  • 代码随想录算法训练营第十四天|leetcode226. 翻转二叉树、leetcode101.对称二叉树、le
    1leetcode226.翻转二叉树题目链接:226.翻转二叉树-力扣(LeetCode)文章链接:代码随想录视频链接:听说一位巨佬面Google被拒了,因为没写出翻转二叉树|LeetCode:226.翻转二叉树哔哩哔哩bilibili自己的思路:之前想过就是使用层序遍历的方法来做这一道题目,后来感觉有一些行不通,就......