首页 > 其他分享 >深度学习笔记: 详解处理类别不平衡

深度学习笔记: 详解处理类别不平衡

时间:2024-05-29 11:29:31浏览次数:24  
标签:loss 20 log cdot 笔记 垃圾邮件 类别 1.05 详解

欢迎收藏Star我的Machine Learning Blog:https://github.com/purepisces/Wenqing-Machine_Learning_Blog。如果收藏star, 有问题可以随时与我交流, 谢谢大家!

处理类别不平衡

在欺诈检测、点击预测或垃圾邮件检测等机器学习用例中,通常会遇到标签不平衡的问题。根据具体用例,可以使用以下几种策略来处理不平衡。

在损失函数中使用类别权重

例如,在垃圾邮件检测问题中,非垃圾邮件数据占95%,而垃圾邮件数据仅占5%。我们希望对非垃圾邮件类别给予更高的惩罚。在这种情况下,可以通过权重修改熵损失函数。

// w0是类别0的权重,w1是类别1的权重
loss_function = -w0 * ylog(p) - w1*(1-y)*log(1-p)

情况1:正确分类(垃圾邮件)

  • 真实标签: 垃圾邮件 (y = 1)
  • 预测概率: p = 0.9 (对垃圾邮件的高置信度)
  • 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)

由于真实标签是垃圾邮件且模型预测为垃圾邮件,这是正确分类。

loss = − w 1 ⋅ y ⋅ log ⁡ ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ⁡ ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1​⋅y⋅log(p)−w0​⋅(1−y)⋅log(1−p)

loss = − 20 ⋅ 1 ⋅ log ⁡ ( 0.9 ) − 1.05 ⋅ 0 ⋅ log ⁡ ( 0.1 ) \text{loss} = - 20 \cdot 1 \cdot \log(0.9) - 1.05 \cdot 0 \cdot \log(0.1) loss=−20⋅1⋅log(0.9)−1.05⋅0⋅log(0.1)

loss = − 20 ⋅ log ⁡ ( 0.9 ) \text{loss} = - 20 \cdot \log(0.9) loss=−20⋅log(0.9)

loss ≈ − 20 ⋅ ( − 0.105 ) \text{loss} \approx - 20 \cdot (-0.105) loss≈−20⋅(−0.105)

loss ≈ 2.1 \text{loss} \approx 2.1 loss≈2.1

情况2:错误分类(垃圾邮件)

  • 真实标签: 垃圾邮件 (y = 1)
  • 预测概率: p = 0.3 (对垃圾邮件的低置信度)
  • 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)

由于真实标签是垃圾邮件但模型预测为非垃圾邮件,这是错误分类。

loss = − w 1 ⋅ y ⋅ log ⁡ ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ⁡ ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1​⋅y⋅log(p)−w0​⋅(1−y)⋅log(1−p)

loss = − 20 ⋅ 1 ⋅ log ⁡ ( 0.3 ) − 1.05 ⋅ 0 ⋅ log ⁡ ( 0.7 ) \text{loss} = - 20 \cdot 1 \cdot \log(0.3) - 1.05 \cdot 0 \cdot \log(0.7) loss=−20⋅1⋅log(0.3)−1.05⋅0⋅log(0.7)

loss = − 20 ⋅ log ⁡ ( 0.3 ) \text{loss} = - 20 \cdot \log(0.3) loss=−20⋅log(0.3)

loss ≈ − 20 ⋅ ( − 0.523 ) \text{loss} \approx - 20 \cdot (-0.523) loss≈−20⋅(−0.523)

loss ≈ 10.46 \text{loss} \approx 10.46 loss≈10.46

情况3:正确分类(非垃圾邮件)

  • 真实标签: 非垃圾邮件 (y = 0)
  • 预测概率: p = 0.1 (对垃圾邮件的低置信度)
  • 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)

由于真实标签是非垃圾邮件且模型预测为非垃圾邮件,这是正确分类。

loss = − w 1 ⋅ y ⋅ log ⁡ ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ⁡ ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1​⋅y⋅log(p)−w0​⋅(1−y)⋅log(1−p)

loss = − 20 ⋅ 0 ⋅ log ⁡ ( 0.1 ) − 1.05 ⋅ 1 ⋅ log ⁡ ( 0.9 ) \text{loss} = - 20 \cdot 0 \cdot \log(0.1) - 1.05 \cdot 1 \cdot \log(0.9) loss=−20⋅0⋅log(0.1)−1.05⋅1⋅log(0.9)

loss = − 1.05 ⋅ log ⁡ ( 0.9 ) \text{loss} = - 1.05 \cdot \log(0.9) loss=−1.05⋅log(0.9)

loss ≈ − 1.05 ⋅ ( − 0.105 ) \text{loss} \approx - 1.05 \cdot (-0.105) loss≈−1.05⋅(−0.105)

