首页 > 其他分享 >决策树

决策树

时间:2023-05-10 15:34:27浏览次数:33  
标签:myTree no plotTree dataset yes 节点 决策树

决策树

基础概念

计算

image
image

实战代码

import matplotlib.pyplot as plt
from math import log
import operator
from matplotlib import font_manager
font = font_manager.FontProperties(fname=r"c:\windows\fonts\SimHei.ttf")
def createDataSet():
    # dataSet = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
    #            [1, 0, 1, 0, 0, 0, 'yes'],
    #            [1, 0, 0, 0, 0, 0, 'yes'],
    #            [0, 0, 1, 0, 0, 0, 'yes'],
    #            [2, 0, 0, 0, 0, 0, 'yes'],
    #            [0, 1, 0, 0, 1, 1, 'yes'],
    #            [1, 1, 0, 1, 1, 1, 'yes'],
    #            [1, 1, 0, 0, 1, 0, 'yes'],
    #            [1, 1, 1, 1, 1, 0, 'no'],
    #            [0, 2, 2, 0, 2, 1, 'no'],
    #            [2, 2, 2, 2, 2, 0, 'no'],
    #            [2, 0, 0, 2, 2, 1, 'no'],
    #            [0, 1, 0, 1, 0, 0, 'no'],
    #            [2, 1, 1, 1, 0, 0, 'no'],
    #            [1, 1, 0, 0, 1, 1, 'no'],
    #            [2, 0, 0, 2, 2, 0, 'no'],
    #            [0, 0, 1, 1, 1, 0, 'no']]
    dataSet = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
               ['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 'yes'],
               ['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
               ['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 'yes'],
               ['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
               ['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 'yes'],
               ['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', 'yes'],
               ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', 'yes'],
               ['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', 'no'],
               ['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', 'no'],
               ['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 'no'],
               ['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', 'no'],
               ['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', 'no'],
               ['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', 'no'],
               ['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 'no'],
               ['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', 'no'],
               ['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', 'no']]
    labels = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
    return dataSet, labels
def createTree(dataset,labels,featLabels):
    """创建根结点并实现其基本流程"""
    classList = [example[-1] for example in dataset] 
    # 设置递归停止条件
    # 如果数据集很纯净,就返回当前类别yes和no,即一种特征就可以判断出来是好瓜好事坏瓜
    if (classList.count(classList[0])) == len(classList):
        return classList[0] 
    # 如果只剩下一列特征不能继续再划分时,停止递归
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    # 计算信息熵,取出最优属性,得出其索引值
    bestFeat = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat] # 最优根结点
    featLabels.append(bestFeatLabel) # 将其添加到featLabels参数中,储存每个分支的特征标签,即对于每个节点featLabels中的一个元素对应于该节点所选择的特征
    myTree = {bestFeatLabel:{}}
    # 删除该属性
    del labels[bestFeat]
    featValue = [example[bestFeat] for example in dataset]
    # 取出唯一值
    uniqueVals = set(featValue)
    # 利用其特征进行分叉
    for value in uniqueVals:
        sub_lables = labels[:]
        # 递归分叉
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sub_lables,featLabels)
    return myTree
def createTree(dataset,labels,featLabels):
    """创建根结点并实现其基本流程"""
    classList = [example[-1] for example in dataset] 
    # 设置递归停止条件
    # 如果数据集很纯净,就返回当前类别yes和no,即一种特征就可以判断出来是好瓜好事坏瓜
    if (classList.count(classList[0])) == len(classList):
        return classList[0] 
    # 如果只剩下一列特征不能继续再划分时,停止递归
    if len(dataset[0]) == 1:
        return majorityCnt(classList)
    # 计算信息熵,取出最优属性,得出其索引值
    bestFeat = chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat] # 最优根结点
    featLabels.append(bestFeatLabel) # 将其添加到featLabels参数中,储存每个分支的特征标签,即对于每个节点featLabels中的一个元素对应于该节点所选择的特征
    myTree = {bestFeatLabel:{}}
    # 删除该属性
    del labels[bestFeat]
    featValue = [example[bestFeat] for example in dataset]
    # 取出唯一值
    uniqueVals = set(featValue)
    # 利用其特征进行分叉
    for value in uniqueVals:
        sub_lables = labels[:]
        # 递归分叉
        myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sub_lables,featLabels)
    return myTree
def splitDataSet(dataset, axis, val):
    retDataSet = []
    for featVec in dataset:
        if featVec[axis] == val:
            reduceFeatVec = featVec[:axis]
            reduceFeatVec.extend(featVec[axis + 1:])
            retDataSet.append(reduceFeatVec)
    return retDataSet
# 用于切分数据集。函数的输入参数包括:数据集(dataset)、切分数据集的特征(axis)和需要返回的特征值(val)。
# 函数的输出是一个列表,其中包含了数据集中所有特征值为val的数据行,并且这些行已经去掉了特征值为axis的那一列。
# 函数通过遍历数据集中的每一行,判断该行的axis列是否等于val。
# 如果是,就把该行的axis列去掉,并将其它列组成一个新的列表redceFeatVec,然后将这个列表添加到retDataSet中。
# 最后,函数返回retDataSet。
def calcShannonEnt(dataset):
    """计算熵值"""
    numexamples = len(dataset) # 总体数据
    labelCounts = {}
    # 取出yes和no便于后面计算概率
    for featVec in dataset:
        currentlabel = featVec[-1]
        if currentlabel not in labelCounts.keys():labelCounts[currentlabel] = 0
        labelCounts[currentlabel] += 1
    shannonEnt = 0
    # 计算熵值
    for key in labelCounts:
        prop = float(labelCounts[key])/numexamples
        shannonEnt -= prop * log(prop,2)
    return shannonEnt
def majorityCnt(classList):
    """计算多数类别是哪一个"""
    classCount = {}
    # 这个for循环是将classList中的yes和no统计出来保存在字典中
    for vote in classList:
        # 如果不在这个字典里面就将他的key设置为零
        if vote not in classCount.keys():classCount[vote] = 0
        # 在的话就+=1
        classCount[vote] +=1
    sortedclassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
    # print(sortedclassCount)
    return sortedclassCount[0][0]

def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs
def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = next(iter(myTree))
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = 1 + getTreeDepth(secondDict[key])
        else:
            thisDepth = 1
        if thisDepth > maxDepth: maxDepth = thisDepth
    return maxDepth
# 这段代码是一个递归函数,计算输入的决策树中叶节点的数量。
# 1、函数获取决策树的根节点,并进入其子树。
# 2、遍历该子树的每个分支,如果当前节点是一个字典类型,则递归调用该函数,继续遍历其子节点。
# 如果当前节点是叶节点,即不再包含子节点,则将叶节点的数量加1。
# 3、函数返回整个决策树中叶节点的总数。
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    arrow_args = dict(arrowstyle="<-")
    # font = FontProperties(fname=r"C:\Windows\Fonts\Corbel.ttf", size=14)
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
                            xytext=centerPt, textcoords='axes fraction',
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args,fontproperties=font)

# 是用来计算一个嵌套字典的深度(也可以理解为树的深度)。
# 其中,输入参数`myTree`是一个嵌套字典,表示一个树形结构,每个节点都是一个字典。函数返回值是这个树的深度。
# 1. 初始化变量`maxDepth`为0,表示当前树的深度为0。
# 2. 从字典`myTree`中获取第一个键值对,即根节点。将根节点的值(一个字典)赋值给变量`secondDict`。
# 3. 遍历`secondDict`中的每个键,判断对应的值是否为字典。如果是字典,则递归调用`getTreeDepth`函数,计算以该节点为根的子树的深度。
# 否则,该节点为叶子节点,深度为1。
# 4. 将当前节点的深度`thisDepth`与当前最大深度`maxDepth`比较,更新`maxDepth`为较大值。
# 5. 返回最大深度`maxDepth`。

# def plotMidText(cntrPt, parentPt, txtString):
#     xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
#     yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
#
#     createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotMidText(centerPt, parentPt, txtString):
    xMid = (parentPt[0] - centerPt[0]) / 2.0 + centerPt[0]
    yMid = (parentPt[1] - centerPt[1]) / 2.0 + centerPt[1]
    createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30,fontproperties=font)
