首页 > 编程语言 >CART——Classification And Regression Tree在python下的实现

CART——Classification And Regression Tree在python下的实现

时间:2023-06-08 13:03:56浏览次数:58  
标签:返回 特征值 Classification 方差 python Tree 最佳 划分


分类与回归树(CART——Classification And Regression Tree)) 是一种非参数分类和回归方法,它通过构建二叉树达到预测目的。

示例:

1.样本数据集 

2.运行结果-cart决策树的字典

max_n_feats = 3时
tree_dict = {
          house :{
                  yes :  agree
                  no :{
                          working : {'yes': 'agree', 'no': 'refuse'}
                          }
                  }
          }


3.运行结果-决策树的绘制图形

max_n_feats = 3时

CART——Classification And Regression Tree在python下的实现_python

4.核心代码讲解

核心代码是:类class CCartTree(object)中的work()接口和create_tree()接口

    work()是cart算法,生成最优特征,最优切分点,最优叶节点等等

    create_tree()是递归生成cart决策树字典

树的限制递归生成阈值:

    max_n_feats,当剩下的样本集的特征数少于max_n_feats,将不再进行继续生成。

    也可以提供gini阈值~

具体:

一、CART ( Classification And Regression Tree) 分类回归树

1、基尼指数:

在分类问题中,假设有KK 个类,样本点属于第kk 类的概率为PkPk ,则概率分布的基尼指数定义为: 
Gini(P)=∑k=1KPk(1−Pk)=1−∑k=1KP2kGini(P)=∑k=1KPk(1−Pk)=1−∑k=1KPk2

在CART 分类问题中,基尼指数作为特征选择的依据:选择基尼指数最小的特征及切分点做为最优特征和最优切分点。

2、在回归问题中,特征选择及最佳划分特征值的依据是:划分后样本的均方差之和最小!

二、算法分析:

CART 主要包括特征选择、回归树的生成、剪枝三部分

数据特征停止划分的条件: 
1、当前数据集中的标签相同,返回当前的标签 
2、划分前后的总方差差距很小,数据不划分,返回的属性为空,返回的最佳划分值为当前所有标签的均值。 
3、划分后的左右两个数据集的样本数量较小,返回的属性为空,返回的最佳划分值为当前所有标签的均值。

若满足上述三个特征停止划分的条件,则返回的最佳特征为空,返回的最佳划分特征值会作为叶子结点。

注:CART是一棵二叉树。 在生成CART回归树过程中,一个特征可能会被使用不止一次,所以,不存在当前属性集为空的情况;

1、特征选择(依据:总方差最小)

输入:数据集、op = [m,n] 
输出:最佳特征、最佳划分特征值

m表示剪枝前总方差与剪枝后总方差差值的最小值; n: 数据集划分为左右两个子数据集后,子数据集中的样本的最少数量;

1、判断数据集中所有的样本标签是否相同,是:返回当前标签; 
2、遍历所有的样本特征,遍历每一个特征的特征值。计算出每一个特征值下的数据总方差,找出使总方差最小的特征、特征值 
3、比较划分前和划分后的总方差大小;若划分后总方差减少较小,则返回的最佳特征为空,返回的最佳划分特征值会为当前数据集标签的平均值。 
4、比较划分后的左右分支数据集样本中的数量,若某一分支数据集中样本少于指定数量op[1],则返回的最佳特征为空, 
返回的最佳划分特征值会为当前数据集标签的平均值。 
5、否则,返回使总方差最小的特征、特征值

二、回归树的生成函数 createTree 
输入:数据集 
输出:生成回归树 
1、得到当前数据集的最佳划分特征、最佳划分特征值 
2、若返回的最佳特征为空,则返回最佳划分特征值(作为叶子节点) 
3、声明一个字典,用于保存当前的最佳划分特征、最佳划分特征值 
4、执行二元切分;根据最佳划分特征、最佳划分特征值,将当前的数据划分为两部分 
5、在左子树中调用createTree 函数, 在右子树调用createTree 函数。 
6、返回树。

