首页 > 其他分享 >sklearn中的决策树(1)—— 分类树

sklearn中的决策树(1)—— 分类树

时间:2022-08-27 17:12:03浏览次数:59  
标签:0.0 clf 分类 value entropy samples sklearn class 决策树

 

sklearn中的决策树(1)—— 分类树

   

DecisionTreeClassifier

   

重要参数

   
  • Criterion: 不纯度,gini & entropy

    entropy对不纯度更加敏感,即对不纯度的惩罚更强,由于这种特性,决策树的生长会更加“精细”,对高维数据很容易过拟合

   

实例: 红酒数据集

  In [54]:
from sklearn import tree
from sklearn.datasets import load_wine 
from sklearn.model_selection import train_test_split 
  In [55]:
#字典
wine = load_wine()

# 用pandas整理数据
import pandas as pd
pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)
  Out[55]:  
 01234567891011120
0 14.23 1.71 2.43 15.6 127.0 2.80 3.06 0.28 2.29 5.64 1.04 3.92 1065.0 0
1 13.20 1.78 2.14 11.2 100.0 2.65 2.76 0.26 1.28 4.38 1.05 3.40 1050.0 0
2 13.16 2.36 2.67 18.6 101.0 2.80 3.24 0.30 2.81 5.68 1.03 3.17 1185.0 0
3 14.37 1.95 2.50 16.8 113.0 3.85 3.49 0.24 2.18 7.80 0.86 3.45 1480.0 0
4 13.24 2.59 2.87 21.0 118.0 2.80 2.69 0.39 1.82 4.32 1.04 2.93 735.0 0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
173 13.71 5.65 2.45 20.5 95.0 1.68 0.61 0.52 1.06 7.70 0.64 1.74 740.0 2
174 13.40 3.91 2.48 23.0 102.0 1.80 0.75 0.43 1.41 7.30 0.70 1.56 750.0 2
175 13.27 4.28 2.26 20.0 120.0 1.59 0.69 0.43 1.35 10.20 0.59 1.56 835.0 2
176 13.17 2.59 2.37 20.0 120.0 1.65 0.68 0.53 1.46 9.30 0.60 1.62 840.0 2
177 14.13 4.10 2.74 24.5 96.0 2.05 0.76 0.56 1.35 9.20 0.61 1.60 560.0 2

178 rows × 14 columns

  In [56]:
# 查看标签
print(wine.feature_names)
print(wine.target_names)
   
['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline']
['class_0' 'class_1' 'class_2']
  In [75]:
# 分出训练集和测试集
Xtrain,Xtest,Ytrain,Ytest = train_test_split(wine.data,wine.target,test_size=0.3)
print(Xtrain.shape)
   
(124, 13)
  In [76]:
# 实例化
clf = tree.DecisionTreeClassifier(criterion='entropy')
# 训练模型
clf = clf.fit(Xtrain,Ytrain)
# 衡量精确度
score = clf.score(Xtest,Ytest)
print(score)
   
0.9444444444444444
  In [77]:
# 画树
import graphviz
dot_data = tree.export_graphviz(clf
                                ,feature_names = wine.feature_names
                                ,class_names = ['Gin','Sherry','Vermouth']
                                ,filled=True
                               )
graph = graphviz.Source(dot_data)
graph
# 颜色越浅不纯度越高
  Out[77]: Tree 0 flavanoids <= 1.575 entropy = 1.569 samples = 124 value = [34, 49, 41] class = Sherry 1 color_intensity <= 3.725 entropy = 0.68 samples = 50 value = [0, 9, 41] class = Vermouth 0->1 True 6 alcohol <= 13.04 entropy = 0.995 samples = 74 value = [34, 40, 0] class = Sherry 0->6 False 2 entropy = 0.0 samples = 8 value = [0, 8, 0] class = Sherry 1->2 3 hue <= 0.97 entropy = 0.162 samples = 42 value = [0, 1, 41] class = Vermouth 1->3 4 entropy = 0.0 samples = 41 value = [0, 0, 41] class = Vermouth 3->4 5 entropy = 0.0 samples = 1 value = [0, 1, 0] class = Sherry 3->5 7 entropy = 0.0 samples = 37 value = [0, 37, 0] class = Sherry 6->7 8 magnesium <= 88.0 entropy = 0.406 samples = 37 value = [34, 3, 0] class = Gin 6->8 9 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 8->9 10 entropy = 0.0 samples = 34 value = [34, 0, 0] class = Gin 8->10   In [79]:
