一、决策树定义:
决策树是一种基于树结构的机器学习算法,用于建立一系列的规则来对数据进行分类或预测。
二、决策树特征选择
2.1 特征选择问题
在决策树的构建过程中,特征选择是一个关键的步骤,它决定了每个节点应该选择哪个特征来进行分裂。
2.2 信息增益
信息增益是一种常用的特征选择准则,它衡量了在特征给定的条件下,分类结果的不确定性减少的程度。
2.2.1 熵
熵是信息论中用来表示随机变量不确定性的度量,它的计算公式为:
[H(X) = -\sum_{i=1}^{n} P(x_i) \log_2 P(x_i)]
其中,(P(x_i)) 表示随机变量 (X) 取值为 (x_i) 的概率。
2.2.2 信息增益
信息增益表示在特征 (A) 给定的条件下,分类结果的熵减少的程度,它的计算公式为:
[Gain(A) = H(D) - H(D|A)]
其中,(H(D)) 表示数据集 (D) 的熵,(H(D|A)) 表示在特征 (A) 给定的条件下,数据集 (D) 的条件熵。
三、决策树的生成
决策树的生成是指根据训练数据集生成决策树的过程,常用的算法包括ID3算法和C4.5算法。
3.1 ID3算法
ID3算法是一种基于信息增益准则的决策树生成算法,其主要步骤包括:
3.1.1 理论推导
ID3算法的核心思想是选择能够最大化信息增益的特征进行分裂。
3.2 C4.5算法
C4.5算法是ID3算法的改进版,它使用信息增益比来解决ID3算法的缺陷。
3.2.1 理论推导
C4.5算法在选择特征时,使用信息增益比来代替信息增益,以解决特征取值数目较多时信息增益偏向于取值数目较多的特征的问题。
四、决策树的剪枝
决策树的剪枝是为了防止过拟合,常用的剪枝算法包括预剪枝和后剪枝。
4.1 原理
剪枝的原理是通过降低树的复杂度来提高泛化能力,从而减少过拟合的风险。
4.2 算法思路:
预剪枝在决策树构建过程中,在每次进行分裂前先进行判断,若判断结果表明分裂会导致过拟合,则停止分裂。
五、CART算法
CART(Classification and Regression Trees)算法是一种既能用于分类又能用于回归的决策树算法。
5.1 CART生成
CART生成算法包括分类树的生成和回归树的生成。
5.1.1 回归树的生成
回归树的生成目标是将数据集划分成尽可能纯的子集,使得每个子集内的样本的目标变量的方差尽可能小。
5.1.2 分类树的生成
分类树的生成目标是将数据集划分成尽可能纯的子集,使得每个子集内的样本属于同一类别。
5.1.3 CART生成算法
CART生成算法主要通过递归地进行二分来生成树,直到满足停止条件为止。
5.2 CART剪枝
CART剪枝通过调整树的复杂度来提高泛化能力,常用的剪枝方法包括代价复杂度剪枝(Cost Complexity Pruning)。
六、代码
6.1 代码
以下用分类树举例
以下是决策树的简单代码示例:
`#数据准备
from sklearn.datasets import load_breast_cancer
breast_cancer = load_breast_cancer()
分离数据
breast_cancer
x=breast_cancer.data
y=breast_cancer.target
训练数据
from sklearn.model_selection import train_test_split
x_train,x_test,y_train,y_test = train_test_split(x,y,random_state=33,test_size=0.3)
数据标准化
from sklearn.preprocessing import StandardScaler
breast_cancer_ss = StandardScaler()
x_train = breast_cancer_ss.fit_transform(x_train)
x_test = breast_cancer_ss.transform(x_test)
分类树
from sklearn.tree import DecisionTreeClassifier
dtc = DecisionTreeClassifier()
dtc.fit(x_train,y_train)
dtc_y_predict = dtc.predict(x_test)
from sklearn.metrics import classification_report
k=0
j=0
for i in y_test:
if i!=dtc_y_predict[j]:
k=k+1
j=j+1
print(k)
print('预测结果:\n:',dtc_y_predict)
print('真是结果:\n:',y_test)
print('Accuracy:',dtc.score(x_test,y_test))
print(classification_report(y_test,dtc_y_predict,target_names=['benign','malignant']))`
6.2结果