首页 > 其他分享 >机器学习之决策树

机器学习之决策树

时间:2024-06-09 15:55:11浏览次数:12  
标签:机器 tree dataset 学习 child label class best 决策树

import math
import pickle

from matplotlib import pyplot as plt


def calc_shang(dataset: list):
    """
    计算给定数据集的香农熵
    :param dataset:
    :return:
    """
    length = len(dataset)
    label_count_map = {}
    for item in dataset:
        current_label = item[-1]
        if current_label not in label_count_map:
            label_count_map[current_label] = 0
        label_count_map[current_label] += 1
    shang = 0.0
    for label, count in label_count_map.items():
        prob = count / length
        shang += prob * (-1 * math.log(prob, 2))
    return shang


def create_dataset():
    dataset = [
        [1, 1, "yes"],
        [1, 1, "yes"],
        [1, 0, "no"],
        [0, 1, "no"],
        [0, 1, "no"]
    ]
    labels = ["no surfacing", "flippers"]
    return dataset, labels


def split_dataset(dataset, axis, value):
    new_dataset = []
    for item in dataset:
        if item[axis] == value:
            reduced_item = item[:axis]
            reduced_item.extend(item[axis + 1:])
            new_dataset.append(reduced_item)
    return new_dataset


def choose_best_feature(dataset):
    num = len(dataset[0]) - 1
    shang = calc_shang(dataset)
    best_info_gain = 0
    best_feature = -1
    for i in range(num):
        feat_list = [_[i] for _ in dataset]
        unique_list = set(feat_list)
        _shang = 0
        for feat in unique_list:
            sub_dataset = split_dataset(dataset, i, feat)
            prob = len(sub_dataset) / len(dataset)
            _shang += prob * calc_shang(sub_dataset)
        info_gain = shang - _shang
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_feature = i
    return best_feature


def classify(class_list):
    class_count_map = {}
    for item in class_list:
        if item not in class_count_map:
            class_count_map[item] = 0
        class_count_map[item] += 1
    sorted_class_count_map = sorted(class_count_map.items(), key=lambda x: x[1], reverse=True)
    return sorted_class_count_map[0][0]


def create_tree(dataset, labels):
    class_list = [_[-1] for _ in dataset]
    if class_list.count(class_list[0]) == len(class_list):
        return class_list[0]
    best_feature = choose_best_feature(dataset)
    best_class_label = labels[best_feature]
    tree = {best_class_label: {}}
    del labels[best_feature]
    feat_values = [_[best_feature] for _ in dataset]
    unique_values = set(feat_values)
    for value in unique_values:
        sub_labels = labels[:]
        tree[best_class_label][value] = create_tree(split_dataset(dataset, best_feature, value), sub_labels)
    return tree


def plot_tree(tree, root_name):
    def _plot_tree(ax, tree, parent_name, parent_x, parent_y, dx, dy):
        if parent_name and parent_x == 0 and parent_y == 0:
            ax.text(0, 0, parent_name, ha='center', va='center', bbox=dict(facecolor='white', edgecolor='black'))
        if isinstance(tree, dict):
            # 遍历字典中的每个键值对
            for edge_label, child in tree.items():
                # 计算子节点的位置
                child_x = parent_x - dx / 2 if edge_label == 0 else parent_x + dx / 2
                child_y = parent_y - dy

                if isinstance(child, dict):
                    child_name = list(child.keys())[0]
                else:
                    child_name = child

                # 绘制边和边的描述
                ax.plot([parent_x, child_x], [parent_y, child_y], 'k-')
                mid_x = (parent_x + child_x) / 2
                mid_y = (parent_y + child_y) / 2
                ax.text(mid_x, mid_y, str(edge_label), ha='center', va='center', fontsize=8,
                        bbox=dict(facecolor='yellow', edgecolor='black'))

                # 绘制子节点
                ax.text(child_x, child_y, child_name, ha='center', va='center',
                        bbox=dict(facecolor='white', edgecolor='black'))

                # 递归绘制子树
                if isinstance(child, dict):
                    _plot_tree(ax, child[child_name], child_name, child_x, child_y, dx / 2, dy)

    fig, ax = plt.subplots(figsize=(10, 8))
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1.5, 0.5)
    ax.axis('off')

    _plot_tree(ax, tree[root_name], root_name, 0, 0, 1, 0.5)

    plt.show()


def classify_tree(tree: dict, labels: list, test_vec):
    first_str = list(tree.keys())[0]
    second_dict = tree[first_str]
    feat_index = labels.index(first_str)
    class_label = ""
    for key, value in second_dict.items():
        if test_vec[feat_index] == key:
            if isinstance(value, dict):
                class_label = classify_tree(value, labels, test_vec)
            else:
                class_label = value
    return class_label


