首页 > 其他分享 >【scikit-learn基础】--『监督学习』之 随机森林分类

【scikit-learn基础】--『监督学习』之 随机森林分类

时间:2024-01-10 12:33:45浏览次数:27  
标签:分类 -- scikit 算法 随机 learn 森林 reg 决策树

随机森林分类算法是一种基于集成学习(ensemble learning)的机器学习算法,
它的基本原理是通过对多个决策树的预测结果进行平均或投票,以产生最终的分类结果。

随机森林算法可用于回归分类问题。
关于随机森林算法在回归问题上的应用可参考:TODO

随机森林分类算法可以应用于各种需要进行分类或预测的问题,如垃圾邮件识别信用卡欺诈检测疾病预测等,
它也可以与其他机器学习算法进行结合,以进一步提高预测准确率。

1. 算法概述

随机森林的基本原理是构建多棵决策树,每棵树都是基于原始训练数据的一个随机子集进行训练。在构建每棵树时,算法会随机选择一部分特征进行考虑,而不是考虑所有的特征。

然后,对于一个新的输入样本,每棵树都会进行分类预测,并将预测结果提交给“森林”进行最终的分类决策。
一般来说,森林会选择出现次数最多的类别作为最终的分类结果。

理论上来看,随机森林分类应该比决策树分类有更加好的准确度,特别是在高维度的数据情况下。

2. 创建样本数据

为了后面比较随机森林分类算法和决策树算法的准确性,创建分类多一些(8个分类标签)的样本数据。

import matplotlib.pyplot as plt
from sklearn.datasets import make_classification

# 分类数据的样本生成器
X, y = make_classification(
    n_samples=1000, n_classes=8, n_clusters_per_class=2, n_informative=6
)
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=25)

plt.show()

image.png

3. 模型训练

首先,分割训练集测试集

from sklearn.model_selection import train_test_split

# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1)

这次按照9:1的比例来划分训练集和测试集。

决策树分类模型来训练数据:

from sklearn.tree import DecisionTreeClassifier

reg_names = [
    "ID3算法",
    "C4.5算法",
    "CART算法",
]

# 定义
regs = [
    DecisionTreeClassifier(criterion="entropy"),
    DecisionTreeClassifier(criterion="log_loss"),
    DecisionTreeClassifier(criterion="gini"),
]

# 训练模型
for reg in regs:
    reg.fit(X_train, y_train)

# 在测试集上进行预测
y_preds = []
for reg in regs:
    y_pred = reg.predict(X_test)
    y_preds.append(y_pred)

for i in range(len(y_preds)):
    correct_pred = np.sum(y_preds[i] == y_test)
    print("决策树【{}】 预测正确率:{:.2f}%".format(reg_names[i], correct_pred / len(y_pred) * 100))

# 运行结果
决策树【ID3算法】 预测正确率:43.00%
决策树【C4.5算法】 预测正确率:42.00%
决策树【CART算法】 预测正确率:42.00%

随机森林分类模型来训练数据:

from sklearn.ensemble import RandomForestClassifier

reg_names = [
    "ID3算法",
    "C4.5算法",
    "CART算法",
]

# 定义
regs = [
    RandomForestClassifier(criterion="entropy"),
    RandomForestClassifier(criterion="log_loss"),
    RandomForestClassifier(criterion="gini"),
]

# 训练模型
for reg in regs:
    reg.fit(X_train, y_train)

# 在测试集上进行预测
y_preds = []
for reg in regs:
    y_pred = reg.predict(X_test)
    y_preds.append(y_pred)

for i in range(len(y_preds)):
    correct_pred = np.sum(y_preds[i] == y_test)
    print("随机森林【{}】 预测正确率:{:.2f}%".format(reg_names[i], correct_pred / len(y_pred) * 100))

# 运行结果
随机森林【ID3算法】 预测正确率:64.00%
随机森林【C4.5算法】 预测正确率:63.00%
随机森林【CART算法】 预测正确率:69.00%

可以看出,随机森林分类的准确性确实比决策树分类提高了。
不过,运行过程中也可以发现,随机森林的训练时间会比决策树长一些。

4. 总结

随机森林分类算法的优势在于:

  1. 抗过拟合能力强:由于采用随机选择特征的方式,可以有效地避免过拟合问题。
  2. 泛化能力强:通过对多个决策树的结果进行投票或平均,可以获得更好的泛化性能。
  3. 对数据特征的选取具有指导性:在构建决策树时会对特征进行选择,这可以为后续的特征选择提供指导。
  4. 适用于大规模数据集:可以有效地处理大规模数据集,并且训练速度相对较快。

