首页 > 编程语言 >机器学习算法(五):基于企鹅数据集的决策树分类预测

机器学习算法(五):基于企鹅数据集的决策树分类预测

时间:2023-03-25 22:45:04浏览次数:48  
标签:企鹅 ## 模型 predict 算法 test data 决策树

机器学习算法(五):基于企鹅数据集的决策树分类预测

1 逻决策树的介绍和应用

1.1 决策树的介绍

决策树是一种常见的分类模型,在金融风控、医疗辅助诊断等诸多行业具有较为广泛的应用。决策树的核心思想是基于树结构对数据进行划分,这种思想是人类处理问题时的本能方法。例如在婚恋市场中,女方通常会先询问男方是否有房产,如果有房产再了解是否有车产,如果有车产再看是否有稳定工作……最后得出是否要深入了解的判断。

主要应用:

由于决策树模型中自变量与因变量的非线性关系以及决策树简单的计算方法,使得它成为集成学习中最为广泛使用的基模型。梯度提升树(GBDT),XGBoost以及LightGBM等先进的集成模型都采用了决策树作为基模型,在广告计算、CTR预估、金融风控等领域大放异彩,成为当今与神经网络相提并论的复杂模型,更是数据挖掘比赛中的常客。在新的研究中,南京大学周志华教授提出一种多粒度级联森林模型,创造了一种全新的基于决策树的深度集成方法,为我们提供了决策树发展的另一种可能。

同时决策树在一些明确需要可解释性或者提取分类规则的场景中被广泛应用,而其他机器学习模型在这一点很难做到。例如在医疗辅助系统中,为了方便专业人员发现错误,常常将决策树算法用于辅助病症检测。例如在一个预测哮喘患者的模型中,医生发现测试的许多高级模型的效果非常差。在他们运行了一个决策树模型后发现,算法认为剧烈咳嗽的病人患哮喘的风险很小。但医生非常清楚剧烈咳嗽一般都会被立刻检查治疗,这意味着患有剧烈咳嗽的哮喘病人都会马上得到收治。用于建模的数据认为这类病人风险很小,是因为所有这类病人都得到了及时治疗,所以极少有人在此之后患病或死亡。

1.2 相关流程

  • 了解 决策树 的理论知识
  • 掌握 决策树 的 sklearn 函数调用并将其运用在企鹅数据集的预测中

Part1 Demo实践

  • Step1:库函数导入
  • Step2:模型训练
  • Step3:数据和模型可视化
  • Step4:模型预测

Part2 基于企鹅(penguins)数据集的决策树分类实践

  • Step1:库函数导入
  • Step2:数据读取/载入
  • Step3:数据信息简单查看
  • Step4:可视化描述
  • Step5:利用 决策树模型 在二分类上 进行训练和预测
  • Step6:利用 决策树模型 在三分类(多分类)上 进行训练和预测

3 算法实战

3.1Demo实践

Step1: 库函数导入

##  基础函数库
import numpy as np 

## 导入画图库
import matplotlib.pyplot as plt
import seaborn as sns

## 导入决策树模型函数
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree

Step2: 训练模型

##Demo演示LogisticRegression分类

## 构造数据集
x_fearures = np.array([[-1, -2], [-2, -1], [-3, -2], [1, 3], [2, 1], [3, 2]])
y_label = np.array([0, 1, 0, 1, 0, 1])

## 调用决策树回归模型
tree_clf = DecisionTreeClassifier()

## 调用决策树模型拟合构造的数据集
tree_clf = tree_clf.fit(x_fearures, y_label)

Step3: 数据和模型可视化(需要用到graphviz可视化库)

## 可视化构造的数据样本点
plt.figure()
plt.scatter(x_fearures[:,0],x_fearures[:,1], c=y_label, s=50, cmap='viridis')
plt.title('Dataset')
plt.show()