loss ≈ 0.11 \text{loss} \approx 0.11 loss≈0.11

情况4:错误分类(非垃圾邮件)

  • 真实标签: 非垃圾邮件 (y = 0)
  • 预测概率: p = 0.8 (对垃圾邮件的高置信度)
  • 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)

由于真实标签是非垃圾邮件但模型预测为垃圾邮件,这是错误分类。

loss = − w 1 ⋅ y ⋅ log ⁡ ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ⁡ ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1​⋅y⋅log(p)−w0​⋅(1−y)⋅log(1−p)

loss = − 20 ⋅ 0 ⋅ log ⁡ ( 0.8 ) − 1.05 ⋅ 1 ⋅ log ⁡ ( 0.2 ) \text{loss} = - 20 \cdot 0 \cdot \log(0.8) - 1.05 \cdot 1 \cdot \log(0.2) loss=−20⋅0⋅log(0.8)−1.05⋅1⋅log(0.2)

loss = − 1.05 ⋅ log ⁡ ( 0.2 ) \text{loss} = - 1.05 \cdot \log(0.2) loss=−1.05⋅log(0.2)

loss ≈ − 1.05 ⋅ ( − 0.699 ) \text{loss} \approx - 1.05 \cdot (-0.699) loss≈−1.05⋅(−0.699)

loss ≈ 0.73 \text{loss} \approx 0.73 loss≈0.73

总结

在这些例子中,可以看到:

  • 当模型正确分类邮件时,损失相对较低。
  • 当模型错误分类邮件时,损失较高,尤其是垃圾邮件类别由于更高的权重。

在训练期间,错误分类的高损失促使模型调整其参数,以减少未来迭代中的高损失,从而提高整体性能,特别是对少数类别(垃圾邮件)。当错误分类垃圾邮件的损失更高时,这会提示模型在训练期间更加关注正确预测垃圾邮件。通过在数据集不平衡时平衡每个类别的影响,模型倾向于偏向多数类(在这种情况下是非垃圾邮件),因为它看到更多该类的例子。通过赋予少数类(垃圾邮件)更高的权重,我们确保对少数类的错误具有更大的影响,从而鼓励模型更多地关注正确分类少数类。提高少数类的召回率:在垃圾邮件检测等场景中,正确识别垃圾邮件(真阳性)通常比正确识别非垃圾邮件(真阴性)更为重要。通过惩罚模型将垃圾邮件错误分类为非垃圾邮件,可以提高垃圾邮件类的召回率,减少假阴性数量。

在垃圾邮件检测的上下文中:

真阳性(TP):实际是垃圾邮件且模型正确识别为垃圾邮件的邮件。真阴性(TN):实际不是垃圾邮件(非垃圾邮件)且模型正确识别为非垃圾邮件的邮件。假阳性(FP):实际不是垃圾邮件(非垃圾邮件)但模型错误识别为垃圾邮件的邮件。假阴性(FN):实际是垃圾邮件但模型错误识别为非垃圾邮件的邮件。

使用简单重采样

以一定比例重采样非垃圾邮件类别,以减少训练集中的不平衡。保持验证数据和测试数据不变(不重采样)非常重要。

简单重采样的类型

少数类的过采样:这涉及到复制少数类的例子,以增加其在训练数据集中的频率。

多数类的欠采样:这涉及到随机移除多数类的例子,以减少其在训练数据集中的频率。

重要考虑因素

验证和测试数据的完整性:保持验证和测试数据完整,不进行任何重采样,以确保性能指标反映模型对未见数据的泛化能力。重采样只应应用于训练数据。

过拟合风险(在过采样中):过采样可能导致过拟合,因为模型可能会记住少数类的重复实例。

信息丢失(在欠采样中):欠采样可能会导致多数类的有价值信息丢失,这可能会降低模型的性能。

import pandas as pd
from sklearn.utils import resample
# 创建一个示例数据集
data = {
    'feature1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
    'feature2': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
    'label': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
}

# 将字典转换为pandas DataFrame
dataset = pd.DataFrame(data)

# 分离多数类和少数类
non_spam = dataset[dataset['label'] == 0]
spam = dataset[dataset['label'] == 1]

# 少数类的过采样
spam_oversampled = resample(spam, replace=True, n_samples=len(non_spam), random_state=42)

# 将多数类与过采样的少数类合并
oversampled_dataset = pd.concat([non_spam, spam_oversampled])

# 验证过采样后的新分布
print("过采样数据集类别分布:")
print(oversampled_dataset['label'].value_counts())
print(oversampled_dataset)

# 多数类的欠采样
non_spam_undersampled = resample(non_spam, replace=False, n_samples=len(spam), random_state=42)

# 将欠采样的多数类与少数类合并
undersampled_dataset = pd.concat([non_spam_undersampled, spam])

