首页 > 其他分享 >不平衡数据集的建模的技巧和策略

不平衡数据集的建模的技巧和策略

时间:2023-02-01 23:13:39浏览次数:54  
标签:non 技巧 float64 模型 示例 建模 284807 平衡 null

前言 本文讨论了处理不平衡数据集和提高机器学习模型性能的各种技巧和策略,涵盖的一些技术包括重采样技术、代价敏感学习、使用适当的性能指标、集成方法和其他策略。

作者:Emine Bozkuş
来源:DeepHub IMBA

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

不平衡数据集是指一个类中的示例数量与另一类中的示例数量显著不同的情况。例如在一个二元分类问题中,一个类只占总样本的一小部分,这被称为不平衡数据集。类不平衡会在构建机器学习模型时导致很多问题。

不平衡数据集的主要问题之一是模型可能会偏向多数类,从而导致预测少数类的性能不佳。这是因为模型经过训练以最小化错误率,并且当多数类被过度代表时,模型倾向于更频繁地预测多数类。这会导致更高的准确率得分,但少数类别得分较低。

另一个问题是,当模型暴露于新的、看不见的数据时,它可能无法很好地泛化。这是因为该模型是在倾斜的数据集上训练的,可能无法处理测试数据中的不平衡。

在本文中,我们将讨论处理不平衡数据集和提高机器学习模型性能的各种技巧和策略。将涵盖的一些技术包括重采样技术、代价敏感学习、使用适当的性能指标、集成方法和其他策略。通过这些技巧,可以为不平衡的数据集构建有效的模型。

处理不平衡数据集的技巧

重采样技术是处理不平衡数据集的最流行方法之一。这些技术涉及减少多数类中的示例数量或增加少数类中的示例数量。

欠采样可以从多数类中随机删除示例以减小其大小并平衡数据集。这种技术简单易行,但会导致信息丢失,因为它会丢弃一些多数类示例。

过采样与欠采样相反,过采样随机复制少数类中的示例以增加其大小。这种技术可能会导致过度拟合,因为模型是在少数类的重复示例上训练的。

SMOTE是一种更高级的技术,它创建少数类的合成示例,而不是复制现有示例。这种技术有助于在不引入重复项的情况下平衡数据集。

代价敏感学习(Cost-sensitive learning)是另一种可用于处理不平衡数据集的技术。在这种方法中,不同的错误分类成本被分配给不同的类别。这意味着与错误分类多数类示例相比,模型因错误分类少数类示例而受到更严重的惩罚。

在处理不平衡的数据集时,使用适当的性能指标也很重要。准确性并不总是最好的指标,因为在处理不平衡的数据集时它可能会产生误导。相反,使用 AUC-ROC等指标可以更好地指示模型性能。

集成方法,例如 bagging 和 boosting,也可以有效地对不平衡数据集进行建模。这些方法结合了多个模型的预测以提高整体性能。Bagging 涉及独立训练多个模型并对它们的预测进行平均,而 boosting 涉及按顺序训练多个模型,其中每个模型都试图纠正前一个模型的错误。

重采样技术、成本敏感学习、使用适当的性能指标和集成方法是一些技巧和策略,可以帮助处理不平衡的数据集并提高机器学习模型的性能。

在不平衡数据集上提高模型性能的策略

收集更多数据是在不平衡数据集上提高模型性能的最直接策略之一。通过增加少数类中的示例数量,模型将有更多信息可供学习,并且不太可能偏向多数类。当少数类中的示例数量非常少时,此策略特别有用。

生成合成样本是另一种可用于提高模型性能的策略。合成样本是人工创建的样本,与少数类中的真实样本相似。这些样本可以使用 SMOTE等技术生成,该技术通过在现有示例之间进行插值来创建合成示例。生成合成样本有助于平衡数据集并为模型提供更多示例以供学习。

使用领域知识来关注重要样本也是一种可行的策略,通过识别数据集中信息量最大的示例来提高模型性能。例如,如果我们正在处理医学数据集,可能知道某些症状或实验室结果更能表明某种疾病。通过关注这些例子可以提高模型准确预测少数类的能力。

最后可以使用异常检测等高级技术来识别和关注少数类示例。这些技术可用于识别与多数类不同且可能是少数类示例的示例。这可以通过识别数据集中信息量最大的示例来帮助提高模型性能。

