首页 > 编程语言 >【算法】决策树算法:ID3

【算法】决策树算法:ID3

时间:2023-12-20 18:45:39浏览次数:25  
标签:info ID3 dataset 算法 gain label feat best 决策树

import math
from collections import Counter

# 创建数据集
def create_dataset():
    dataset = [
        # 年龄, 工作, 房子,信用,标签
        ['青年', 0, 0, '一般', '0'],
        ['青年', 0, 0, '好', '0'],
        ['青年', 1, 0, '好', '1'],
        ['青年', 1, 1, '一般', '1'],
        ['青年', 0, 0, '一般', '0'],
        ['中年', 0, 0, '一般', '0'],
        ['中年', 0, 0, '好', '0'],
        ['中年', 1, 1, '好', '1'],
        ['中年', 0, 1, '很好', '1'],
        ['中年', 0, 1, '很好', '1'],
        ['老年', 0, 1, '很好', '1'],
        ['老年', 0, 1, '好', '1'],
        ['老年', 1, 0, '好', '1'],
        ['老年', 1, 0, '很好', '1'],
        ['老年', 0, 0, '一般', '0']
    ]
    return dataset

# 计算熵
def cal_entropy(dataset):
    label_count = {}
    # 统计样本标签
    for item in dataset:
        # 样本标签
        label = item[-1]
        # 不在字典中
        if label not in label_count:
            label_count[label] = 0
        # 计数+1
        label_count[label] += 1
    # 计算熵
    entropy = 0.0
    for label in label_count:
        # 概率 = 样本数 / 样本总数
        p = label_count[label] / len(dataset)
        # 计算熵
        if p == 0:
            continue
        entropy -= p * math.log(p, 2)
    return entropy

# 计算条件熵
def cal_cond_entropy(dataset, feature, value):
    ret_dataset = []
    for item in dataset:
        if item[feature] == value:
            # 抽取当前特征左侧的数据
            except_item = item[:feature]
            # 抽取当前特征右侧的数据
            except_item.extend(item[feature + 1:])
            ret_dataset.append(except_item)
    return ret_dataset

# 计算信息增益
def cal_info_gain(dataset):
    # 样本数
    num_feature = len(dataset[0]) - 1
    # 计算基本熵
    base_entropy = cal_entropy(dataset)
    # 最优的信息增益
    best_info_gain = 0.0
    # 最优的信息增益的索引
    best_info_gain_feature = 0
    for i in range(num_feature):
        feature_list = [example[i] for example in dataset]
        feature_set = set(feature_list)
        conditional_entropy = 0.0

        for value in feature_set:
            # 计算条件熵
            sub_dataset = cal_cond_entropy(dataset, i, value)
            p = float(len(sub_dataset)) / len(dataset)
            conditional_entropy += p * cal_entropy(sub_dataset)

        info_gain = base_entropy - conditional_entropy
        # 选取最大的信息索引
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_info_gain_feature = i
    return best_info_gain_feature, best_info_gain

# 多数表决法决定叶子节点分类
def majority_cnt(class_list):
    class_count = Counter(class_list)
    sorted_class_count = sorted(class_count.items(), key=lambda x: x[1], reverse=True)
    return sorted_class_count[0][0]

# 构建决策树
def build_decision_tree(dataset, labels):
    class_list = [data[-1] for data in dataset]
    if class_list.count(class_list[0]) == len(class_list):  # 类别完全相同则停止继续划分
        return class_list[0]
    if len(dataset[0]) == 1:  # 遍历完所有特征时返回出现次数最多的类别
        return majority_cnt(class_list)
    best_feat, best_info_gain = cal_info_gain(dataset)
    best_feat_label = labels[best_feat]
    my_tree = {best_feat_label: {}}
    new_labels = labels[:]
    del(new_labels[best_feat])
    feat_values = [data[best_feat] for data in dataset]
    unique_vals = set(feat_values)
    for value in unique_vals:
        sub_labels = new_labels[:]
        my_tree[best_feat_label][value] = build_decision_tree(cal_cond_entropy(dataset, best_feat, value), sub_labels)
    return my_tree

# 使用决策树进行分类
def classify(input_tree, feat_labels, test_data):
    first_str = list(input_tree.keys())[0]
    second_dict = input_tree[first_str]
    feat_index = feat_labels.index(first_str)
    key = test_data[feat_index]
    value_of_feat = second_dict[key]
    if isinstance(value_of_feat, dict):
        class_label = classify(value_of_feat, feat_labels, test_data)
    else:
        class_label = value_of_feat
    return class_label

