首页 > 其他分享 >简单的XGBoost案例

简单的XGBoost案例

时间:2024-09-29 22:50:07浏览次数:11  
标签:样本 数据 模型 XGBoost 案例 简单 np data

一、前言

        今天我们来一起学习一个新的算法模型,XGboost算法

     1、XGBoost的特性

        XGBoost(Extreme Gradient Boosting)是一个高效的开源机器学习库,广泛应用于结构化数据的分类和回归问题。它基于梯度提升算法,利用决策树的集成方法来提高模型的性能和准确性。XGBoost 提供了以下几个主要特性:

        高效性:通过并行计算和缓存优化,XGBoost 在训练速度上显著快于传统的梯度提升算法。
        准确性:XGBoost 采用正则化手段来防止过拟合,并且支持自定义损失函数,使得模型在多种场景下表现出色。
        灵活性:支持多种数据类型和目标函数,适用于回归、分类、排序等多种任务。
        可解释性:可以生成特征重要性图,帮助分析模型的决策依据。

     2、XGBoost的使用场景

        金融欺诈检测: 通过分析交易数据,XGBoost 可以帮助识别潜在的欺诈行为,例如信用卡欺诈、保险欺诈等。

        客户流失预测: 在电信、银行等行业,通过客户的行为特征,预测哪些客户可能流失,从而制定相应的留存策略。

        异常检测:用于检测数据中的异常点,如信用卡欺诈、故障检测等。

        销售预测: 根据历史销售数据和市场特征,预测未来的销售趋势,帮助企业制定生产和营销策略。

        医疗健康: 通过患者的医疗记录和特征,预测疾病的发生风险,帮助医生做出更准确的诊断。

        图像分类与目标检测: 虽然 XGBoost 主要用于结构化数据,但在某些情况下,可以与图像特征提取结合使用,进行分类任务。

        推荐系统: 利用用户行为数据,构建推荐模型,提升用户的个性化体验。

      3、本文的案例分析

        本文将从异常检测入手,带大家通过一个简单的案例来了解XGBoost算法,我们创建一千个正常数据和五十个异常数据,通过训练来找出异常的数据。

二、具体代码

        本文的训练流程依旧是:创建数据集—>划分数据集—>创建模型—>训练模型—>预测和评估模型。

        先导入需要用到的库: 

import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import xgboost as xgb
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

      1、创建数据集:

np.random.seed(42)
# 正常数据
normal_data = np.random.normal(loc=0, scale=1, size=(1000, 2))
normal_labels = np.zeros(1000)

# 异常数据
anomaly_data = np.random.uniform(low=-6, high=6, size=(50, 2))
anomaly_labels = np.ones(50)

# 合并数据
X = np.vstack((normal_data, anomaly_data))
y = np.concatenate((normal_labels, anomaly_labels))

# 创建DataFrame
data = pd.DataFrame(X, columns=['Feature1', 'Feature2'])
data['Label'] = y

      np.random.seed(42): 通过设置随机种子,可以确保每次运行代码时生成的随机数相同,便于调试和复现结果。

        normal_data = np.random.normal(loc=0, scale=1, size=(1000, 2)):从正态分布中生成1000个样本。loc=0:均值为0。 scale=1:标准差为1。 size=(1000, 2):生成1000行,2列的数据,即每个样本有两个特征。

        normal_labels = np.zeros(1000):使用np.zeros(1000)生成1000个0,表示这些样本为正常数据。

        anomaly_data = np.random.uniform(low=-6, high=6, size=(50, 2)):从均匀分布中生成50个样本,范围更广。low=-6和high=6:生成的样本值范围在-6到6之间。 size=(50, 2):生成50行,2列的数据。

        anomaly_labels = np.ones(50):使用np.ones(50)生成50个1,表示这些样本为异常数据。

        np.vstack((normal_data, anomaly_data)):将正常数据和异常数据按垂直方向合并,形成一个新的数组X,包含1050个样本。

        合并标签:np.concatenate((normal_labels, anomaly_labels)):将正常标签和异常标签合并为一个新的标签数组y。

      2、数据集划分

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(data[['Feature1', 'Feature2']], data['Label'], test_size=0.2,
                                                    random_state=42)

     3、 创建模型

# 3. 创建DMatrix
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)

