首页 > 其他分享 >决策树

决策树

时间:2022-11-17 13:22:07浏览次数:27  
标签:结点 myTree plotTree dataSet bestFeat 决策树

一.什么是决策树
决策树算法是一种逼近离散函数值的方法。它是一种典型的分类方法,首先对数据进行处理,利用归纳算法生成可读的规则和决策树,然后使用决策对新数据进行分析。本质上决策树是通过一系列规则对数据进行分类的过程。

决策树算法构造决策树来发现数据中蕴涵的分类规则.如何构造精度高、规模小的决策树是决策树算法的核心内容。决策树构造可以分两步进行。第一步,决策树的生成:由训练样本集生成决策树的过程。一般情况下,训练样本数据集是根据实际需要有历史的、有一定综合程度的,用于数据分析处理的数据集。第二步,决策树的剪枝:决策树的剪枝是对上一阶段生成的决策树进行检验、校正和修下的过程,主要是用新的样本数据集(称为测试数据集)中的数据校验决策树生成过程中产生的初步规则,将那些影响预衡准确性的分枝剪除。

如ppt上样例,就是一个决策树
二.决策树的基本流程
(1)当前结点包含的样本全部属于同一类别C:
(2)当前属性集为空,或所有样本在所有属性上取值相同:
(3)当前结点包含的样本集合为空:
无需划分,叶子节点标记为类别C
当前结点标记为叶子节点,类别=该结点所含样本最
多的类别
当前结点标记为叶子节点,类别=该结点的父节点所
含样本最多的类别
三.创建决策树的算法
①递归操作:
a.选择属性 : 递归由上到下决定每一个节点的属性 , 依次递归构造决策树 ;
b.数据集划分 : 开始决策时 , 所有的数据都在树根 , 由树根属性来划分数据集 ;
c.属性离散化 : 如果属性的值是连续值 , 需要将连续属性值离散化 ; 如 : 100 分满分 , 将 60 分以下分为不及格数据 , 60 分以上分为及格数据 ;
②递归终止条件:
a.子树分类完成 : 节点上的子数据集都属于同一个类别 , 该节点就不再向下划分 , 称为叶子节点 ;
b.属性 ( 节点 ) 全部分配完毕 : 所有的属性都已经分配完毕 , 决策树的高度等于属性个数 ;
c.所有样本分类完毕 : 所有的样本数据集都分类完成 ;
四.构建决策树
1.输入数据集,先构建决策树所包含的数据

2.构建特征集,计算信息熵以获得所有数据集的判定输出

-- coding: UTF-8 --

import operator
from math import log

def createDataSet():
"""
创建数据集

:return: 数据集与特征集
"""
dataSet = [[706, 'hot', 'sunny', 'high', 'false', 'no'],
           [707, 'hot', 'sunny', 'high', 'true', 'no'],
           [708, 'hot', 'overcast', 'high', 'false', 'yes'],
           [709, 'cool', 'overcast', 'normal', 'false', 'yes'],
           [710, 'cool', 'overcast', 'normal', 'true', 'yes'],
           [713, 'mild', 'sunny', 'high', 'false', 'no'],
           [714, 'cool', 'sunny', 'normal', 'false', 'yes'],
           [715, 'mild', 'overcast', 'normal', 'false', 'yes'],
           [720, 'mild', 'sunny', 'normal', 'true', 'yes'],
           [722, 'mild', 'overcast', 'high', 'true', 'yes'],
           [721, 'hot', 'overcast', 'normal', 'false', 'yes'],
           [723, 'mild', 'sunny', 'high', 'true', 'no'],
           [727, 'cool', 'sunny', 'normal', 'true', 'no'],
           [730, 'mild', 'sunny', 'high', 'false', 'yes']]
labels = ['日期', '气候', '天气', '气温', '寒冷']
return dataSet, labels

def classCount(dataSet):
"""
获取每个特征出现的次数

:param dataSet: 数据集
:return:
"""