# 不一定会用到全部特征,可以看到特征重要性
clf.feature_importances_
importance = [*zip(wine.feature_names,clf.feature_importances_)]
sorted(importance,key=lambda x:x[1],reverse=True)
  Out[79]:
[('flavanoids', 0.4467047938462895),
 ('alcohol', 0.30132438648683235),
 ('color_intensity', 0.13972701286171754),
 ('magnesium', 0.0772032916950798),
 ('hue', 0.03504051511008071),
 ('malic_acid', 0.0),
 ('ash', 0.0),
 ('alcalinity_of_ash', 0.0),
 ('total_phenols', 0.0),
 ('nonflavanoid_phenols', 0.0),
 ('proanthocyanins', 0.0),
 ('od280/od315_of_diluted_wines', 0.0),
 ('proline', 0.0)]
   

每次运行结果不同,需要通过集成的方法找出最优节点

  In [132]:
# random_state & splitter
clf_1 = tree.DecisionTreeClassifier(criterion="entropy"
                                  ,random_state=42
                                  ,splitter='random')
clf_1 = clf_1.fit(Xtrain,Ytrain)
score = clf_1.score(Xtest,Ytest)
print(score)
   
0.9444444444444444
  In [143]:
dot1_data = tree.export_graphviz(clf_1
                                 ,feature_names=wine.feature_names
                                 ,class_names = ['Gin','Sherry','Vermouth']
                                 ,filled=True
                                )
graph1 = graphviz.Source(dot1_data)
graph1.render('wine_tree')
graph1
  Out[143]: Tree 0 od280/od315_of_diluted_wines <= 2.123 entropy = 1.569 samples = 124 value = [34, 49, 41] class = Sherry 1 hue <= 0.96 entropy = 0.575 samples = 44 value = [0, 6, 38] class = Vermouth 0->1 True 8 color_intensity <= 6.0 entropy = 1.184 samples = 80 value = [34, 43, 3] class = Sherry 0->8 False 2 od280/od315_of_diluted_wines <= 1.785 entropy = 0.378 samples = 41 value = [0, 3, 38] class = Vermouth 1->2 7 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 1->7 3 entropy = 0.0 samples = 28 value = [0, 0, 28] class = Vermouth 2->3 4 flavanoids <= 0.973 entropy = 0.779 samples = 13 value = [0, 3, 10] class = Vermouth 2->4 5 entropy = 0.0 samples = 10 value = [0, 0, 10] class = Vermouth 4->5 6 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 4->6 9 proline <= 461.506 entropy = 1.15 samples = 69 value = [23, 43, 3] class = Sherry 8->9 28 entropy = 0.0 samples = 11 value = [11, 0, 0] class = Gin 8->28 10 entropy = 0.0 samples = 20 value = [0, 20, 0] class = Sherry 9->10 11 alcohol <= 13.214 entropy = 1.271 samples = 49 value = [23, 23, 3] class = Gin 9->11 12 flavanoids <= 0.626 entropy = 0.871 samples = 27 value = [3, 22, 2] class = Sherry 11->12 21 total_phenols <= 1.753 entropy = 0.53 samples = 22 value = [20, 1, 1] class = Gin 11->21 13 entropy = 0.0 samples = 2 value = [0, 0, 2] class = Vermouth 12->13 14 proline <= 1025.703 entropy = 0.529 samples = 25 value = [3, 22, 0] class = Sherry 12->14 15 proline <= 710.714 entropy = 0.258 samples = 23 value = [1, 22, 0] class = Sherry 14->15 20 entropy = 0.0 samples = 2 value = [2, 0, 0] class = Gin 14->20 16 entropy = 0.0 samples = 19 value = [0, 19, 0] class = Sherry 15->16 17 color_intensity <= 3.53 entropy = 0.811 samples = 4 value = [1, 3, 0] class = Sherry 15->17 18 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 17->18 19 entropy = 0.0 samples = 1 value = [1, 0, 0] class = Gin 17->19 22 entropy = 0.0 samples = 1 value = [0, 0, 1] class = Vermouth 21->22 23 color_intensity <= 4.101 entropy = 0.276 samples = 21 value = [20, 1, 0] class = Gin 21->23 24 proline <= 734.781 entropy = 0.918 samples = 3 value = [2, 1, 0] class = Gin 23->24 27 entropy = 0.0 samples = 18 value = [18, 0, 0] class = Gin 23->27 25 entropy = 0.0 samples = 1 value = [0, 1, 0] class = Sherry 24->25 26 entropy = 0.0 samples = 2 value = [2, 0, 0] class = Gin 24->26    