# 该函数用于在父节点和子节点之间画出标签,显示它们之间的关系。具体解释如下:
# - centerPt:当前节点的坐标
# - parentPt:父节点的坐标
# - txtString:标签内容
# - xMid:标签的x坐标,计算方式为父节点和子节点的x坐标的平均值
# - yMid:标签的y坐标,计算方式为父节点和子节点的y坐标的平均值
# - va:标签的垂直对齐方式,"center"表示居中对齐
# - ha:标签的水平对齐方式,"center"表示居中对齐
# - rotation:标签的旋转角度,30表示旋转30度
# - fontproperties:字体属性,用于设置标签的字体大小、颜色等。
    
def plotTree(myTree, parentPt, nodeTxt):
    decisionNode = dict(boxstyle="sawtooth", fc="0.8")
    leafNode = dict(boxstyle="round4", fc="0.8")
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = next(iter(myTree))
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
    
# 代码用来绘制决策树的。
# 首先,定义了两个字典decisionNode和leafNode,分别表示决策节点和叶子节点的样式。
# 然后,获取决策树的叶子节点数量和深度,以及根节点的属性名firstStr。
# 接着,计算当前节点的中心位置cntrPt,并调用plotMidText和plotNode函数绘制节点的文本和样式。
# 然后,获取该节点的子节点secondDict,遍历其所有键值对,如果值是一个字典,则说明该节点不是叶子节点,需要递归调用plotTree函数来绘制其子节点。
# 如果值不是一个字典,则说明该节点是叶子节点,需要调用plotNode和plotMidText函数绘制该节点的样式和文本。
# 最后,更新plotTree的yOff值,以便绘制下一个节点。
def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')  # 创建fig
    fig.clf()  # 清空fig
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # 去掉x、y轴
    plotTree.totalW = float(getNumLeafs(inTree))  # 获取决策树叶结点数目
    plotTree.totalD = float(getTreeDepth(inTree))  # 获取决策树层数
    plotTree.xOff = -0.5 / plotTree.totalW
    plotTree.yOff = 1.0  # x偏移
    plotTree(inTree, (0.5, 1.0), '')  # 绘制决策树
    plt.show()