## 可视化决策树
import graphviz
dot_data = tree.export_graphviz(tree_clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("pengunis")
'pengunis.pdf'

Step4:模型预测

## 创建新样本
x_fearures_new1 = np.array([[0, -1]])
x_fearures_new2 = np.array([[2, 1]])

## 在训练集和测试集上分布利用训练好的模型进行预测
y_label_new1_predict = tree_clf.predict(x_fearures_new1)
y_label_new2_predict = tree_clf.predict(x_fearures_new2)

print('The New point 1 predict class:\n',y_label_new1_predict)
print('The New point 2 predict class:\n',y_label_new2_predict)
The New point 1 predict class:
 [1]
The New point 2 predict class:
 [0]

3.2 基于penguins_raw数据集的决策树实战

在实践的最开始,我们首先需要导入一些基础的函数库包括:numpy (Python进行科学计算的基础软件包),pandas(pandas是一种快速,强大,灵活且易于使用的开源数据分析和处理工具),matplotlib和seaborn绘图。

#下载需要用到的数据集
!wget https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/6tree/penguins_raw.csv
--2023-03-22 16:21:32--  https://tianchi-media.oss-cn-beijing.aliyuncs.com/DSW/6tree/penguins_raw.csv
正在解析主机 tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)... 49.7.22.39
正在连接 tianchi-media.oss-cn-beijing.aliyuncs.com (tianchi-media.oss-cn-beijing.aliyuncs.com)|49.7.22.39|:443... 已连接。
已发出 HTTP 请求,正在等待回应... 200 OK
长度: 53098 (52K) [text/csv]
正在保存至: “penguins_raw.csv”

penguins_raw.csv    100%[===================>]  51.85K  --.-KB/s    in 0.04s   

2023-03-22 16:21:33 (1.23 MB/s) - 已保存 “penguins_raw.csv” [53098/53098])

Step1:函数库导入

##  基础函数库
import numpy as np 
import pandas as pd

## 绘图函数库
import matplotlib.pyplot as plt
import seaborn as sns

本次我们选择企鹅数据(palmerpenguins)进行方法的尝试训练,该数据集一共包含8个变量,其中7个特征变量,1个目标分类变量。共有150个样本,目标变量为 企鹅的类别 其都属于企鹅类的三个亚属,分别是(Adélie, Chinstrap and Gentoo)。包含的三种种企鹅的七个特征,分别是所在岛屿,嘴巴长度,嘴巴深度,脚蹼长度,身体体积,性别以及年龄。

变量 描述
species a factor denoting penguin species
island a factor denoting island in Palmer Archipelago, Antarctica
bill_length_mm a number denoting bill length
bill_depth_mm a number denoting bill depth
flipper_length_mm an integer denoting flipper length
body_mass_g an integer denoting body mass
sex a factor denoting penguin sex
year an integer denoting the study year

Step2:数据读取/载入

## 我们利用Pandas自带的read_csv函数读取并转化为DataFrame格式

data = pd.read_csv('./penguins_raw.csv')
## 为了方便我们仅选取四个简单的特征,有兴趣的同学可以研究下其他特征的含义以及使用方法
data = data[['Species','Culmen Length (mm)','Culmen Depth (mm)',
            'Flipper Length (mm)','Body Mass (g)']]

Step3:数据信息简单查看

## 利用.info()查看数据的整体信息
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 344 entries, 0 to 343
Data columns (total 5 columns):
 #   Column               Non-Null Count  Dtype  
---  ------               --------------  -----  
 0   Species              344 non-null    object 
 1   Culmen Length (mm)   342 non-null    float64
 2   Culmen Depth (mm)    342 non-null    float64
 3   Flipper Length (mm)  342 non-null    float64
 4   Body Mass (g)        342 non-null    float64
dtypes: float64(4), object(1)
memory usage: 13.6+ KB
## 进行简单的数据查看,我们可以利用 .head() 头部.tail()尾部
data.head()
Species Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Body Mass (g)
0 Adelie Penguin (Pygoscelis adeliae) 39.1 18.7 181.0 3750.0
1 Adelie Penguin (Pygoscelis adeliae) 39.5 17.4 186.0 3800.0
2 Adelie Penguin (Pygoscelis adeliae) 40.3 18.0 195.0 3250.0
3 Adelie Penguin (Pygoscelis adeliae) NaN NaN NaN NaN
4 Adelie Penguin (Pygoscelis adeliae) 36.7 19.3 193.0 3450.0