在收集更多数据、生成合成样本、使用领域知识专注于重要样本以及使用异常检测等先进技术是一些可用于提高模型在不平衡数据集上的性能的策略。这些策略可以帮助平衡数据集,为模型提供更多示例以供学习,并识别数据集中信息量最大的示例。

不平衡数据集的练习

这里我们使用信用卡欺诈分类的数据集演示处理不平衡数据的方法

import pandas as pd  
 import numpy as np  
 from sklearn.preprocessing import RobustScaler  
 from sklearn.linear\_model import LogisticRegression  
 from sklearn.model\_selection import train\_test\_split  
 from sklearn.metrics import accuracy\_score  
 from sklearn.metrics import confusion\_matrix, classification\_report,f1\_score,recall\_score,roc\_auc\_score, roc\_curve  
 import matplotlib.pyplot as plt  
 import seaborn as sns  
 from matplotlib import rc,rcParams  
 import itertools  
   
 import warnings  
 warnings.filterwarnings\("ignore", category\=DeprecationWarning\)   
 warnings.filterwarnings\("ignore", category\=FutureWarning\)   
 warnings.filterwarnings\("ignore", category\=UserWarning\)

读取数据:

 df \= pd.read\_csv\("creditcard.csv"\)  
 df.head\(\)  
 print\("Number of observations : " ,len\(df\)\)  
 print\("Number of variables : ", len\(df.columns\)\)  
 #Number of observations :  284807  
 #Number of variables :  31

查看数据集信息:

 df.info\(\)  
 \<class 'pandas.core.frame.DataFrame'\>  
 RangeIndex: 284807 entries, 0 to 284806  
 Data columns \(total 31 columns\):  
  \#   Column  Non-Null Count   Dtype  
 \---  \------  \--------------   \-----    
  0   Time    284807 non\-null  float64  
  1   V1      284807 non\-null  float64  
  2   V2      284807 non\-null  float64  
  3   V3      284807 non\-null  float64  
  4   V4      284807 non\-null  float64  
  5   V5      284807 non\-null  float64  
  6   V6      284807 non\-null  float64  
  7   V7      284807 non\-null  float64  
  8   V8      284807 non\-null  float64  
  9   V9      284807 non\-null  float64  
  10  V10     284807 non\-null  float64  
  11  V11     284807 non\-null  float64  
  12  V12     284807 non\-null  float64  
  13  V13     284807 non\-null  float64  
  14  V14     284807 non\-null  float64  
  15  V15     284807 non\-null  float64  
  16  V16     284807 non\-null  float64  
  17  V17     284807 non\-null  float64  
  18  V18     284807 non\-null  float64  
  19  V19     284807 non\-null  float64  
  20  V20     284807 non\-null  float64  
  21  V21     284807 non\-null  float64  
  22  V22     284807 non\-null  float64  
  23  V23     284807 non\-null  float64  
  24  V24     284807 non\-null  float64  
  25  V25     284807 non\-null  float64  
  26  V26     284807 non\-null  float64  
  27  V27     284807 non\-null  float64  
  28  V28     284807 non\-null  float64  
  29  Amount  284807 non\-null  float64  
  30  Class   284807 non\-null  int64  
 dtypes: float64\(30\), int64\(1\)  
 memory usage: 67.4 MB

查看分类类别:

 f,ax\=plt.subplots\(1,2,figsize\=\(18,8\)\)  
 df\['Class'\].value\_counts\(\).plot.pie\(explode\=\[0,0.1\],autopct\='\%1.1f\%\%',ax\=ax\[0\],shadow\=True\)  
 ax\[0\].set\_title\('dağılım'\)  
 ax\[0\].set\_ylabel\(''\)  
 sns.countplot\('Class',data\=df,ax\=ax\[1\]\)  
 ax\[1\].set\_title\('Class'\)  
 plt.show\(\)
 rob\_scaler \= RobustScaler\(\)  
 df\['Amount'\] \= rob\_scaler.fit\_transform\(df\['Amount'\].values.reshape\(\-1,1\)\)  
 df\['Time'\] \= rob\_scaler.fit\_transform\(df\['Time'\].values.reshape\(\-1,1\)\)  
 df.head\(\)