# 创建并显示决策树的可视化图形的。
# 首先,创建一个白色背景的fig对象,并清空该对象。
# 然后,定义一个字典axprops,用于去掉x、y轴的刻度。
# 接着,创建一个子图ax1,并将其frameon属性设置为False,以便去掉边框。
# 接下来,获取决策树的叶子节点数量和深度,并初始化plotTree的xOff和yOff值。
# 最后,调用plotTree函数绘制决策树,并显示图形。
if __name__ == '__main__':
    dataset, labels = createDataSet()
    featLabels = []
    myTree = createTree(dataset, labels, featLabels)
    createPlot(myTree)

# 这段代码是用来测试决策树算法的。
# 首先,调用createDataSet函数生成一个简单的数据集和标签。
# 然后,定义一个空列表featLabels,用于存储决策树的属性标签。
# 接着,调用createTree函数生成决策树,并将属性标签存储在featLabels中。
# 最后,调用createPlot函数绘制决策树的可视化图形。

png


标签:myTree,no,plotTree,dataset,yes,节点,决策树
From: https://www.cnblogs.com/fyuan0206/p/17388123.html

相关文章

  • R语言决策树、随机森林、逻辑回归临床决策分析NIPPV疗效和交叉验证
    全文链接:http://tecdat.cn/?p=32295原文出处:拓端数据部落公众号临床决策(clinical decision making)是医务人员在临床实践过程中,根据国内外医学科研的最新进展,不断提出新方案,与传统方案进行比较后,取其最优者付诸实施,从而提高疾病诊治水平的过程。在临床医疗实践中,许多事件......
  • 机器学习算法 随机森林学习 之决策树
    随机森林是基于集体智慧的一个机器学习算法,也是目前最好的机器学习算法之一。随机森林实际是一堆决策树的组合(正如其名,树多了就是森林了)。在用于分类一个新变量时,相关的检测数据提交给构建好的每个分类树。每个树给出一个分类结果,最终选择被最多的分类树支持的分类结果。回归则是不......
  • 决策树算法总结
    决策树(DecisionTree)决策树是一种树形结构,以信息熵为度量构造一棵熵值下降最快的树,它每个内部节点表示在某个特征上的分割使得分割前后熵值下降最快,到叶子结点处的熵值为零,此时每个叶结点中的样本都被归为同一类(训练时叶结点中数据的真实类别未必为同一类)。决策树算法递归的选择......
  • m基于ID3决策树算法的能量管理系统matlab仿真
    1.算法描述       ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成的决策树能完美分类训练样例。    ......
  • PYTHON银行机器学习:回归、随机森林、KNN近邻、决策树、高斯朴素贝叶斯、支持向量机SV
    全文下载链接:http://tecdat.cn/?p=26219最近我们被客户要求撰写关于银行机器学习的研究报告,包括一些图形和统计输出。该数据与银行机构的直接营销活动相关,营销活动基于电话。通常,需要与同一客户的多个联系人联系,以便访问产品(银行定期存款)是否会(“是”)或不会(“否”)订阅银行数据集我......
  • 基于决策树算法的病例类型诊断matlab仿真
    1.算法仿真效果matlab2022a仿真结果如下:2.算法涉及理论知识概要ID3算法是一种贪心算法,用来构造决策树。ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性的标准,即在每个节点选取还尚未被用来划分的具有最高信息增益的属性作为划分标准,然后继续这个过程,直到生成......
  • 数据分享|R语言决策树和随机森林分类电信公司用户流失churn数据和参数调优、ROC曲线可
    原文链接:http://tecdat.cn/?p=26868最近我们被客户要求撰写关于电信公司用户流失的研究报告,包括一些图形和统计输出。在本教程中,我们将学习覆盖决策树和随机森林。这些是可用于分类或回归的监督学习算法下面的代码将加载本教程所需的包和数据集。library(tidyverse)# 电信......
  • Chapter3 绘制决策树
    绘制决策树1.概述我们在上个博客已经学会使用代码来构造决策树了。但是,为了让构造出来的决策树具有可读性,我们还需要绘制决策树。2.设定样式#该代码的作用是设定节点和箭头的样式#该代码位于treePlotter.py文件中importmatplotlib.pyplotasplt'''在mat......
  • 数据分享|R语言用RFM、决策树模型顾客购书行为的数据预测|附代码数据
    全文链接:http://tecdat.cn/?p=30330最近我们被客户要求撰写关于RFM、决策树模型的研究报告,包括一些图形和统计输出。团队需要分析一个来自在线零售商的数据该数据包含了78周的购买历史。该数据文件中的每条记录包括四个字段。客户的ID(从1到2357不等),交易日期,购买的书籍数量,以及......
  • 决策树可视化Graphviz中文乱码
    输出svg时中文显示正常!!!fromsiximportStringIO#可视化dot_data=StringIO()tree.export_graphviz(clf,out_file=dot_data,feature_names=feature_name,class_names=target_name,filled=True,rounded=True,special_characte......