这里我们发现数据集中存在NaN,一般的我们认为NaN在数据集中代表了缺失值,可能是数据采集或处理时产生的一种错误。这里我们采用-1将缺失值进行填补,还有其他例如“中位数填补、平均数填补”的缺失值处理方法有兴趣的同学也可以尝试。

data = data.fillna(-1)
data.tail()
Species Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Body Mass (g)
339 Chinstrap penguin (Pygoscelis antarctica) 55.8 19.8 207.0 4000.0
340 Chinstrap penguin (Pygoscelis antarctica) 43.5 18.1 202.0 3400.0
341 Chinstrap penguin (Pygoscelis antarctica) 49.6 18.2 193.0 3775.0
342 Chinstrap penguin (Pygoscelis antarctica) 50.8 19.0 210.0 4100.0
343 Chinstrap penguin (Pygoscelis antarctica) 50.2 18.7 198.0 3775.0
## 其对应的类别标签为'Adelie Penguin', 'Gentoo penguin', 'Chinstrap penguin'三种不同企鹅的类别。
data['Species'].unique()
array(['Adelie Penguin (Pygoscelis adeliae)',
       'Gentoo penguin (Pygoscelis papua)',
       'Chinstrap penguin (Pygoscelis antarctica)'], dtype=object)
## 利用value_counts函数查看每个类别数量
pd.Series(data['Species']).value_counts()
Adelie Penguin (Pygoscelis adeliae)          152
Gentoo penguin (Pygoscelis papua)            124
Chinstrap penguin (Pygoscelis antarctica)     68
Name: Species, dtype: int64
## 对于特征进行一些统计描述
data.describe()
Culmen Length (mm) Culmen Depth (mm) Flipper Length (mm) Body Mass (g)
count 344.000000 344.000000 344.000000 344.000000
mean 43.660756 17.045640 199.741279 4177.319767
std 6.428957 2.405614 20.806759 861.263227
min -1.000000 -1.000000 -1.000000 -1.000000
25% 39.200000 15.500000 190.000000 3550.000000
50% 44.250000 17.300000 197.000000 4025.000000
75% 48.500000 18.700000 213.000000 4750.000000
max 59.600000 21.500000 231.000000 6300.000000

Step4:可视化描述

## 特征与标签组合的散点可视化
sns.pairplot(data=data, diag_kind='hist', hue= 'Species')
plt.show()

从上图可以发现,在2D情况下不同的特征组合对于不同类别的企鹅的散点分布,以及大概的区分能力。Culmen Lenth与其他特征的组合散点的重合较少,所以对于数据集的划分能力最好。

我们发现

'''为了方便我们将标签转化为数字
       'Adelie Penguin (Pygoscelis adeliae)'        ------0
       'Gentoo penguin (Pygoscelis papua)'          ------1
       'Chinstrap penguin (Pygoscelis antarctica)   ------2 '''

def trans(x):
    if x == data['Species'].unique()[0]:
        return 0
    if x == data['Species'].unique()[1]:
        return 1
    if x == data['Species'].unique()[2]:
        return 2

data['Species'] = data['Species'].apply(trans)
for col in data.columns:
    if col != 'Species':
        sns.boxplot(x='Species', y=col, saturation=0.5, palette='pastel', data=data)
        plt.title(col)
        plt.show()




利用箱型图我们也可以得到不同类别在不同特征上的分布差异情况。

# 选取其前三个特征绘制三维散点图
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10,8))
ax = fig.add_subplot(111, projection='3d')

data_class0 = data[data['Species']==0].values
data_class1 = data[data['Species']==1].values
data_class2 = data[data['Species']==2].values
# 'setosa'(0), 'versicolor'(1), 'virginica'(2)
ax.scatter(data_class0[:,0], data_class0[:,1], data_class0[:,2],label=data['Species'].unique()[0])
ax.scatter(data_class1[:,0], data_class1[:,1], data_class1[:,2],label=data['Species'].unique()[1])
ax.scatter(data_class2[:,0], data_class2[:,1], data_class2[:,2],label=data['Species'].unique()[2])
plt.legend()