创建基类模型:

 X \= df.drop\("Class", axis\=1\)  
 y \= df\["Class"\]  
 X\_train, X\_test, y\_train, y\_test \= train\_test\_split\(X, y, test\_size\=0.20, random\_state\=123456\)  
 model \= LogisticRegression\(random\_state\=123456\)  
 model.fit\(X\_train, y\_train\)  
 y\_pred \= model.predict\(X\_test\)  
 accuracy \= accuracy\_score\(y\_test, y\_pred\)  
 print\("Accuracy: \%.3f"\%\(accuracy\)\)

我们创建的模型的准确率评分为0.999。我们可以说我们的模型很完美吗?混淆矩阵是一个用来描述分类模型的真实值在测试数据上的性能的表。它包含4种不同的估计值和实际值的组合。

 def plot\_confusion\_matrix\(cm, classes,  
  title\='Confusion matrix',  
  cmap\=plt.cm.Blues\):  
   
  plt.rcParams.update\(\{'font.size': 19\}\)  
  plt.imshow\(cm, interpolation\='nearest', cmap\=cmap\)  
  plt.title\(title,fontdict\=\{'size':'16'\}\)  
  plt.colorbar\(\)  
  tick\_marks \= np.arange\(len\(classes\)\)  
  plt.xticks\(tick\_marks, classes, rotation\=45,fontsize\=12,color\="blue"\)  
  plt.yticks\(tick\_marks, classes,fontsize\=12,color\="blue"\)  
  rc\('font', weight\='bold'\)  
  fmt \= '.1f'  
  thresh \= cm.max\(\)  
  for i, j in itertools.product\(range\(cm.shape\[0\]\), range\(cm.shape\[1\]\)\):  
  plt.text\(j, i, format\(cm\[i, j\], fmt\),  
  horizontalalignment\="center",  
  color\="red"\)  
   
  plt.ylabel\('True label',fontdict\=\{'size':'16'\}\)  
  plt.xlabel\('Predicted label',fontdict\=\{'size':'16'\}\)  
  plt.tight\_layout\(\)  
   
 plot\_confusion\_matrix\(confusion\_matrix\(y\_test, y\_pred\=y\_pred\), classes\=\['Non Fraud','Fraud'\],  
  title\='Confusion matrix'\)
  • 非欺诈类共进行了56875次预测,其中56870次(TP)正确,5次(FP)错误。
  • 欺诈类共进行了87次预测,其中31次(FN)错误,56次(TN)正确。

该模型可以预测欺诈状态,准确率为0.99。但当检查混淆矩阵时,欺诈类的错误预测率相当高。也就是说该模型正确地预测了非欺诈类的概率为0.99。但是非欺诈类的观测值的数量高于欺诈类的观测值的数量,这拉搞了我们对准确率的计算,并且我们更加关注的是欺诈类的准确率,所以我们需要一个指标来衡量它的性能。

选择正确的指标

在处理不平衡数据集时,选择正确的指标来评估模型的性能非常重要。传统指标,如准确性、精确度和召回率,可能不适用于不平衡的数据集,因为它们没有考虑数据中类别的分布。

经常用于不平衡数据集的一个指标是 F1 分数。F1 分数是精确率和召回率的调和平均值,它提供了两个指标之间的平衡。计算如下:

F1 = 2 * (precision * recall) / (precision + recall)

另一个经常用于不平衡数据集的指标是 AUC-ROC。AUC-ROC 衡量模型区分正类和负类的能力。它是通过绘制不同分类阈值下的TPR与FPR来计算的。AUC-ROC 值的范围从 0.5(随机猜测)到 1.0(完美分类)。

 print\(classification\_report\(y\_test, y\_pred\)\)  
   
  precision   recall   f1\-score   support  
   
  0       1.00      1.00      1.00     56875  
  1       0.92      0.64      0.76        87  
   
  accuracy                           1.00     56962  
  macro avg       0.96      0.82      0.88     56962  
 weighted avg       1.00      1.00      1.00     56962

返回对0(非欺诈)类的预测有多少是正确的。查看混淆矩阵,56870 + 31 = 56901个非欺诈类预测,其中56870个预测正确。0类的精度值接近1 (56870 / 56901)