labelCount = {}
for one in dataSet:
    if one[-1] not in labelCount.keys():
        labelCount[one[-1]] = 0
    labelCount[one[-1]] += 1
return labelCount

def calcShannonEntropy(dataSet):
"""
计算系统信息熵

:param dataSet: 数据集
:return:
"""

labelCount = classCount(dataSet)
numEntries = len(dataSet)
Entropy = 0.0
for i in labelCount:
    prob = float(labelCount[i]) / numEntries
    Entropy -= prob * log(prob, 2)
return Entropy

def majorityClass(dataSet):
"""
找到对应结果最多的特征

:param dataSet: 数据集
:return:
"""
labelCount = classCount(dataSet)
sortedLabelCount = sorted(labelCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedLabelCount[0][0]

def splitDataSet(dataSet, i, value):
"""
非数值型特征划分
将 dataset 以第 i 个特征值为 value 作为基准划分为多个部分

:param dataSet: 数据集
:param i: 特征索引
:param value: 划分基准值
:return:
"""

subDataSet = []
for one in dataSet:
    if one[i] == value:
        reduceData = one[:i]
        reduceData.extend(one[i + 1:])
        subDataSet.append(reduceData)
return subDataSet

def splitContinuousDataSet(dataSet, i, value, direction):
"""
数值型特征划分
将 dataset 以第 i 个特征值为 value 作为基准划分为多个部分

:param dataSet: 数据集
:param i: 特征索引
:param value: 划分基准值
:param direction: 0. 左侧, 1. 右侧
:return:
"""

subDataSet = []
for one in dataSet:
    if direction == 0:
        if one[i] > value:
            reduceData = one[:i]
            reduceData.extend(one[i + 1:])
            subDataSet.append(reduceData)
    if direction == 1:
        if one[i] <= value:
            reduceData = one[:i]
            reduceData.extend(one[i + 1:])
            subDataSet.append(reduceData)
return subDataSet

def chooseBestFeat(dataSet, labels):
"""
获取最佳特征与特征对应的最佳划分值

:param dataSet: 数据集
:param labels: 特征集
:return:
"""

global bestSplit
""" 计算划分前系统的信息熵 """
baseEntropy = calcShannonEntropy(dataSet)
bestFeat = 0
baseGainRatio = -1
numFeats = len(dataSet[0]) - 1
bestSplitDic = {}

""" 遍历每个特征 """
for i in range(numFeats):
    """ 获取该特征所有值 """
    featVals = [example[i] for example in dataSet]
    uniVals = sorted(set(featVals))
    if type(featVals[0]).__name__ == 'float' or type(featVals[0]).__name__ == 'int':

        """ 用于区分的坐标值 """
        splitList = []
        for j in range(len(uniVals) - 1):
            splitList.append((uniVals[j] + uniVals[j + 1]) / 2.0)

        """ 计算信息增益比,找到最佳划分属性与划分阈值 """
        for j in range(len(splitList)):

            """ 该划分情况下熵值 """
            newEntropy = 0.0
            splitInfo = 0.0
            value = splitList[j]

            """ 划分出左右两侧数据集 """
            subDataSet0 = splitContinuousDataSet(dataSet, i, value, 0)
            subDataSet1 = splitContinuousDataSet(dataSet, i, value, 1)

            """ 计算划分后系统信息熵 """
            prob0 = float(len(subDataSet0)) / len(dataSet)
            newEntropy -= prob0 * calcShannonEntropy(subDataSet0)
            prob1 = float(len(subDataSet1)) / len(dataSet)
            newEntropy -= prob1 * calcShannonEntropy(subDataSet1)

            """ 获取惩罚参数 """
            splitInfo -= prob0 * log(prob0, 2)
            splitInfo -= prob1 * log(prob1, 2)

            """ 计算信息增益比 """
            gainRatio = float(baseEntropy - newEntropy) / splitInfo

            if gainRatio > baseGainRatio:
                baseGainRatio = gainRatio
                bestSplit = j
                bestFeat = i

        bestSplitDic[labels[i]] = splitList[bestSplit]
    else:
        splitInfo = 0.0
        newEntropy = 0.0
        for value in uniVals:
            """ 划分数据集 """
            subDataSet = splitDataSet(dataSet, i, value)

            """ 计算划分后系统信息熵 """
            prob = float(len(subDataSet)) / len(dataSet)
            newEntropy -= prob * calcShannonEntropy(subDataSet)

            """ 获取惩罚参数 """
            splitInfo -= prob * log(prob, 2)

        """ 计算信息增益比 """
        gainRatio = float(baseEntropy - newEntropy) / splitInfo
        if gainRatio > baseGainRatio:
            bestFeat = i
            baseGainRatio = gainRatio

bestFeatValue = None
if type(dataSet[0][bestFeat]).__name__ == 'float' or type(dataSet[0][bestFeat]).__name__ == 'int':
    bestFeatValue = bestSplitDic[labels[bestFeat]]
if type(dataSet[0][bestFeat]).__name__ == 'str':
    bestFeatValue = labels[bestFeat]
return bestFeat, bestFeatValue

def createTree(dataSet, labels):
"""
递归创建决策树

:param dataSet: 数据集
:param labels: 特征指标集
:return: 决策树字典结构
"""
classList = [example[-1] for example in dataSet]

if len(set(classList)) == 1:
    return classList[0]

if len(dataSet[0]) == 1:
    return majorityClass(dataSet)

""" 找到当前的最佳划分属性与划分阈值 """
bestFeat, bestFeatLabel = chooseBestFeat(dataSet, labels)

myTree = {labels[bestFeat]: {}}
subLabels = labels[:bestFeat]
subLabels.extend(labels[bestFeat + 1:])

if type(dataSet[0][bestFeat]).__name__ == 'str':
    featVals = [example[bestFeat] for example in dataSet]
    uniqueVals = set(featVals)

    """ 递归创建左右子树 """
    for value in uniqueVals:
        """ 获取去除该特征数据集 """
        reduceDataSet = splitDataSet(dataSet, bestFeat, value)
        myTree[labels[bestFeat]][value] = createTree(reduceDataSet, subLabels)

if type(dataSet[0][bestFeat]).__name__ == 'int' or type(dataSet[0][bestFeat]).__name__ == 'float':
    value = bestFeatLabel

    """ 划分数据集 """
    greaterDataSet = splitContinuousDataSet(dataSet, bestFeat, value, 0)
    smallerDataSet = splitContinuousDataSet(dataSet, bestFeat, value, 1)

    """ 递归创建左右子树 """
    myTree[labels[bestFeat]]['>' + str(value)] = createTree(greaterDataSet, subLabels)
    myTree[labels[bestFeat]]['<=' + str(value)] = createTree(smallerDataSet, subLabels)
return myTree

if name == 'main':
dataSet, labels = createDataSet()
print(createTree(dataSet, labels))

3.输出结果

五.展示决策树
1.先安装Graphviz绘图软件,我安装在了d盘,安装好后配置环境变量,在cmd中查看

2.根据上面的创建的决策树的数据集写展示决策树代码

-- coding: UTF-8 --

import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

def getNumLeafs(myTree):
"""
获取决策树叶子结点的数目

:param myTree: 决策树
:return: 决策树的叶子结点的数目
"""
numLeafs = 0  # 初始化叶子
firstStr = list(myTree.keys())[0]
secondDict = myTree[firstStr]  # 获取下一组字典
for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':  # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
        numLeafs += getNumLeafs(secondDict[key])
    else:
        numLeafs += 1
return numLeafs

def getTreeDepth(myTree):
"""
获取决策树的层数

:param myTree: 决策树
:return: 决策树的层数
"""
maxDepth = 0  # 初始化决策树深度
firstStr = next(iter(
    myTree))  # python3中myTree.keys()返回的是dict_keys,不在是list,所以不能使用myTree.keys()[0]的方法获取结点属性,可以使用list(myTree.keys())[0]
secondDict = myTree[firstStr]  # 获取下一个字典
for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':  # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
        thisDepth = 1 + getTreeDepth(secondDict[key])
    else:
        thisDepth = 1
    if thisDepth > maxDepth: maxDepth = thisDepth  # 更新层数
return maxDepth

def plotNode(nodeTxt, centerPt, parentPt, nodeType):

arrow_args = dict(arrowstyle="<-")  # 定义箭头格式
font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc", size=14)  # 设置中文字体
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',  # 绘制结点
                        xytext=centerPt, textcoords='axes fraction',
                        va="center", ha="center", bbox=nodeType, arrowprops=arrow_args, FontProperties=font)