def store_tree(tree: dict, file_path: str):
    with open(file_path, "wb") as f:
        pickle.dump(tree, f)


def grab_tree(file_path):
    with open(file_path, "rb") as f:
        return pickle.load(f)


if __name__ == '__main__':
    mat, labels = create_dataset()
    tree = create_tree(dataset=mat, labels=labels)
    plot_tree(tree, 'no surfacing')

其他决策树示例或者基于主流机器学习框架实现的决策树代码地址:

https://gitee.com/navysummer/machine-learning/tree/master/decision_tree

  

标签:机器,tree,dataset,学习,child,label,class,best,决策树
From: https://www.cnblogs.com/navysummer/p/18239639

相关文章

  • 详解FedAvg:联邦学习的开山之作
    FedAvg:2017年开山之作论文地址:https://proceedings.mlr.press/v54/mcmahan17a/mcmahan17a.pdf源码地址:https://github.com/shaoxiongji/federated-learning针对的问题:移动设备中有大量的数据,但显然我们不能收集这些数据到云端以进行集中训练,所以引入了一种分布式的机器......
  • 联邦学习中的非独立同分布Non-IID
    在联邦学习FederatedLearning中,出现的很高频的一个词就是Non-IID,翻译过来就是非独立同分布,这是一个来自于概率论与数理统计中的概念,下面我来简单介绍一下在FederatedLearning中IID和Non-IID的概念。何为IID(独立同分布)IID是数据独立同分布(IndependentIdenticallyDistri......
  • Objective-C 学习笔记 | 基础
    Objective-C学习笔记|基础参考书:《Objective-C编程(第2版)》第1部分入门Objective-C语言是以C语言为基础的,但增加了对面向对象编程的支持。Objective-C语言是用来开发在苹果iOS以及OSX操作系统上运行的应用的编程语言。第2部分如何编程该部分讲解了C语言编程的必......
  • 一起学习javascript-进阶版函数(1)
    <script>  //举个例子 functionsum(x){    returnx+1;  }    functionadd(a,b,f){    varc=f(a)+f(b)    console.log("c的值为:"+c);  }  //给add函数传参  add(3,6,sum);// 这里的f相当于为//......
  • 一起学习javascript-函数(2)
    <script>  //变量作用域与解构赋值  functionf1(y){  varx=1;  x=x+2;  console.log(x+y);  console.log(x);  //因为变量x在函数f1中申明,在这里属于局部变量,所以x只能在f1中访问,f1执行完,x就销毁了}f1(2);//为了更直观一点,调用......
  • HarmonyOS ArkTS组件 | Flex 以弹性方式布局子组件的容器组件 学习记录
    HarmonyOSArkTS组件|Flex以弹性方式布局子组件的容器组件学习记录前言:最近需要用到弹性布局,记录一下。(忽略图片水印QAQ)说明:Flex组件在渲染时存在二次布局过程,因此在对性能有严格要求的场景下建议使用Column、Row代替。Flex组件主轴默认不设置时撑满父容器,Column、Row组......
  • [无监督学习] 14.详细图解k-means 算法
    k-means算法把相似的数据汇总为簇的方法叫作聚类。k-means算法是一种聚类算法,该算法非常简单,所以被广泛应用于数据分析。概述k-means算法是一种有代表性的聚类算法。由于该算法简单易懂,又可以用于比较大的数据集,所以在市场分析和计算机视觉等领域得到了广泛的应用。我......
  • CUDA编程学习笔记-02
    CUDA代码高效计算策略高效公式✒️Math代表数学计算量,Memory代表每个线程的内存......
  • 【机器学习】与【数据挖掘】技术下【C++】驱动的【嵌入式】智能系统优化
    目录一、嵌入式系统简介二、C++在嵌入式系统中的优势三、机器学习在嵌入式系统中的挑战四、C++实现机器学习模型的基本步骤五、实例分析:使用C++在嵌入式系统中实现手写数字识别1.数据准备2.模型训练与压缩3.模型部署六、优化与分析1.模型优化模型量化模型剪枝......
  • 跟着GPT学习Java线程中断机制
    Java中的线程中断是一个复杂但非常重要的概念,它允许一个线程告知另一个线程希望它停止正在做的事情。这个机制是协作式的,意味着被请求中断的线程需要自己检查中断状态,并且决定如何响应中断请求。下面我将系统地讲解Java中的线程中断知识点。 1.中断标志每个线程都有一个内部......