返回对1 (欺诈)类的预测有多少是正确的。查看混淆矩阵,5 + 56 = 61个欺诈类别预测,其中56个被正确估计。0类的精度为0.92 (56 / 61),可以看到差别还是很大的

过采样

通过复制少数类样本来稳定数据集。

随机过采样:通过添加从少数群体中随机选择的样本来平衡数据集。如果数据集很小,可以使用这种技术。可能会导致过拟合。randomoverampler方法接受sampling_strategy参数,当sampling_strategy = ' minority '被调用时,它会增加minority类的数量,使其与majority类的数量相等。

我们可以在这个参数中输入一个浮点值。例如,假设我们的少数群体人数为1000人,多数群体人数为100人。如果我们说sampling_strategy = 0.5,少数类将被添加到500

 y\_train.value\_counts\(\)  
 0    227440  
 1       405  
 Name: Class, dtype: int64  
   
 from imblearn.over\_sampling import RandomOverSampler  
 oversample \= RandomOverSampler\(sampling\_strategy\='minority'\)  
 X\_randomover, y\_randomover \= oversample.fit\_resample\(X\_train, y\_train\)

采样后训练

 model.fit\(X\_randomover, y\_randomover\)  
 y\_pred \= model.predict\(X\_test\)  
   
 plot\_confusion\_matrix\(confusion\_matrix\(y\_test, y\_pred\=y\_pred\), classes\=\['Non Fraud','Fraud'\],  
  title\='Confusion matrix'\)

应用随机过采样后,训练模型的精度值为0.97,出现了下降。但是从混淆矩阵来看,模型的欺诈类的正确估计率有所提高。

SMOTE 过采样:从少数群体中随机选取一个样本。然后,为这个样本找到k个最近的邻居。从k个最近的邻居中随机选取一个,将其与从少数类中随机选取的样本组合在特征空间中形成线段,形成合成样本。

 from imblearn.over\_sampling import SMOTE  
 oversample = SMOTE\(\)  
 X\_smote, y\_smote = oversample.fit\_resample\(X\_train, y\_train\)

使用SMOTE后的数据训练

 model.fit\(X\_smote, y\_smote\)  
 y\_pred = model.predict\(X\_test\)  
   
 accuracy = accuracy\_score\(y\_test, y\_pred\)  
 plot\_confusion\_matrix\(confusion\_matrix\(y\_test, y\_pred=y\_pred\), classes=\['Non Fraud','Fraud'\],  
  title='Confusion matrix'\)

可以看到与基线模型相比,欺诈的准确率有所提高,但是比随机过采样有所下降,这可能是数据集的原因,因为SMOTE采样会生成心的数据,所以并不适合所有的数据集。

总结

在这篇文章中,我们讨论了处理不平衡数据集和提高机器学习模型性能的各种技巧和策略。不平衡的数据集可能是机器学习中的一个常见问题,并可能导致在预测少数类时表现不佳。

本文介绍了一些可用于平衡数据集的重采样技术,如欠采样、过采样和SMOTE。还讨论了成本敏感学习和使用适当的性能指标,如AUC-ROC,这可以提供更好的模型性能指示。

处理不平衡的数据集是具有挑战性的,但通过遵循本文讨论的技巧和策略,可以建立有效的模型准确预测少数群体。重要的是要记住最佳方法将取决于特定的数据集和问题,为了获得最佳结果,可能需要结合各种技术。因此,试验不同的技术并使用适当的指标评估它们的性能是很重要的。

