首页 > 其他分享 >【机器学习实战】用sklearn玩转随机森林,分类准确率提升秘籍!

【机器学习实战】用sklearn玩转随机森林,分类准确率提升秘籍!

时间:2024-09-13 12:52:40浏览次数:3  
标签:search 准确率 train grid 玩转 test import best sklearn

在机器学习的世界里,随机森林算法以其出色的分类和回归能力而闻名。我们将深入sklearn库中的随机森林,探索如何通过实战提升模型的分类准确率。

一 随机森林算法简介

随机森林是一种集成学习方法,通过构建多个决策树并综合它们的预测结果来提高预测准确性。每个决策树都是在训练数据的一个随机子集上构建的,这种方法减少了模型间的相关性,从而增强了整体模型的泛化能力。

理论详情,请查看往期文章:揭秘Bagging与随机森林:构建更强大预测模型的秘密

二 sklean实战

在 SKLearn 中,随机森林算法被封装在RandomForestClassifierRandomForestRegressor两个类中,分别用于分类和回归问题。这两个类提供了丰富的参数和方法,使得我们可以轻松地构建和调优随机森林模型。

1. 导入库和数据

首先,我们需要导入必要的库,并加载数据集。

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 将数据划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

2. 创建随机森林分类器

创建一个随机森林分类器实例,并设定随机种子以保证结果的可重复性。

# 创建随机森林分类器实例
rf = RandomForestClassifier(random_state=42)

3. 定义超参数网格

定义一个超参数网格,用于GridSearchCV进行搜索。

# 定义超参数网格
param_grid = {
    'n_estimators': [50, 100, 200],  # 树的数量
    'max_depth': [None, 10, 20, 30],  # 树的最大深度
    'min_samples_split': [2, 5, 10],  # 内部节点再划分所需最小样本数
    'min_samples_leaf': [1, 2, 4]     # 叶子节点所需的最小样本数
}

4. 创建并执行网格搜索

使用GridSearchCV创建一个搜索对象,并执行网格搜索。

# 创建GridSearchCV对象
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1)

# 执行网格搜索
grid_search.fit(X_train, y_train)

5. 查看最佳参数组合

查看网格搜索后找到的最佳参数组合,并评估最佳模型的性能。

# 查看最佳参数组合
print("Best parameters found: ", grid_search.best_params_)
print("Best cross-validation score: %.3f" % grid_search.best_score_)

# 获取最佳模型
best_rf = grid_search.best_estimator_

# 在测试集上进行预测
y_pred = best_rf.predict(X_test)

# 打印分类报告
print(classification_report(y_test, y_pred))

完整代码

将上述步骤整合成一个完整的代码块:

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report

# 加载数据
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建随机森林分类器
rf = RandomForestClassifier(random_state=42)

# 定义超参数网格
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 10, 20, 30],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}

# 创建GridSearchCV对象
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=5, scoring='accuracy', n_jobs=-1)

# 执行网格搜索
grid_search.fit(X_train, y_train)

# 查看最佳参数组合
print("Best parameters found: ", grid_search.best_params_)
print("Best cross-validation score: %.3f" % grid_search.best_score_)

# 获取最佳模型
best_rf = grid_search.best_estimator_

# 在测试集上进行预测
y_pred = best_rf.predict(X_test)

# 打印分类报告
print(classification_report(y_test, y_pred))

标签:search,准确率,train,grid,玩转,test,import,best,sklearn
From: https://blog.csdn.net/u011026329/article/details/142051191

相关文章

  • 通义灵码用户说:“人工编写测试用例需要数十分钟,通义灵码以毫秒级的速度生成测试代码,且
    通过一篇文章,详细跟大家分享一下我在使用通义灵码过程中的感受。一、定义通义灵码,是一个智能编码助手,它基于通义大模型,提供代码智能生成、研发智能问答能力。在体验过程中有任何问题均可点击下面的连接前往了解和学习。通义灵码官网通义灵码安装教程通义灵码产品手册......
  • 每天五分钟玩转深度学习框架PyTorch:获取神经网络模型的参数
    本文重点当我们定义好神经网络之后,这个网络是由多个网络层构成的,每层都有参数,我们如何才能获取到这些参数呢?我们将再下面介绍几个方法来获取神经网络的模型参数,此文我们是为了学习第6步(优化器)。获取所有参数Parametersfromtorchimportnnnet=nn.Sequential(nn.Linear(4......
  • 深度学习基础案例4--运用动态学习率构建CNN卷积神经网络实现的运动鞋识别(测试集的准
    ......
  • 每天五分钟玩转深度学习框架PyTorch:将nn的神经网络层连接起来
    本文重点前面我们学习pytorch中已经封装好的神经网络层,有全连接层,激活层,卷积层等等,我们可以直接使用。如代码所示我们直接使用了两个nn.Linear(),这两个linear之间并没有组合在一起,所以forward的之后,分别调用了,在实际使用中我们常常将几个神经层组合在一起,这样不仅操作方便,而且......
  • 8G 显存玩转书生大模型 Demo
    8G显存玩转书生大模型Demo首先第一步依旧是创建我们的开发机,选择上我们需要选择10%的开发机,镜像选择为Cuda-12.2。在输入开发机名称后,点击创建开发机。这里就不放创建的流程图了环境配置#创建环境condacreate-ndemopython=3.10-y#激活环境condaactivate......
  • 全能AI神器!工作效率提升80倍!Zmo.ai带你玩转AI做图!
    今天,我要给大家介绍一款神器:Zmo.ai。这个平台简直是做图神器,集多种功能于一身,让你像专业人士一样轻松创建和编辑图像,不需要任何美术与设计基础,真的非常适合我们这些“手残党”!我们只需单击按钮即可从文本或图像生成令人惊叹的AI艺术、图像、动漫和逼真的照片,最关键的是......
  • 【零基础玩转树莓派】03-USB摄像头和CSI摄像头的使用
    摄像头USB摄像头的使用环境搭建FSWebcam是一个简洁明了的网络摄像头应用程序,软件安装命令如下:sudoaptinstallfswebcam添加用户权限:sudousermod-a-Gvideo示例:添加pi用户权限到群组中:sudousermod-a-Gvideopi检查用户是否已正确添加到群组中:groups查看USB......
  • 带你1分钟玩转AI大模型微调推理,更有限时福利等你领
    本文分享自华为云开发者联盟微信公众号《如何1分钟玩转AI大模型微调推理?(文末有福利)》想要低成本用好大模型,必然离不开对它的微调(FineTuning)。那么,为什么大模型需要微调呢?举个例子:一个通用大模型涵盖了许多语言信息,我们和它可以进行流畅的对话。但是如果想要它正确回答“布......
  • 数据库上云有多轻松?华为云技术专家带你玩转云数据库API
    本文分享自华为云开发者联盟微信公众号《DTSETechTalk|第65期:智能数据底座使能千行百业,华为云数据库服务API揭秘与实践探索》华为云GaussDB是一款软硬全栈协同的企业级原生分布式数据库,支持x86和Kunpeng硬件架构,基于Share-nothing架构,具备高可用、高安全、高性能、......
  • CamoTeacher:玩转半监督伪装物体检测,双一致性动态调整样本权重 | ECCV 2024
    论文提出了第一个端到端的半监督伪装目标检测模型CamoTeacher。为了解决半监督伪装目标检测中伪标签中存在的大量噪声问题,包括局部噪声和全局噪声,引入了一种名为双旋转一致性学习(DRCL)的新方法,包括像素级一致性学习(PCL)和实例级一致性学习(ICL)。DRCL帮助模型缓解噪音问题,有效利用伪......