# 4. 设置参数
params = {
    'objective': 'binary:logistic',  # 二分类任务
    'max_depth': 3,
    'eta': 0.1,
    'eval_metric': 'logloss'
}

        objective: 定义了学习任务的目标类型。'binary:logistic':表示这是一个二分类任务,模型输出的是每个样本属于正类的概率。

        max_depth: 指定树的最大深度。 3:树的最大深度设置为3,限制了模型的复杂度,有助于防止过拟合。

        eta: 学习率,也称为步长。 0.1:学习率设置为0.1,表示每次迭代更新时会缩小权重的更新量。较小的学习率通常能提升模型的性能,但训练时间会变长。

         eval_metric: 评估指标。 'logloss':用于评估模型的损失,通常在二分类问题中使用。log loss越小,模型的表现越好。

      4、模型训练

# 训练模型
num_boost_round = 100
model = xgb.train(params, dtrain, num_boost_round)

        num_boost_round: 这是XGBoost模型训练的迭代次数,也称为 boosting rounds。 在这个例子中,设置为 100,表示模型将进行100次迭代,每次迭代都会根据之前的模型更新学习到的权重。

        xgb.train(): 这是用于训练XGBoost模型的函数。 第一个参数是 params,包含了模型的设置(如目标函数、最大深度等)。 第二个参数是 dtrain,这是之前创建的训练数据DMatrix,包含特征和标签。 第三个参数是 num_boost_round,表示模型训练的轮数。

      5、模型的评估与预测

# 进行预测
y_pred_probs = model.predict(dtest)
y_pred = (y_pred_probs > 0.5).astype(int)  # 使用0.5作为阈值进行分类

# 评估模型
print(classification_report(y_test, y_pred))

           

        这是模型评估结果的分类报告,包含了每个类别的精确度(precision)、召回率(recall)、F1分数和支持度(support)。下面逐一解释这些指标。

        1. 精确度(Precision) 定义: 正确预测为正类的样本占所有预测为正类的样本的比例。 0.0 类别: 0.98,表示被预测为正常的样本中,有98%是正确的。 1.0 类别: 1.00,表示被预测为异常的样本中,100%是正确的。

        2. 召回率(Recall) 定义: 正确预测为正类的样本占所有实际为正类的样本的比例。 0.0 类别: 1.00,表示所有实际正常的样本都被正确识别。 1.0 类别: 0.77,表示77%的实际异常样本被正确识别。

        3. F1分数(F1-score) 定义: 精确度和召回率的调和平均,综合了这两个指标的性能。 0.0 类别: 0.99,表示对正常样本的检测性能很高。 1.0 类别: 0.87,表示对异常样本的检测性能较好,但相较于正常样本稍低。

        4. 支持度(Support) 定义: 每个类别在测试集中的真实样本数量。 0.0 类别: 197,表示正常样本数量。 1.0 类别: 13,表示异常样本数量。

         5. 整体准确率(Accuracy) 0.99,表示模型的整体预测准确率为99%,即210个样本中有99%被正确分类。

        6. 宏平均(Macro Average) 这些指标的简单平均,适用于类别不平衡的情况。 精确度: 0.99,召回率: 0.88,F1分数: 0.93,表示整体性能良好,但召回率相对较低。

        7. 加权平均(Weighted Average) 考虑了每个类别的样本数量的加权平均。 精确度: 0.99,召回率: 0.99,F1分数: 0.98,表明模型在整体性能上表现优异。       

        到这里,我们的代码就正式结束了,看评估效果很不错。

三、结语

        最后,我们将预测的结果进行可视化,还记得我们20%的测试集嘛,我们将测试数据和真实数据都显示出来,看看效果如何。

# 可视化结果
plt.figure(figsize=(10, 6))

# 真实数据点
plt.scatter(data['Feature1'], data['Feature2'], c=data['Label'], cmap='coolwarm', alpha=0.5, label='真实数据')

# 预测异常数据
predicted_anomalies = data.iloc[X_test.index][y_pred == 1]
plt.scatter(predicted_anomalies['Feature1'], predicted_anomalies['Feature2'], edgecolor='red', facecolor='none', s=100,
            label='预测异常数据')