# ID3 算法举例
if __name__ == '__main__':
    dataset = create_dataset()
    labels = ['年龄', '工作', '房子', '信用']
    print("熵:", cal_entropy(dataset))
    best_info_gain_feature, best_info_gain = cal_info_gain(dataset)
    print("信息增益:", best_info_gain_feature, best_info_gain)

    tree = build_decision_tree(dataset, labels)
    print("决策树:", tree)
    print("测试数据:", dataset[0])
    result = classify(tree, labels, ['老年', 1, 0, '一般'])
    print("预测结果:", result)

运行效果:

标签:info,ID3,dataset,算法,gain,label,feat,best,决策树
From: https://www.cnblogs.com/yangyxd/p/17917226.html

相关文章

  • 安防升级!羚通视频智能分析平台助力安全帽、反光衣算法识别,让安全无处不在!
    在现代社会中,公共安全和个体防护已经成为了我们日常生活的重要组成部分。特别是在一些高风险的工作环境中,如建筑工地、交通警察等,安全帽和反光衣的使用是保障工作人员安全的重要手段。然而,传统的人工监控方式往往无法做到实时、准确的监控和识别,这就为羚通视频智能分析平台的出现......
  • 【算法】K-means 算法学习
    fromnumpyimport*importpandasaspdimportmatplotlib.pyplotasplt#计算两点之间的欧式距离defdist(a,b):returnsqrt(sum((a-b)**2))#生成聚类中心defcreate_center(data,k,defaultPts=[0,3,6]):#固定的几个点作为聚类中心ifdefaultP......
  • 羚通视频智能分析平台视频监控算法分析玩手机打电话检测
    在当今数字化时代,视频监控技术已经广泛应用于我们生活的各个领域。然而,传统的视频监控方式往往需要大量的人力进行监控和分析,这不仅效率低下,而且容易出错。为了解决这个问题,羚通公司推出了一款全新的视频智能分析平台,该平台利用先进的视频监控算法,可以实时检测并分析手机打电话的......
  • 《算法、C++、Linux、Android》
    ......
  • 羚通视频智能分析平台视频监控算法分析玩手机打电话检测
    在当今数字化时代,视频监控技术已经广泛应用于我们生活的各个领域。然而,传统的视频监控方式往往需要大量的人力进行监控和分析,这不仅效率低下,而且容易出错。为了解决这个问题,羚通公司推出了一款全新的视频智能分析平台,该平台利用先进的视频监控算法,可以实时检测并分析手机打电话的行......
  • python 数据结构与算法知识图
    1.算法思想:递归、分治(归并排序、二分查找、快速排序)、贪心(贪心策略排序+当前最优)、动态规划(最优子结构+递推式)、回溯(解空间:排列树+子集树、深度搜索+剪枝)、分支限界(解空间:排列树+子集树、广度搜索+剪枝))2.排序算法:(low:冒泡、插入、选择;mid:快排、归并、堆排(完全二叉树),其他:桶排序、基......
  • 2023最新初级难度算法面试题,包含答案。刷题必备!记录一下。
    好记性不如烂笔头内容来自面试宝典-初级难度算法面试题合集问:什么是排序?说出常见的排序算法有哪几种?排序是计算机科学中的一种基本操作,它将一组数据按照某种顺序进行排列。排序算法是实现排序过程的具体方法。常见的排序算法有多种,它们可以根据不同的数据结构、时间复杂......
  • 详解十大经典排序算法(五):归并排序(Merge Sort)
    算法原理归并排序的核心思想是将一个大的数组分割成多个小的子数组,然后分别对这些子数组进行排序,最后将排序后的子数组合并起来,得到一个有序的大数组。算法描述归并排序(MergeSort)是一种经典的排序算法,其原理基于分治(DivideandConquer)策略。它的核心思想是将一个大问题分割成多个......
  • 详解十大经典排序算法(四):希尔排序(Shell Sort)
    算法原理希尔排序是一种基于插入排序的排序算法,也被称为缩小增量排序。它通过将待排序的序列分割成若干个子序列,对每个子序列进行插入排序,然后逐步缩小增量,最终使整个序列有序。算法描述希尔排序(ShellSort)是一种基于插入排序的算法,由DonaldShell于1959年提出。它是插入排序的一种......
  • 单调栈求解算法
    例题:503. 下一个更大元素II给定一个循环数组 nums ( nums[nums.length-1] 的下一个元素是 nums[0] ),返回 nums 中每个元素的 下一个更大元素 。数字 x 的 下一个更大的元素 是按数组遍历顺序,这个数字之后的第一个比它更大的数,这意味着你应该循环地搜索它的下一......