注:在生成的回归树模型中,划分特征、特征值、左节点、右节点均有相应的关键词对应。

三、(后)剪枝:(CART 树一定是二叉树,所以,如果发生剪枝,肯定是将两个叶子节点合并)

输入:树、测试集 
输出:树

1、判断测试集是否为空,是:对树进行塌陷处理 
2、判断树的左右分支是否为树结构,是:根据树当前的特征值、划分值将测试集分为Lset、Rset两个集合; 
3、判断树的左分支是否是树结构:是:在该子集递归调用剪枝过程; 
4、判断树的右分支是否是树结构:是:在该子集递归调用剪枝过程; 
5、判断当前树结构的两个节点是否为叶子节点: 
是: 
a、根据当前树结构,测试集划分为Lset,Rset两部分; 
b、计算没有合并时的总方差NoMergeError,即:测试集在Lset 和 Rset 的总方差之和; 
c、合并后,取叶子节点值为原左右叶子结点的均值。求取测试集在该节点处的总方差MergeError,; 
d、比较合并前后总方差的大小;若NoMergeError > MergeError,返回合并后的节点;否则,返回原来的树结构; 
否: 
返回树结构。

代码实现:数据集

#-*- coding:utf-8 -*-
from numpy import *
import numpy as np

# 三大步骤:

'''
1、特征的选择:标准:总方差最小
2、回归树的生成:停止划分的标准
3、剪枝:
'''

# 导入数据集
def loadData(filaName):
    dataSet = []
    fr = open(filaName)
    for line in fr.readlines():
        curLine = line.strip().split('\t')
        theLine = map(float, curLine)                 # map all elements to float()
        dataSet.append(theLine)
    return dataSet

# 特征选择:输入:       输出:最佳特征、最佳划分值
'''
1、选择标准
遍历所有的特征Fi:遍历每个特征的所有特征值Zi;找到Zi,划分后总的方差最小
停止划分的条件:
1、当前数据集中的标签相同,返回当前的标签
2、划分前后的总方差差距很小,数据不划分,返回的属性为空,返回的最佳划分值为当前所有标签的均值。
3、划分后的左右两个数据集的样本数量较小,返回的属性为空,返回的最佳划分值为当前所有标签的均值。
当划分的数据集满足上述条件之一,返回的最佳划分值作为叶子节点;
当划分后的数据集不满足上述要求时,找到最佳划分的属性,及最佳划分特征值
'''

# 计算总的方差
def GetAllVar(dataSet):
    return var(dataSet[:,-1])*shape(dataSet)[0]

# 根据给定的特征、特征值划分数据集
def dataSplit(dataSet,feature,featNumber):
    dataL =  dataSet[nonzero(dataSet[:,feature] > featNumber)[0],:]
    dataR = dataSet[nonzero(dataSet[:,feature] <= featNumber)[0],:]
    return dataL,dataR

# 特征划分
def choseBestFeature(dataSet,op = [1,4]):          # 三个停止条件可否当作是三个预剪枝操作
    if len(set(dataSet[:,-1].T.tolist()[0]))==1:     # 停止条件 1
        regLeaf = mean(dataSet[:,-1])         
        return None,regLeaf                   # 返回标签的均值作为叶子节点
    Serror = GetAllVar(dataSet)
    BestFeature = -1; BestNumber = 0; lowError = inf
    m,n = shape(dataSet) # m 个样本, n -1 个特征
    for i in range(n-1):    # 遍历每一个特征值
        for j in set(dataSet[:,i].T.tolist()[0]):
            dataL,dataR = dataSplit(dataSet,i,j)
            if shape(dataR)[0]<op[1] or shape(dataL)[0]<op[1]: continue  # 如果所给的划分后的数据集中样本数目甚少,则直接跳出
            tempError = GetAllVar(dataL) + GetAllVar(dataR)
            if tempError < lowError:
                lowError = tempError; BestFeature = i; BestNumber = j
    if Serror - lowError < op[0]:               # 停止条件 2   如果所给的数据划分前后的差别不大,则停止划分
        return None,mean(dataSet[:,-1])         
    dataL, dataR = dataSplit(dataSet, BestFeature, BestNumber)
    if shape(dataR)[0] < op[1] or shape(dataL)[0] < op[1]:        # 停止条件 3
        return None, mean(dataSet[:, -1])
    return BestFeature,BestNumber


