import math import pickle from matplotlib import pyplot as plt def calc_shang(dataset: list): """ 计算给定数据集的香农熵 :param dataset: :return: """ length = len(dataset) label_count_map = {} for item in dataset: current_label = item[-1] if current_label not in label_count_map: label_count_map[current_label] = 0 label_count_map[current_label] += 1 shang = 0.0 for label, count in label_count_map.items(): prob = count / length shang += prob * (-1 * math.log(prob, 2)) return shang def create_dataset(): dataset = [ [1, 1, "yes"], [1, 1, "yes"], [1, 0, "no"], [0, 1, "no"], [0, 1, "no"] ] labels = ["no surfacing", "flippers"] return dataset, labels def split_dataset(dataset, axis, value): new_dataset = [] for item in dataset: if item[axis] == value: reduced_item = item[:axis] reduced_item.extend(item[axis + 1:]) new_dataset.append(reduced_item) return new_dataset def choose_best_feature(dataset): num = len(dataset[0]) - 1 shang = calc_shang(dataset) best_info_gain = 0 best_feature = -1 for i in range(num): feat_list = [_[i] for _ in dataset] unique_list = set(feat_list) _shang = 0 for feat in unique_list: sub_dataset = split_dataset(dataset, i, feat) prob = len(sub_dataset) / len(dataset) _shang += prob * calc_shang(sub_dataset) info_gain = shang - _shang if info_gain > best_info_gain: best_info_gain = info_gain best_feature = i return best_feature def classify(class_list): class_count_map = {} for item in class_list: if item not in class_count_map: class_count_map[item] = 0 class_count_map[item] += 1 sorted_class_count_map = sorted(class_count_map.items(), key=lambda x: x[1], reverse=True) return sorted_class_count_map[0][0] def create_tree(dataset, labels): class_list = [_[-1] for _ in dataset] if class_list.count(class_list[0]) == len(class_list): return class_list[0] best_feature = choose_best_feature(dataset) best_class_label = labels[best_feature] tree = {best_class_label: {}} del labels[best_feature] feat_values = [_[best_feature] for _ in dataset] unique_values = set(feat_values) for value in unique_values: sub_labels = labels[:] tree[best_class_label][value] = create_tree(split_dataset(dataset, best_feature, value), sub_labels) return tree def plot_tree(tree, root_name): def _plot_tree(ax, tree, parent_name, parent_x, parent_y, dx, dy): if parent_name and parent_x == 0 and parent_y == 0: ax.text(0, 0, parent_name, ha='center', va='center', bbox=dict(facecolor='white', edgecolor='black')) if isinstance(tree, dict): # 遍历字典中的每个键值对 for edge_label, child in tree.items(): # 计算子节点的位置 child_x = parent_x - dx / 2 if edge_label == 0 else parent_x + dx / 2 child_y = parent_y - dy if isinstance(child, dict): child_name = list(child.keys())[0] else: child_name = child # 绘制边和边的描述 ax.plot([parent_x, child_x], [parent_y, child_y], 'k-') mid_x = (parent_x + child_x) / 2 mid_y = (parent_y + child_y) / 2 ax.text(mid_x, mid_y, str(edge_label), ha='center', va='center', fontsize=8, bbox=dict(facecolor='yellow', edgecolor='black')) # 绘制子节点 ax.text(child_x, child_y, child_name, ha='center', va='center', bbox=dict(facecolor='white', edgecolor='black')) # 递归绘制子树 if isinstance(child, dict): _plot_tree(ax, child[child_name], child_name, child_x, child_y, dx / 2, dy) fig, ax = plt.subplots(figsize=(10, 8)) ax.set_xlim(-1, 1) ax.set_ylim(-1.5, 0.5) ax.axis('off') _plot_tree(ax, tree[root_name], root_name, 0, 0, 1, 0.5) plt.show() def classify_tree(tree: dict, labels: list, test_vec): first_str = list(tree.keys())[0] second_dict = tree[first_str] feat_index = labels.index(first_str) class_label = "" for key, value in second_dict.items(): if test_vec[feat_index] == key: if isinstance(value, dict): class_label = classify_tree(value, labels, test_vec) else: class_label = value return class_label def store_tree(tree: dict, file_path: str): with open(file_path, "wb") as f: pickle.dump(tree, f) def grab_tree(file_path): with open(file_path, "rb") as f: return pickle.load(f) if __name__ == '__main__': mat, labels = create_dataset() tree = create_tree(dataset=mat, labels=labels) plot_tree(tree, 'no surfacing')
其他决策树示例或者基于主流机器学习框架实现的决策树代码地址:
https://gitee.com/navysummer/machine-learning/tree/master/decision_tree
标签:机器,tree,dataset,学习,child,label,class,best,决策树 From: https://www.cnblogs.com/navysummer/p/18239639