决策树
基础概念
计算
实战代码
import matplotlib.pyplot as plt
from math import log
import operator
from matplotlib import font_manager
font = font_manager.FontProperties(fname=r"c:\windows\fonts\SimHei.ttf")
def createDataSet():
# dataSet = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
# [1, 0, 1, 0, 0, 0, 'yes'],
# [1, 0, 0, 0, 0, 0, 'yes'],
# [0, 0, 1, 0, 0, 0, 'yes'],
# [2, 0, 0, 0, 0, 0, 'yes'],
# [0, 1, 0, 0, 1, 1, 'yes'],
# [1, 1, 0, 1, 1, 1, 'yes'],
# [1, 1, 0, 0, 1, 0, 'yes'],
# [1, 1, 1, 1, 1, 0, 'no'],
# [0, 2, 2, 0, 2, 1, 'no'],
# [2, 2, 2, 2, 2, 0, 'no'],
# [2, 0, 0, 2, 2, 1, 'no'],
# [0, 1, 0, 1, 0, 0, 'no'],
# [2, 1, 1, 1, 0, 0, 'no'],
# [1, 1, 0, 0, 1, 1, 'no'],
# [2, 0, 0, 2, 2, 0, 'no'],
# [0, 0, 1, 1, 1, 0, 'no']]
dataSet = [['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 'yes'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', 'yes'],
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 'yes'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 'yes'],
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', 'yes'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', 'yes'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', 'no'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', 'no'],
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', 'no'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', 'no'],
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', 'no'],
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', 'no'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', 'no'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', 'no'],
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', 'no']]
labels = ['色泽', '根蒂', '敲声', '纹理', '脐部', '触感']
return dataSet, labels
def createTree(dataset,labels,featLabels):
"""创建根结点并实现其基本流程"""
classList = [example[-1] for example in dataset]
# 设置递归停止条件
# 如果数据集很纯净,就返回当前类别yes和no,即一种特征就可以判断出来是好瓜好事坏瓜
if (classList.count(classList[0])) == len(classList):
return classList[0]
# 如果只剩下一列特征不能继续再划分时,停止递归
if len(dataset[0]) == 1:
return majorityCnt(classList)
# 计算信息熵,取出最优属性,得出其索引值
bestFeat = chooseBestFeatureToSplit(dataset)
bestFeatLabel = labels[bestFeat] # 最优根结点
featLabels.append(bestFeatLabel) # 将其添加到featLabels参数中,储存每个分支的特征标签,即对于每个节点featLabels中的一个元素对应于该节点所选择的特征
myTree = {bestFeatLabel:{}}
# 删除该属性
del labels[bestFeat]
featValue = [example[bestFeat] for example in dataset]
# 取出唯一值
uniqueVals = set(featValue)
# 利用其特征进行分叉
for value in uniqueVals:
sub_lables = labels[:]
# 递归分叉
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sub_lables,featLabels)
return myTree
def createTree(dataset,labels,featLabels):
"""创建根结点并实现其基本流程"""
classList = [example[-1] for example in dataset]
# 设置递归停止条件
# 如果数据集很纯净,就返回当前类别yes和no,即一种特征就可以判断出来是好瓜好事坏瓜
if (classList.count(classList[0])) == len(classList):
return classList[0]
# 如果只剩下一列特征不能继续再划分时,停止递归
if len(dataset[0]) == 1:
return majorityCnt(classList)
# 计算信息熵,取出最优属性,得出其索引值
bestFeat = chooseBestFeatureToSplit(dataset)
bestFeatLabel = labels[bestFeat] # 最优根结点
featLabels.append(bestFeatLabel) # 将其添加到featLabels参数中,储存每个分支的特征标签,即对于每个节点featLabels中的一个元素对应于该节点所选择的特征
myTree = {bestFeatLabel:{}}
# 删除该属性
del labels[bestFeat]
featValue = [example[bestFeat] for example in dataset]
# 取出唯一值
uniqueVals = set(featValue)
# 利用其特征进行分叉
for value in uniqueVals:
sub_lables = labels[:]
# 递归分叉
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataset,bestFeat,value),sub_lables,featLabels)
return myTree
def splitDataSet(dataset, axis, val):
retDataSet = []
for featVec in dataset:
if featVec[axis] == val:
reduceFeatVec = featVec[:axis]
reduceFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reduceFeatVec)
return retDataSet
# 用于切分数据集。函数的输入参数包括:数据集(dataset)、切分数据集的特征(axis)和需要返回的特征值(val)。
# 函数的输出是一个列表,其中包含了数据集中所有特征值为val的数据行,并且这些行已经去掉了特征值为axis的那一列。
# 函数通过遍历数据集中的每一行,判断该行的axis列是否等于val。
# 如果是,就把该行的axis列去掉,并将其它列组成一个新的列表redceFeatVec,然后将这个列表添加到retDataSet中。
# 最后,函数返回retDataSet。
def calcShannonEnt(dataset):
"""计算熵值"""
numexamples = len(dataset) # 总体数据
labelCounts = {}
# 取出yes和no便于后面计算概率
for featVec in dataset:
currentlabel = featVec[-1]
if currentlabel not in labelCounts.keys():labelCounts[currentlabel] = 0
labelCounts[currentlabel] += 1
shannonEnt = 0
# 计算熵值
for key in labelCounts:
prop = float(labelCounts[key])/numexamples
shannonEnt -= prop * log(prop,2)
return shannonEnt
def majorityCnt(classList):
"""计算多数类别是哪一个"""
classCount = {}
# 这个for循环是将classList中的yes和no统计出来保存在字典中
for vote in classList:
# 如果不在这个字典里面就将他的key设置为零
if vote not in classCount.keys():classCount[vote] = 0
# 在的话就+=1
classCount[vote] +=1
sortedclassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
# print(sortedclassCount)
return sortedclassCount[0][0]
def getNumLeafs(myTree):
numLeafs = 0
firstStr = next(iter(myTree))
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
numLeafs += getNumLeafs(secondDict[key])
else:
numLeafs += 1
return numLeafs
def getTreeDepth(myTree):
maxDepth = 0
firstStr = next(iter(myTree))
secondDict = myTree[firstStr]
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
thisDepth = 1 + getTreeDepth(secondDict[key])
else:
thisDepth = 1
if thisDepth > maxDepth: maxDepth = thisDepth
return maxDepth
# 这段代码是一个递归函数,计算输入的决策树中叶节点的数量。
# 1、函数获取决策树的根节点,并进入其子树。
# 2、遍历该子树的每个分支,如果当前节点是一个字典类型,则递归调用该函数,继续遍历其子节点。
# 如果当前节点是叶节点,即不再包含子节点,则将叶节点的数量加1。
# 3、函数返回整个决策树中叶节点的总数。
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
arrow_args = dict(arrowstyle="<-")
# font = FontProperties(fname=r"C:\Windows\Fonts\Corbel.ttf", size=14)
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args,fontproperties=font)
# 是用来计算一个嵌套字典的深度(也可以理解为树的深度)。
# 其中,输入参数`myTree`是一个嵌套字典,表示一个树形结构,每个节点都是一个字典。函数返回值是这个树的深度。
# 1. 初始化变量`maxDepth`为0,表示当前树的深度为0。
# 2. 从字典`myTree`中获取第一个键值对,即根节点。将根节点的值(一个字典)赋值给变量`secondDict`。
# 3. 遍历`secondDict`中的每个键,判断对应的值是否为字典。如果是字典,则递归调用`getTreeDepth`函数,计算以该节点为根的子树的深度。
# 否则,该节点为叶子节点,深度为1。
# 4. 将当前节点的深度`thisDepth`与当前最大深度`maxDepth`比较,更新`maxDepth`为较大值。
# 5. 返回最大深度`maxDepth`。
# def plotMidText(cntrPt, parentPt, txtString):
# xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
# yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
#
# createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotMidText(centerPt, parentPt, txtString):
xMid = (parentPt[0] - centerPt[0]) / 2.0 + centerPt[0]
yMid = (parentPt[1] - centerPt[1]) / 2.0 + centerPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30,fontproperties=font)
# 该函数用于在父节点和子节点之间画出标签,显示它们之间的关系。具体解释如下:
# - centerPt:当前节点的坐标
# - parentPt:父节点的坐标
# - txtString:标签内容
# - xMid:标签的x坐标,计算方式为父节点和子节点的x坐标的平均值
# - yMid:标签的y坐标,计算方式为父节点和子节点的y坐标的平均值
# - va:标签的垂直对齐方式,"center"表示居中对齐
# - ha:标签的水平对齐方式,"center"表示居中对齐
# - rotation:标签的旋转角度,30表示旋转30度
# - fontproperties:字体属性,用于设置标签的字体大小、颜色等。
def plotTree(myTree, parentPt, nodeTxt):
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
numLeafs = getNumLeafs(myTree)
depth = getTreeDepth(myTree)
firstStr = next(iter(myTree))
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decisionNode)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
# 代码用来绘制决策树的。
# 首先,定义了两个字典decisionNode和leafNode,分别表示决策节点和叶子节点的样式。
# 然后,获取决策树的叶子节点数量和深度,以及根节点的属性名firstStr。
# 接着,计算当前节点的中心位置cntrPt,并调用plotMidText和plotNode函数绘制节点的文本和样式。
# 然后,获取该节点的子节点secondDict,遍历其所有键值对,如果值是一个字典,则说明该节点不是叶子节点,需要递归调用plotTree函数来绘制其子节点。
# 如果值不是一个字典,则说明该节点是叶子节点,需要调用plotNode和plotMidText函数绘制该节点的样式和文本。
# 最后,更新plotTree的yOff值,以便绘制下一个节点。
def createPlot(inTree):
fig = plt.figure(1, facecolor='white') # 创建fig
fig.clf() # 清空fig
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops) # 去掉x、y轴
plotTree.totalW = float(getNumLeafs(inTree)) # 获取决策树叶结点数目
plotTree.totalD = float(getTreeDepth(inTree)) # 获取决策树层数
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0 # x偏移
plotTree(inTree, (0.5, 1.0), '') # 绘制决策树
plt.show()
# 创建并显示决策树的可视化图形的。
# 首先,创建一个白色背景的fig对象,并清空该对象。
# 然后,定义一个字典axprops,用于去掉x、y轴的刻度。
# 接着,创建一个子图ax1,并将其frameon属性设置为False,以便去掉边框。
# 接下来,获取决策树的叶子节点数量和深度,并初始化plotTree的xOff和yOff值。
# 最后,调用plotTree函数绘制决策树,并显示图形。
if __name__ == '__main__':
dataset, labels = createDataSet()
featLabels = []
myTree = createTree(dataset, labels, featLabels)
createPlot(myTree)
# 这段代码是用来测试决策树算法的。
# 首先,调用createDataSet函数生成一个简单的数据集和标签。
# 然后,定义一个空列表featLabels,用于存储决策树的属性标签。
# 接着,调用createTree函数生成决策树,并将属性标签存储在featLabels中。
# 最后,调用createPlot函数绘制决策树的可视化图形。
标签:myTree,no,plotTree,dataset,yes,节点,决策树
From: https://www.cnblogs.com/fyuan0206/p/17388123.html