plt.show()

Step5:利用 决策树模型 在二分类上 进行训练和预测

## 为了正确评估模型性能,将数据划分为训练集和测试集,并在训练集上训练模型,在测试集上验证模型性能。
from sklearn.model_selection import train_test_split

## 选择其类别为0和1的样本 (不包括类别为2的样本)
data_target_part = data[data['Species'].isin([0,1])][['Species']]
data_features_part = data[data['Species'].isin([0,1])][['Culmen Length (mm)','Culmen Depth (mm)',
            'Flipper Length (mm)','Body Mass (g)']]

## 测试集大小为20%, 80%/20%分
x_train, x_test, y_train, y_test = train_test_split(data_features_part, data_target_part, test_size = 0.2, random_state = 2020)
## 从sklearn中导入决策树模型
from sklearn.tree import DecisionTreeClassifier
from sklearn import tree
## 定义 决策树模型 
clf = DecisionTreeClassifier(criterion='entropy')
# 在训练集上训练决策树模型
clf.fit(x_train, y_train)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='entropy',
                       max_depth=None, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')
## 可视化
import graphviz
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("penguins")
'penguins.pdf'

## 在训练集和测试集上分布利用训练好的模型进行预测
train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)
from sklearn import metrics

## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))

## 查看混淆矩阵 (预测值和真实值的各类情况统计矩阵)
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
The accuracy of the Logistic Regression is: 0.9954545454545455
The accuracy of the Logistic Regression is: 1.0
The confusion matrix result:
 [[31  0]
 [ 0 25]]

我们可以发现其准确度为1,代表所有的样本都预测正确了。

Step6:利用 决策树模型 在三分类(多分类)上 进行训练和预测

## 测试集大小为20%, 80%/20%分
x_train, x_test, y_train, y_test = train_test_split(data[['Culmen Length (mm)','Culmen Depth (mm)',
            'Flipper Length (mm)','Body Mass (g)']], data[['Species']], test_size = 0.2, random_state = 2020)
## 定义 决策树模型 
clf = DecisionTreeClassifier()
# 在训练集上训练决策树模型
clf.fit(x_train, y_train)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=None, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')
## 在训练集和测试集上分布利用训练好的模型进行预测
train_predict = clf.predict(x_train)
test_predict = clf.predict(x_test)

## 由于决策树模型是概率预测模型(前文介绍的 p = p(y=1|x,\theta)),所有我们可以利用 predict_proba 函数预测其概率
train_predict_proba = clf.predict_proba(x_train)
test_predict_proba = clf.predict_proba(x_test)

print('The test predict Probability of each class:\n',test_predict_proba)
## 其中第一列代表预测为0类的概率,第二列代表预测为1类的概率,第三列代表预测为2类的概率。