# 决策树生成
def createTree(dataSet,op=[1,4]):
    bestFeat,bestNumber = choseBestFeature(dataSet,op)
    if bestFeat==None: return bestNumber
    regTree = {}
    regTree['spInd'] = bestFeat
    regTree['spVal'] = bestNumber
    dataL,dataR = dataSplit(dataSet,bestFeat,bestNumber)
    regTree['left'] = createTree(dataL,op)
    regTree['right'] = createTree(dataR,op)
    return  regTree

# 后剪枝操作
# 用于判断所给的节点是否是叶子节点
def isTree(Tree):
    return (type(Tree).__name__=='dict' )

# 计算两个叶子节点的均值
def getMean(Tree):
    if isTree(Tree['left']): Tree['left'] = getMean(Tree['left'])
    if isTree(Tree['right']):Tree['right'] = getMean(Tree['right'])
    return (Tree['left']+ Tree['right'])/2.0

# 后剪枝
def pruneTree(Tree,testData):
    if shape(testData)[0]==0: return getMean(Tree)
    if isTree(Tree['left'])or isTree(Tree['right']):
        dataL,dataR = dataSplit(testData,Tree['spInd'],Tree['spVal'])
    if isTree(Tree['left']):
        Tree['left'] = pruneTree(Tree['left'],dataL)
    if isTree(Tree['right']):
        Tree['right'] = pruneTree(Tree['right'],dataR)
    if not isTree(Tree['left']) and not isTree(Tree['right']):
        dataL,dataR = dataSplit(testData,Tree['spInd'],Tree['spVal'])
        errorNoMerge = sum(power(dataL[:,-1] - Tree['left'],2)) + sum(power(dataR[:,-1] - Tree['right'],2))
        leafMean = getMean(Tree)
        errorMerge = sum(power(testData[:,-1]-  leafMean,2))
        if errorNoMerge > errorMerge:
            print"the leaf merge"
            return leafMean
        else:
            return Tree
    else:
        return Tree

# 预测
def forecastSample(Tree,testData):
    if not isTree(Tree): return float(tree)
    # print"选择的特征是:" ,Tree['spInd']
    # print"测试数据的特征值是:" ,testData[Tree['spInd']]
    if testData[0,Tree['spInd']]>Tree['spVal']:
        if isTree(Tree['left']):
            return forecastSample(Tree['left'],testData)
        else:
            return float(Tree['left'])
    else:
        if isTree(Tree['right']):
            return forecastSample(Tree['right'],testData)
        else:
            return float(Tree['right'])

def TreeForecast(Tree,testData):
    m = shape(testData)[0]
    y_hat = mat(zeros((m,1)))
    for i in range(m):
        y_hat[i,0] = forecastSample(Tree,testData[i])
    return y_hat

if __name__=="__main__":
    print "hello world"
    dataMat = loadData("ex2.txt")
    dataMat = mat(dataMat)
    op = [1,6]    # 参数1:剪枝前总方差与剪枝后总方差差值的最小值;参数2:将数据集划分为两个子数据集后,子数据集中的样本的最少数量;        
    theCreateTree =  createTree(dataMat,op)
   # 测试数据
    dataMat2 = loadData("ex2test.txt")
    dataMat2 = mat(dataMat2)
    #thePruneTree =  pruneTree(theCreateTree, dataMat2)
    #print"剪枝后的后树:\n",thePruneTree
    y = dataMat2[:, -1]
    y_hat = TreeForecast(theCreateTree,dataMat2)
    print corrcoef(y_hat,y,rowvar=0)[0,1]              # 用预测值与真实值计算相关系数

 