def plotMidText(cntrPt, parentPt, txtString):

xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]  # 计算标注位置
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)

def plotTree(myTree, parentPt, nodeTxt):

decisionNode = dict(boxstyle="sawtooth", fc="0.8")  # 设置结点格式
leafNode = dict(boxstyle="round4", fc="0.8")  # 设置叶结点格式
numLeafs = getNumLeafs(myTree)  # 获取决策树叶结点数目,决定了树的宽度
firstStr = next(iter(myTree))  # 下个字典
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)  # 中心位置
plotMidText(cntrPt, parentPt, nodeTxt)  # 标注有向边属性值
plotNode(firstStr, cntrPt, parentPt, decisionNode)  # 绘制结点
secondDict = myTree[firstStr]  # 下一个字典,也就是继续绘制子结点
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD  # y偏移
for key in secondDict.keys():
    if type(secondDict[key]).__name__ == 'dict':  # 测试该结点是否为字典,如果不是字典,代表此结点为叶子结点
        plotTree(secondDict[key], cntrPt, str(key))  # 不是叶结点,递归调用继续绘制
    else:  # 如果是叶结点,绘制叶结点,并标注有向边属性值
        plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
        plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
        plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD

def createPlot(inTree):

fig = plt.figure(1, facecolor='white')  # 创建 fig
fig.clf()  # 清空 fig
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)  # 去掉 x、y 轴
plotTree.totalW = float(getNumLeafs(inTree))  # 获取决策树叶结点数目
plotTree.totalD = float(getTreeDepth(inTree))  # 获取决策树层数
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0  # x偏移
plotTree(inTree, (0.5, 1.0), '')  # 绘制决策树
plt.show()  # 显示绘制结果

if name == 'main':
myTree = {'日期':
{'>728.5': 'yes',
'<=728.5':
{'寒冷':
{'false':
{'气温':
{'high':
{'气候':
{'hot':
{'天气':
{'sunny': 'no',
'overcast': 'yes'
}
},

                                             }
                                        },
                            
                                  }
                              },
                         'true':
                             {'气温':
                                  {'high':
                                       {'气候':
                                            {'hot': 'no',
                                             'mild':
                                                 {'天气':
                                                      {'sunny': 'no',
                                                       'overcast': 'yes'}
                                                  }
                                             }                                            
                                        },
                                   'normal':
                                       {'气候':
                                            {'mild': 'yes',
                                             'cool':
                                                 {'天气':
                                                      {'sunny': 'no',
                                                       'overcast': 'yes'
                                                       }
                                                  }
                                             }
                                        }
                                   }
                              }
                         }
                    }
               }
          }
print(myTree)
createPlot(myTree)
# }}}

设置阈值为创建决策树时算出的728.5,然后决策树进行条件判断,从而得到分类和展示
3.运行结果

标签:结点,myTree,plotTree,dataSet,bestFeat,决策树
From: https://www.cnblogs.com/lh123456789/p/16899179.html

相关文章