作者:Emine Bozkuş

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。 【技术文档】《从零搭建pytorch模型教程》122页PDF下载 QQ交流群:444129970。群内有大佬负责解答大家的日常学习、科研、代码问题。 模型部署交流群:732145323。用于计算机视觉方面的模型部署、高性能计算、优化加速、技术学习等方面的交流。 其它文章 U-Net在2022年相关研究的论文推荐 用少于256KB内存实现边缘训练,开销不到PyTorch千分之一 PyTorch 2.0 重磅发布:一行代码提速 30% Hinton 最新研究:神经网络的未来是前向-前向算法 聊聊计算机视觉入门 FRNet:上下文感知的特征强化模块 DAMO-YOLO | 超越所有YOLO,兼顾模型速度与精度 《医学图像分割》综述,详述六大类100多个算法 如何高效实现矩阵乘?万文长字带你从CUDA初学者的角度入门 近似乘法对卷积神经网络的影响 BT-Unet:医学图像分割的自监督学习框架 语义分割该如何走下去? 轻量级模型设计与部署总结 从CVPR22出发,聊聊CAM是如何激活我们文章的热度! 入门必读系列(十六)经典CNN设计演变的关键总结:从VGGNet到EfficientNet 入门必读系列(十五)神经网络不work的原因总结 入门必读系列(十四)CV论文常见英语单词总结 入门必读系列(十三)高效阅读论文的方法 入门必读系列(十二)池化各要点与各方法总结 TensorRT教程(三)TensorRT的安装教程 TensorRT教程(一)初次介绍TensorRT TensorRT教程(二)TensorRT进阶介绍 计算机视觉中的高效阅读论文的方法总结 计算机视觉中的神经网络可视化工具与项目 计算机视觉中的transformer模型创新思路总结 计算机视觉中的传统特征提取方法总结 计算机视觉中的数据预处理与模型训练技巧总结 计算机视觉中的图像标注工具总结 计算机视觉中的数据增强方法总结 计算机视觉中的注意力机制技术总结 计算机视觉中的特征金字塔技术总结 计算机视觉中的池化技术总结 计算机视觉中的高效阅读论文的方法总结 计算机视觉中的论文创新的常见思路总结 神经网络中的归一化方法总结 神经网络的初始化方法总结

 

标签:non,技巧,float64,模型,示例,建模,284807,平衡,null
From: https://www.cnblogs.com/wxkang/p/17084435.html

相关文章

  • 3D建模零代码平台
    近几年,随着国内外文化产业的迅猛发展,3D建模行业迎来黄金发展期。尤其是在元宇宙时代及数字体验经济时代的大背景下,越来越多的实时、可交互的3D内容将出现在人们的生活中。......
  • LabVIEW|小技巧1:秒破加密的vi文件方法
    可以利用专门的网站进行vi文件的解密,网站如下:​​https://www.hmilch.net/h/labview​​步骤:点击-选择按钮->空白框里输入“YES"->点击-提交;成功后下载解密的vi文件(注:此vi文......
  • ThreadPoolExecutor线程池参数设置技巧
    一、ThreadPoolExecutor的重要参数1、corePoolSize:核心线程数*核心线程会一直存活,及时没有任务需要执行*当线程数小于核心线程数时,即使有线程空闲,线程池也会优先创建......
  • 关于unity使用导入 .unitypackage报错的解决技巧
    在开发过程中,大家难免会下载网络demo用来参考,但是网络上的demo,有很大一部分demo在导入时是报错的,此时就需要修复它让它可以顺畅地运行起来。 在使用demo时,就遇到了一种报......
  • SourceForge文件无法下载问题的技巧(20230131有效)
    曾经多么辉煌的开源网站SourceForge如今还有多少人记得(不小心暴露了年纪),最近有找到早期的一个开源项目下载其中的文件发现只有SourceForge上有保存,于是到SourceForge上下载......
  • css技巧篇(一)
    虚线:css自带的dashed虚线样式非常的有限。往往是不满足UI设计稿的需求的,这里可以使用渐变的方式实现:/**使用渐变来自定义虚线*/background:linear-gra......
  • 小技巧[维护ing]
    记录一些平时遇见的问题,便于后续遇到相同问题时查看1powershell界面按上键不能显示上次的命令出现这种问题可能是因为命令记录的缓冲区已经满了,方法1:可以打开属性->选......
  • 数学建模学习——Day04
    一、灰色关联分析1.基本思想:根据序列曲线几何形状的相似程度来判断其联系是否紧密。曲线越接近,相应序列之间的关联度就越大,反之就越小。2.应用1)进行系统分析: ·1.画......
  • 基于MPPT的PV光伏发电simulink建模和仿真
    1.算法描述       MPPT控制器的全称是“最大功率点跟踪”(MaximumPowerPointTracking)太阳能控制器,是传统太阳能充放电控制器的升级换代产品。MPPT控制器能够实时......
  • 3D场景建模零代码平台
    3D场景建模软件(零基础、零代码、**),是指用来制作场景的软件,分为2D建模和3D建模,二者使用的技术及原理不同。2D软件:它是用3维几何图形绘制出三维图形的软件,其主要功能是利用......