## 利用accuracy(准确度)【预测正确的样本数目占总预测样本数目的比例】评估模型效果
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_train,train_predict))
print('The accuracy of the Logistic Regression is:',metrics.accuracy_score(y_test,test_predict))
The test predict Probability of each class:
 [[0. 0. 1.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 1. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [0. 1. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [0. 0. 1.]
 [1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]
 [1. 0. 0.]]
The accuracy of the Logistic Regression is: 0.9963636363636363
The accuracy of the Logistic Regression is: 0.9565217391304348
## 查看混淆矩阵
confusion_matrix_result = metrics.confusion_matrix(test_predict,y_test)
print('The confusion matrix result:\n',confusion_matrix_result)

# 利用热力图对于结果进行可视化
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix_result, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
The confusion matrix result:
 [[30  1  0]
 [ 0 23  0]
 [ 2  0 13]]

3.3 重要知识点

3.3.1 决策树构建的伪代码

输入: 训练集D={($x_1$,$y_1$),($x_2$,$y_2$),....,($x_m$,$y_m$)};
特征集A={$a_1$,$a_2$,....,$a_d$}

输出: 以node为根节点的一颗决策树

过程:函数TreeGenerate($D$,$A$)

  1. 生成节点node
  2. $if$ $D$中样本全书属于同一类别$C$ $then$:
  3. ----将node标记为$C$类叶节点;$return$
  4. $if$ $A$ = 空集 OR D中样本在$A$上的取值相同 $then$:
  5. ----将node标记为叶节点,其类别标记为$D$中样本数最多的类;$return$
  6. 从 $A$ 中选择最优划分属性 $a_*$;
  7. $for$ $a_$ 的每一个值 $a_^v$ $do$:
  8. ----为node生成一个分支,令$D_v$表示$D$中在$a_$上取值为$a_^v$的样本子集;
  9. ----$if$ $D_v$ 为空 $then$:
  10. --------将分支节点标记为叶节点,其类别标记为$D$中样本最多的类;$then$
  11. ----$else$:
  12. --------以 TreeGenerate($D_v$,$A${$a_*$})为分支节点

决策树的构建过程是一个递归过程。函数存在三种返回状态:(1)当前节点包含的样本全部属于同一类别,无需继续划分;(2)当前属性集为空或者所有样本在某个属性上的取值相同,无法继续划分;(3)当前节点包含的样本集合为空,无法划分。

3.3.2 划分选择

从上述伪代码中我们发现,决策树的关键在于line6.从$A$中选择最优划分属性$

标签:企鹅,##,模型,predict,算法,test,data,决策树
From: https://www.cnblogs.com/ting1/p/17255819.html

相关文章

  • 算法—二分法详解
    二分法详解目录二分法详解1.二分法2.引论:猜数游戏3.整数域二分1、在单调递增序列中找x或者x的后继2、在单调递增序列中查找x或者x的前驱3.简易二分模板4.浮点数二......
  • KMP字符串匹配算法
     KMP算法的要点是避免回溯和Next[]数组,其中,Next[]数组中存的是最长公共前后缀的长度. 1.KMP模板例题:HDU2087剪花布条 intNext[N],cnt;//构建Next[]数组voidg......
  • 用C语言实现ElGamal算法
    缘起是我侄子问的题目,参考了书籍、博客,花了一些时间完成的,丢掉可惜了,记录下来吧。这个程序还有些缺陷,数值太大时计算结果会溢出代码#include<stdio.h>#include<time.h>......
  • 算法总结--ST表
    声明(叠甲):鄙人水平有限,本文为作者的学习总结,仅供参考。1.RMQ介绍在开始介绍ST表前,我们先了解以下它以用的场景RMQ问题。RMQ(RangeMinimum/MaximumQuery)问......
  • m基于分段蚁群算法优化SVM的数据预测matlab仿真
    1.算法描述支持向量机(supportvectormachines,SVM)是二分类算法,所谓二分类即把具有多个特性(属性)的数据分为两类,目前主流机器学习算法中,神经网络等其他机器学习模型已经能......
  • m基于分段蚁群算法优化SVM的数据预测matlab仿真
    1.算法描述       支持向量机(supportvectormachines,SVM)是二分类算法,所谓二分类即把具有多个特性(属性)的数据分为两类,目前主流机器学习算法中,神经网络等其他机器......
  • 知识图谱推荐算法-基于嵌入的推荐方法
    基于嵌入的方法使用知识图谱中的信息来丰富用户或项目的表示,通过知识图谱嵌入将知识图谱中的实体和关系表征为低维向量,保留了知识图谱原有的结构。知识图谱通常存在链......
  • 2022牛客寒假算法基础集训营2 签到题7题
    1、C小沙的杀球如果你能够杀球但不杀球,虽然回复了体力,但你后续可能会没有机会继续杀球,并且杀球次数相同,那么回复的体力是相同的,所以在同等条件下,我们应该尽可能多的杀球。......
  • 2022牛客寒假算法基础集训营1 签到题7题
    1、L.牛牛学走路恭喜你签到成功#include<bits/stdc++.h>usingnamespacestd;intmain(){intT;cin>>T;while(T--){intn;cin>>n;strin......
  • 2022牛客寒假算法基础集训营3 签到题7题(附基础集训营1-3签到题总结)
    1、A-智乃的HelloXXXX签到#include<bits/stdc++.h>usingnamespacestd;intmain(){cout<<"hellochino\n";return0;}2、B-智乃买瓜背包#include<bits/stdc+......