标签:返回,特征值,Classification,方差,python,Tree,最佳,划分
From: https://blog.51cto.com/guog/6439029

相关文章

  • centos执行python脚本
    CentOS下载pyhon当pip下载失败,应该是版本太低了此时需要升级pip:#pip3执行pip3install--upgradepip#pip执行pipinstall--upgradepip#如果上面升级失败,可以试试python-mpipinstall--upgrade--forcepip解决方法1如果在升级过程中报标题中的错误,则通过g......
  • python selenium 浏览器操作 鼠标操作 键盘操作
    窗口截屏#截图driver.get_screenshot_as_file("C:\\Users\\95744\\Desktop\\test01\\test.png")关闭浏览器webdriver.quit()获取当前urldriver.current_url浏览器前进、后退、刷新#后退driver.back()#前进driver.forward()#刷新driver.refresh()......
  • python Qt实现最简单的程序
    1、创建一个程序,实例一个对象2、让这个对象跑起来3、创建组件4、设置标题5、展示出来点击查看代码fromPySide2.QtWidgetsimportQApplication,QMessageBoxfromPySide2.QtUiToolsimportQUiLoaderif__name__=="__main__":app=QApplication(sys.argv)w......
  • python 日志
    在自动化测试中,可以使用以下几种方式记录日志:1.使用内置的`print()`函数:#在需要记录日志的地方使用print()函数输出日志信息print("这是一条日志信息")2.使用标准库中的`logging`模块:importlogging#配置日志输出格式和级别logging.basicConfig(level=logging.INFO......
  • 【python基础】循环语句-break关键字
    1.break关键字break关键字,其作用是在循环中的代码块遇到此关键字,立刻跳出整个循环,执行循环外的下一条语句。其在while和for循环中的作用示意图如下:1.1break在while循环中的使用1.1.1不加else语句比如我们通过键盘输入单词,输出刚才的单词,编写程序如下所示:我们发现当我们输......
  • Python+Redis学习笔记
    首先,通过pip来安装操作redis的相关包,pipinstallredis然后导入我们要使用的模块,formredis.ClientimportRedis然后,通过docker启动redis,fromredis.clientimportRedisr=Redis(host="0.0.0.0",port=6379,db=0,password="")#r.set("kol_height",187)res=r.......
  • python 解析HTML和XML文档
    一、BeautifulSoupBeautifulSoup是一个Python包,用于解析HTML和XML文档。它可以快速而方便地从网页中提取信息,并以易于使用的方式对其进行处理。它支持各种解析器,包括内置的Python解析器和第三方解析器,例如lxml和html5lib。二、对标签提取代码示列以下是使用BeautifulSoup解析H......
  • 初步了解的python的正则表达式
    Python正则表达式|菜鸟教程(runoob.com)Python正则表达式 regex正则表达式是一个特殊的字符序列,它能帮助你方便的检查一个字符串是否与某种模式匹配。Python自1.5版本起增加了re模块,它提供Perl风格的正则表达式模式。re模块使Python语言拥有全部的正则表达式功能......
  • #yyds干货盘点#用Python实现简单的图像识别
    在这篇文章中,我们将使用Python和TensorFlow来实现一个简单的图像识别系统。我们将使用经典的MNIST数据集,这是一个包含手写数字的数据集,用于训练和测试图像识别系统。一、准备环境首先,我们需要安装所需的库。在这里,我们将使用TensorFlow和Keras。您可以使用以下命令安装这些库:pip......
  • Python程序与设计
    2-27在命令行窗口中启动的Python解释器中实现在Python自带的IDLE中实现print("Helloworld")编码规范每个import语句只导入一个模块,尽量避免一次导入多个模块不要在行尾添加分号“:”,也不要用分号将两条命令放在同一行建议每行不超过80个字符使用必要的空行可以增加代码的可读性运算......