# 验证欠采样后的新分布
print("\n欠采样数据集类别分布:")
print(undersampled_dataset['label'].value_counts())
print(undersampled_dataset)

打印结果:

过采样数据集类别分布:
label
0    14
1    14
Name: count, dtype: int64
    feature1  feature2  label
0          1         2      0
1          2         3      0
2          3         4      0
3          4         5      0
4          5         6      0
5          6         7      0
6          7         8      0
7          8         9      0
8          9        10      0
9         10        11      0
10        11        12      0
11        12        13      0
12        13        14      0
13        14        15      0
17        18        19      1
18        19        20      1
16        17        18      1
18        19        20      1
18        19        20      1
15        16        17      1
16        17        18      1
16        17        18      1
16        17        18      1
18        19        20      1
17        18        19      1
16        17        18      1
19        20        21      1
18        19        20      1

欠采样数据集类别分布:
label
0    6
1    6
Name: count, dtype: int64
    feature1  feature2  label
9         10        11      0
11        12        13      0
0          1         2      0
12        13        14      0
5          6         7      0
8          9        10      0
14        15        16      1
15        16        17      1
16        17        18      1
17        18        19      1
18        19        20      1
19        20        21      1

使用合成重采样

合成少数过采样技术(SMOTE)包括基于现有少数类元素合成新元素。其工作原理是从少数类中随机选择一个点,并为该点计算k近邻。合成点添加在选择点及其邻居之间。由于实际原因,SMOTE不像其他方法那样广泛使用。
在这里插入图片描述

参考资料:

  • Educative上的机器学习系统设计

标签:loss,20,log,cdot,笔记,垃圾邮件,类别,1.05,详解
From: https://blog.csdn.net/weixin_53765658/article/details/139290725

相关文章

  • SpringBoot如何使用日志Logback,及日志等级详解
    SpringBoot默认已经集成了SLF4J(SimpleLoggingFacadeforJava)作为日志的接口,以及Logback作为日志的实现。这意味着在大多数情况下,你无需做额外的配置即可开始记录日志。下面是一个简要的指南,包括如何在SpringBoot应用中使用SLF4J和Logback,以及一些实际的代码示例。默......
  • VUE学习笔记(十一)-登录和状态管理
    登录和状态管理src/auth/views/UserLogin.vue<template><divclass="login"><divclass="body"><divclass="container"><h2>用户登陆</h2><el-......
  • VUE学习笔记(十二)-axios拦截器的配置
    axios拦截器的配置src/api/api_config.jsimportaxiosfrom"axios";import{getToken}from"@/auth/auth.service";import{ElMessage}from'element-plus'axios.defaults.baseURL="http://localhost:8080/api";axios.defa......
  • VUE学习笔记(十三)-token过期时间处理
    token过期时间处理添加jwt指令yarnaddjsonwebtoken或者npminstalljsonwebtoken-Syarnaddnode-polyfill-webpack-pluginsrc/auth/auth.service.jsimportaxiosfrom"@/api/api_config"importrouterfrom'@/router'import*asjwtfrom'jsonwe......
  • VUE学习笔记(十四)-调整axios拦截器
    调整axios拦截器src/api/api_config.jsimportaxiosfrom"axios";import{getToken}from"@/auth/auth.service";import{ElMessage}from'element-plus'axios.defaults.baseURL="http://localhost:8080/api";axios.defau......
  • VUE学习笔记(十五)-退出功能
    退出功能src/views/LayoutView.vue<template><el-containerclass="layout-container-demo"><el-asidewidth="200px"><el-scrollbar><divclass="mb-2logo">Vue+WEBAPI</div>......
  • VUE学习笔记(八)
    登录页设计src下新建auth文件夹,新建auth.service.js,auth文件夹下新建views文件夹,view文件夹下新建UserLogin.vueUserLogin.vue<template><divclass="login"><divclass="body"><divclass="container">......
  • VUE学习笔记(九)
    登录数据数据验证,学习elementplus组件种页面数据验证UserLogin.vue页面<template><divclass="login"><divclass="body"><divclass="container"><h2>用户登陆</h2>......
  • JavaScript中的let关键字详解
    在JavaScript中,let关键字用于声明局部变量,它与传统的var关键字类似,但引入了几个关键的区别和改进,主要体现在作用域规则、重复声明限制以及引入了“暂时性死区”等概念。下面将详细介绍let的特点及其与var的不同之处。块级作用域vs函数作用域var声明的变量:其作用域是函数......
  • 《第二节》一、FreeRTOS学习笔记-任务创建和删除
    FreeRTOS的任务创建和删除1,任务创建和删除的API函数(熟悉)任务的创建和删除本质就是调用FreeRTOS的API函数一、任务创建动态创建任务:任务的任务控制块以及任务的栈空间所需的内存,均由FreeRTOS从FreeRTOS管理的堆中分配静态创建任务:任务的任务控制块以及任务的栈空间所需......