plt.title('Anomaly Detection')
plt.xlabel('Feature1')
plt.ylabel('Feature2')
plt.axhline(0, color='gray', linewidth=0.5, linestyle='--')
plt.axvline(0, color='gray', linewidth=0.5, linestyle='--')
plt.legend()
plt.show()

        可以看到我们用红色的小球代表异常值,红色的外环岱庙预测的异常值,正好是10个测试值,准确率还是蛮好的。

         最后,大家如果觉得对您有帮助的话麻烦点点赞,如果有错误或者纰漏希望您能在评论区指出,帮助大家能更好地理解和掌握相关知识!

标签:样本,数据,模型,XGBoost,案例,简单,np,data
From: https://blog.csdn.net/m0_62800009/article/details/142639350

相关文章

  • WPF下使用FreeRedis操作RedisStream实现简单的消息队列
    RedisStream简介RedisStream是随着5.0版本发布的一种新的Redis数据类型:高效消费者组:允许多个消费者组从同一数据流的不同部分消费数据,每个消费者组都能独立地处理消息,这样可以并行处理和提高效率。阻塞操作:消费者可以设置阻塞操作,这样它们会在流中有新数据添加时被唤醒并开始......
  • unity常见的两种简单易上手的移动方式
    第一,使用transform的translate进行移动。使用方法:对象.transform.translate(方向向量*normalized*Time.deltaTime*speed);normalized是将这个方向向量归一化,即模长等于1,这是为了控制速度等于后面的speed,如果不加也能够实现移动,但是速度不便于控制。Time.deltaTime是每一......
  • 【玩转Linux】如何简单快速理解权限?
     学习编程就得循环渐进,扎实基础,勿在浮沙筑高台   循环渐进Forward-CSDN博客Hello,这里是kiki,今天更新Linux部分,我们继续来扩充我们的知识面,我希望能努力把抽象繁多的知识讲的生动又通俗易懂,今天要讲的是权限~目录 循环渐进Forward-CSDN博客shell命令以及运行......
  • ✨简简单单写程序
    每个伟大的梦想,都有一个微不足道的开始。程序的设计目标和流程设计一个程序是为了让计算机始终不渝地遵循指令,以完成特定的任务。为了能让计算机听懂指令,我们编写程序来与计算机交流。编程方法使用IDE(集成编辑环境)例如:DevC++/CodeBlocks。使用洛谷在线编程https:......
  • 【FPGA开发】一文轻松入门Modelsim的简单操作
    Modelsim仿真的步骤    (1)创建新的工程。    (2)在弹出的窗口中,确定项目名和工作路径,库保持为work不变。    (3)添加已经存在的文件(rtl代码和tb代码)。    如果这里关闭后,还想继续添加,也可以直接在界面空白处右键进行添加。    加错......
  • 实验1_C语言输入输出和简单程序应用编程
    task.1//打印一个字符小人#include<stdio.h>intmain(){printf("O\n");printf("<H>\n");printf("II\n");return0;}  task.1-1&1-2#include<stdio.h>intmain(){printf(&qu......
  • C++实现简单的tcp协议
    Server.cpp#include<iostream>#include<winsock2.h>#include<ws2tcpip.h>#pragmacomment(lib,"ws2_32.lib")constintPORT=8888;constintBUFFER_SIZE=1024;intmain(){WSADATAwsaData;intiResult=WSAStartu......
  • 【Ruby】ruby on rails两行命令搭建简单的学生管理系统
    【Ruby】rubyonrails两行命令搭建简单的学生管理系统本文主要是让大家体验一下rubyonrails开发网站的快速,ruby和rails的安装以及一些细节的介绍请看本人的另一篇文章【Ruby】Web框架rubyonrails初识(MVC架构初理解)我们只需要两条命令,就可以搭建出一个简单的学生......
  • 实验1 C语言输入输出和简单程序编写
    task11#include<stdio.h>2intmain()3{4printf("0\n");5printf("<H>\n");6printf("II\n");7return0;8}   task1_1.c1#include<stdio.h>2intmain()3{4int......
  • 快手:数据库升级实践,实现PB级数据的高效管理|OceanBase案例
    本文作者:胡玉龙,快手技术专家快手在较初期采用了OceanBase 3.1版本成功替换了多个核心业务、数百套的MySQL集群。至2023年,快手的数据量已突破800TB大关,其中最大集群的数据量更是达到了数百TB级别。为此,快手将数据库系统升级至OceanBase4.x版本,从而显著提升了业务的稳定性和......