sklearn中的决策树(1)—— 分类树¶
DecisionTreeClassifier¶
重要参数¶
-
Criterion: 不纯度,gini & entropy
entropy对不纯度更加敏感,即对不纯度的惩罚更强,由于这种特性,决策树的生长会更加“精细”,对高维数据很容易过拟合
实例: 红酒数据集¶
In [54]:from sklearn import tree from sklearn.datasets import load_wine from sklearn.model_selection import train_test_splitIn [55]:
#字典 wine = load_wine() # 用pandas整理数据 import pandas as pd pd.concat([pd.DataFrame(wine.data),pd.DataFrame(wine.target)],axis=1)Out[55]:
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 0 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 14.23 | 1.71 | 2.43 | 15.6 | 127.0 | 2.80 | 3.06 | 0.28 | 2.29 | 5.64 | 1.04 | 3.92 | 1065.0 | 0 |
1 | 13.20 | 1.78 | 2.14 | 11.2 | 100.0 | 2.65 | 2.76 | 0.26 | 1.28 | 4.38 | 1.05 | 3.40 | 1050.0 | 0 |
2 | 13.16 | 2.36 | 2.67 | 18.6 | 101.0 | 2.80 | 3.24 | 0.30 | 2.81 | 5.68 | 1.03 | 3.17 | 1185.0 | 0 |
3 | 14.37 | 1.95 | 2.50 | 16.8 | 113.0 | 3.85 | 3.49 | 0.24 | 2.18 | 7.80 | 0.86 | 3.45 | 1480.0 | 0 |
4 | 13.24 | 2.59 | 2.87 | 21.0 | 118.0 | 2.80 | 2.69 | 0.39 | 1.82 | 4.32 | 1.04 | 2.93 | 735.0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
173 | 13.71 | 5.65 | 2.45 | 20.5 | 95.0 | 1.68 | 0.61 | 0.52 | 1.06 | 7.70 | 0.64 | 1.74 | 740.0 | 2 |
174 | 13.40 | 3.91 | 2.48 | 23.0 | 102.0 | 1.80 | 0.75 | 0.43 | 1.41 | 7.30 | 0.70 | 1.56 | 750.0 | 2 |
175 | 13.27 | 4.28 | 2.26 | 20.0 | 120.0 | 1.59 | 0.69 | 0.43 | 1.35 | 10.20 | 0.59 | 1.56 | 835.0 | 2 |
176 | 13.17 | 2.59 | 2.37 | 20.0 | 120.0 | 1.65 | 0.68 | 0.53 | 1.46 | 9.30 | 0.60 | 1.62 | 840.0 | 2 |
177 | 14.13 | 4.10 | 2.74 | 24.5 | 96.0 | 2.05 | 0.76 | 0.56 | 1.35 | 9.20 | 0.61 | 1.60 | 560.0 | 2 |
178 rows × 14 columns
In [56]:# 查看标签 print(wine.feature_names) print(wine.target_names)
['alcohol', 'malic_acid', 'ash', 'alcalinity_of_ash', 'magnesium', 'total_phenols', 'flavanoids', 'nonflavanoid_phenols', 'proanthocyanins', 'color_intensity', 'hue', 'od280/od315_of_diluted_wines', 'proline'] ['class_0' 'class_1' 'class_2']In [75]:
# 分出训练集和测试集 Xtrain,Xtest,Ytrain,Ytest = train_test_split(wine.data,wine.target,test_size=0.3) print(Xtrain.shape)
(124, 13)In [76]:
# 实例化 clf = tree.DecisionTreeClassifier(criterion='entropy') # 训练模型 clf = clf.fit(Xtrain,Ytrain) # 衡量精确度 score = clf.score(Xtest,Ytest) print(score)
0.9444444444444444In [77]:
# 画树 import graphviz dot_data = tree.export_graphviz(clf ,feature_names = wine.feature_names ,class_names = ['Gin','Sherry','Vermouth'] ,filled=True ) graph = graphviz.Source(dot_data) graph # 颜色越浅不纯度越高Out[77]: Tree 0 flavanoids <= 1.575 entropy = 1.569 samples = 124 value = [34, 49, 41] class = Sherry 1 color_intensity <= 3.725 entropy = 0.68 samples = 50 value = [0, 9, 41] class = Vermouth 0->1 True 6 alcohol <= 13.04 entropy = 0.995 samples = 74 value = [34, 40, 0] class = Sherry 0->6 False 2 entropy = 0.0 samples = 8 value = [0, 8, 0] class = Sherry 1->2 3 hue <= 0.97 entropy = 0.162 samples = 42 value = [0, 1, 41] class = Vermouth 1->3 4 entropy = 0.0 samples = 41 value = [0, 0, 41] class = Vermouth 3->4 5 entropy = 0.0 samples = 1 value = [0, 1, 0] class = Sherry 3->5 7 entropy = 0.0 samples = 37 value = [0, 37, 0] class = Sherry 6->7 8 magnesium <= 88.0 entropy = 0.406 samples = 37 value = [34, 3, 0] class = Gin 6->8 9 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 8->9 10 entropy = 0.0 samples = 34 value = [34, 0, 0] class = Gin 8->10 In [79]:
# 不一定会用到全部特征,可以看到特征重要性 clf.feature_importances_ importance = [*zip(wine.feature_names,clf.feature_importances_)] sorted(importance,key=lambda x:x[1],reverse=True)Out[79]:
[('flavanoids', 0.4467047938462895), ('alcohol', 0.30132438648683235), ('color_intensity', 0.13972701286171754), ('magnesium', 0.0772032916950798), ('hue', 0.03504051511008071), ('malic_acid', 0.0), ('ash', 0.0), ('alcalinity_of_ash', 0.0), ('total_phenols', 0.0), ('nonflavanoid_phenols', 0.0), ('proanthocyanins', 0.0), ('od280/od315_of_diluted_wines', 0.0), ('proline', 0.0)]
每次运行结果不同,需要通过集成的方法找出最优节点
In [132]:# random_state & splitter clf_1 = tree.DecisionTreeClassifier(criterion="entropy" ,random_state=42 ,splitter='random') clf_1 = clf_1.fit(Xtrain,Ytrain) score = clf_1.score(Xtest,Ytest) print(score)
0.9444444444444444In [143]:
dot1_data = tree.export_graphviz(clf_1 ,feature_names=wine.feature_names ,class_names = ['Gin','Sherry','Vermouth'] ,filled=True ) graph1 = graphviz.Source(dot1_data) graph1.render('wine_tree') graph1Out[143]: Tree 0 od280/od315_of_diluted_wines <= 2.123 entropy = 1.569 samples = 124 value = [34, 49, 41] class = Sherry 1 hue <= 0.96 entropy = 0.575 samples = 44 value = [0, 6, 38] class = Vermouth 0->1 True 8 color_intensity <= 6.0 entropy = 1.184 samples = 80 value = [34, 43, 3] class = Sherry 0->8 False 2 od280/od315_of_diluted_wines <= 1.785 entropy = 0.378 samples = 41 value = [0, 3, 38] class = Vermouth 1->2 7 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 1->7 3 entropy = 0.0 samples = 28 value = [0, 0, 28] class = Vermouth 2->3 4 flavanoids <= 0.973 entropy = 0.779 samples = 13 value = [0, 3, 10] class = Vermouth 2->4 5 entropy = 0.0 samples = 10 value = [0, 0, 10] class = Vermouth 4->5 6 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 4->6 9 proline <= 461.506 entropy = 1.15 samples = 69 value = [23, 43, 3] class = Sherry 8->9 28 entropy = 0.0 samples = 11 value = [11, 0, 0] class = Gin 8->28 10 entropy = 0.0 samples = 20 value = [0, 20, 0] class = Sherry 9->10 11 alcohol <= 13.214 entropy = 1.271 samples = 49 value = [23, 23, 3] class = Gin 9->11 12 flavanoids <= 0.626 entropy = 0.871 samples = 27 value = [3, 22, 2] class = Sherry 11->12 21 total_phenols <= 1.753 entropy = 0.53 samples = 22 value = [20, 1, 1] class = Gin 11->21 13 entropy = 0.0 samples = 2 value = [0, 0, 2] class = Vermouth 12->13 14 proline <= 1025.703 entropy = 0.529 samples = 25 value = [3, 22, 0] class = Sherry 12->14 15 proline <= 710.714 entropy = 0.258 samples = 23 value = [1, 22, 0] class = Sherry 14->15 20 entropy = 0.0 samples = 2 value = [2, 0, 0] class = Gin 14->20 16 entropy = 0.0 samples = 19 value = [0, 19, 0] class = Sherry 15->16 17 color_intensity <= 3.53 entropy = 0.811 samples = 4 value = [1, 3, 0] class = Sherry 15->17 18 entropy = 0.0 samples = 3 value = [0, 3, 0] class = Sherry 17->18 19 entropy = 0.0 samples = 1 value = [1, 0, 0] class = Gin 17->19 22 entropy = 0.0 samples = 1 value = [0, 0, 1] class = Vermouth 21->22 23 color_intensity <= 4.101 entropy = 0.276 samples = 21 value = [20, 1, 0] class = Gin 21->23 24 proline <= 734.781 entropy = 0.918 samples = 3 value = [2, 1, 0] class = Gin 23->24 27 entropy = 0.0 samples = 18 value = [18, 0, 0] class = Gin 23->27 25 entropy = 0.0 samples = 1 value = [0, 1, 0] class = Sherry 24->25 26 entropy = 0.0 samples = 2 value = [2, 0, 0] class = Gin 24->26
剪枝参数¶
In [145]:# 我们对训练集的拟合程度如何? score_train = clf.score(Xtrain,Ytrain) score_trainOut[145]:
1.0
-
参数1:max_dapth
超过设定深度的树枝都会被剪掉,从3开始试验
-
参数2:min_sample_leaf & min_sample_split
min_sample_leaf 规定节点在分之后的子节点至少有min_sample_leaf个训练样本,一般从5 开始试,设置得太小会过拟合,太大不利于学习
min_sample_split 限定每个节点至少要有这么min_sample_split个训练样本,否则该节点不被允许分支
-
参数3:max_features & min_impurity_decrease
超过max_features个数的特征会被舍弃
小于min_impurity_decrease的分支不会发生(建议)
确定最优的剪枝参数——超参数曲线¶
In [153]:import matplotlib.pyplot as plt test = [] for i in range(10): clf = tree.DecisionTreeClassifier(max_depth=i+1 ,criterion="gini" ,random_state=42 ) clf = clf.fit(Xtrain,Ytrain) score = clf.score(Xtest,Ytest) test.append(score) plt.plot(range(1,11),test,'-r.',label='max_depth') plt.legend() plt.show() # 所以max_depth 选择 3 就好标签:0.0,clf,分类,value,entropy,samples,sklearn,class,决策树 From: https://www.cnblogs.com/-simon-/p/16630927.html