剪枝参数

  In [145]:
# 我们对训练集的拟合程度如何?
score_train = clf.score(Xtrain,Ytrain)
score_train 
  Out[145]:
1.0
   
  • 参数1:max_dapth

    超过设定深度的树枝都会被剪掉,从3开始试验

  • 参数2:min_sample_leaf & min_sample_split

    min_sample_leaf 规定节点在分之后的子节点至少有min_sample_leaf个训练样本,一般从5 开始试,设置得太小会过拟合,太大不利于学习

    min_sample_split 限定每个节点至少要有这么min_sample_split个训练样本,否则该节点不被允许分支

  • 参数3:max_features & min_impurity_decrease

    超过max_features个数的特征会被舍弃

    小于min_impurity_decrease的分支不会发生(建议)

   

确定最优的剪枝参数——超参数曲线

  In [153]:
import matplotlib.pyplot as plt
test = []
for i in range(10):
    clf = tree.DecisionTreeClassifier(max_depth=i+1
                                      ,criterion="gini"
                                      ,random_state=42
                                    )
    clf = clf.fit(Xtrain,Ytrain)
    score = clf.score(Xtest,Ytest)
    test.append(score)

plt.plot(range(1,11),test,'-r.',label='max_depth')
plt.legend()
plt.show()
# 所以max_depth 选择 3 就好
   

标签:0.0,clf,分类,value,entropy,samples,sklearn,class,决策树
From: https://www.cnblogs.com/-simon-/p/16630927.html

相关文章

  • 综合案例-黑马旅游网_分类数据展示缓存优化
    综合案例-黑马旅游网_分类数据展示缓存优化分析发现,分类的数据在每一次页面加载后都会重写请求数据库来加载对数据库压力比较大而且分类的数据不会经常产生变化所有可......
  • 综合案例-黑马旅游网_分类数据展示功能
    综合案例-黑马旅游网_分类数据展示功能分析效果   后端代码实现CategoryDao接口packagecom.bai.dao;importcom.bai.domain.Category;importjava.util.......
  • bug的几种类型分类
    类型名称描述备注类型名称描述备注业务逻辑主要的业务流程走不通,出现错误,比如新增保存不成功........ 功能操作一些功能按钮无法进行操作,没反......
  • 决策树与集成
    DecisionTree目录DecisionTreeClassificationTreeRegressionTreeRegularizationProsandconsAssembleMethodBaggingBoostingTakeawayGreedy,Top-down,Recurrent......
  • C#中锁的使用分类
    1互斥锁lock(基于Monitor实现)定义:privatestaticreadonlyobjectLock=newobject();使用:lock(Lock){//todo}作用:将会锁住代码块的内容,并阻止其他线程进入该代......
  • 天煞NLP之我要毕业——实战第一课:新闻文本分类
    赛题之题目:(说实话,题目我都看不太懂,艹) 评测标准:第二个就是评测标准,沃土现在,我只知道f1的值(也就是f1_score越大越好)--_--|  看了几分钟,大概看懂了什么意思:就是......
  • 75. 颜色分类
    75.颜色分类给定一个包含红色、白色和蓝色、共 n个元素的数组 nums ,原地对它们进行排序,使得相同颜色的元素相邻,并按照红色、白色、蓝色顺序排列。我们使用整数0......
  • 4. 基础实战——FashionMNIST时装分类
    importosimportnumpyasnpimportpandasaspdimporttorchimporttorch.nnasnnimporttorch.optimasoptimfromtorch.utils.dataimportDataset,DataLoad......
  • 随笔分类 - Microsoft Dynamices CRM(2013, 2011)
    随笔分类-MicrosoftDynamicesCRM(2013,2011)MicrosoftDynamicsCRM数据库连接存储位置在哪里是在注册表里摘要:MicrosoftDynamicsCRM数据库连接存储......
  • 10--DSL查询文档-查询分类和基本语法
    elasticsearch的查询依然是基于JSON风格的DSL来实现的。 DSL查询分类Elasticsearch提供了基于JSON的DSL(DomainSpecificLanguage)来定义查询。常见的查询类型包括:(1)......