当然,随机森林分类算法也存在一些劣势:

  1. 需要大量的内存和计算资源:由于需要构建多个决策树,因此需要更多的内存和计算资源。
  2. 需要调整参数:性能很大程度上取决于参数的设置,如树的数量、每个节点的最小样本数等,这些参数的设置需要一定的经验和实验。
  3. 对新样本的预测性能不稳定:由于是通过投票或平均多个决策树的结果来进行预测,因此对新样本的预测性能可能会受到影响。

标签:分类,--,scikit,算法,随机,learn,森林,reg,决策树
From: https://www.cnblogs.com/wang_yb/p/17956232

相关文章

  • 图论
    目录深搜入门leetcode797广搜入门leetcode200深搜和广搜6951020深搜入门leetcode797因为也是二刷,推的比较快刷题之后的感悟,其实就是先把模板写上去了之后再在里面缝缝补补出题目要求都比较模板,变通一下思路就能做出来classSolution{public:vector<vector<int>>resu......
  • HarmonyOS (ArkTS)状态管理
    一、状态管理分为:页面级变量的状态管理  (主要用于单页面,同一个页面内不同组件之间的状态管理。)应用级变量的状态管理(主要用于多页面,同一个应用内,不同页面之间的状态管理。例如:A页面和B页面实现数据共享)1、页面级变量的状态管理@State、@Prop、@Link、@Provi......
  • 药品不良反应智能监测系统源码,ADR智能监测系统全套源码,
    ADR智能监测系统全套源码,药品不良反应智能监测系统源码ADR智能监测上报系统是基于医院临床数据中心而建立,运用信息技术实现药品不良反应的智能监测、报告管理、知识库查询、统计分析等功能。药品不良反应智能监测系统自动提取不良反应报告数据,主动实时监测临床发生的不良反应,第一时......
  • 大模型时代下的开发范式探索
    在大数据和深度学习技术的推动下,大模型已成为AI领域的主流趋势。这些庞大的模型拥有数亿甚至数十亿的参数,能够处理复杂的任务并实现令人惊叹的性能。然而,随着模型规模的扩大,开发、训练和部署的难度也急剧增加。如何在这样的时代背景下破茧重生,探索新的开发范式,成为摆在我们面前的重......
  • 你还在“垃圾”调优?快来看看JDK17的ZGC如何解放双手 | 京东云技术团队
    1、前言不要犹豫了,GC最大停顿时间小于1ms,支持16TB内存,这么高的性能提升,也不需要复杂的调优,节省了这个时间,你去陪对象不香嘛。上篇文章给大家带来了JDK11升级JDK17的最全实践,相信大家阅读后对于升级JDK17有了基本的了解。同时我们也会比较好奇,ZGC的原理是啥样的,怎么做到停顿时间那么......
  • 「云渲染知识」建筑效果图用什么软件制作?
    高品质的建筑效果图需要利用插件来模拟复杂的场景、光线照射和天气变化。然而,许多专业人士可能不清楚有哪些软件可以实现这样的效果。下面将介绍一些常用的软件来帮助实现高品质的建筑效果图。一、建筑效果图必备软件1、三维建模工具Autodesk3dsMax:强大的建模工具,用于创建复......
  • SciTech-Github-解决git push时的 Error: hasDotgit: contains '.git'
    AbaelsMacBookPro:pelicanabaelhe$gitpushEnumeratingobjects:6872,done.Countingobjects:100%(6872/6872),done.Deltacompressionusingupto8threadsCompressingobjects:100%(4305/4305),done.remote:error:object93c3f3e6d30672571d972693d0842a......
  • str 系列字符串操作函数
    str系列字符串操作函数主要包括strlen、strcpy、strcmp、strcat等。strlen函数用于统计字符串长度,strcpy函数用于将某个字符串复制到字符数组中,strcmp函数用于比较两个字符串的大小,strcat函数用于将两个字符串连接到一起。各个函数的具体格式如下所示:1#include<string.......
  • 8、SpringBoot2之打包及运行
    为了演示高级启动时动态配置参数的使用,本文在SpringBoot2之配置文件的基础上进行8.1、概述普通的web项目,会被打成一个war包,然后再将war包放到tomcat的webapps目录中;当tomcat启动时,在webapps目录中的war包会自动解压,此时便可访问该web项目的资源或服务;因为......
  • php 数据安全性(过滤提交的数据)
    1.在common.php公共方法加入/***过滤sql与php文件操作的关键字*/functionfilter_keyword($string){$keyword='select|insert|update|delete|\'|\/\*|\*|\.\.\/|\.\/|union|into|load_file|outfile';$arr=explode(......