欢迎收藏Star我的Machine Learning Blog:https://github.com/purepisces/Wenqing-Machine_Learning_Blog。如果收藏star, 有问题可以随时与我交流, 谢谢大家!
处理类别不平衡
在欺诈检测、点击预测或垃圾邮件检测等机器学习用例中,通常会遇到标签不平衡的问题。根据具体用例,可以使用以下几种策略来处理不平衡。
在损失函数中使用类别权重
例如,在垃圾邮件检测问题中,非垃圾邮件数据占95%,而垃圾邮件数据仅占5%。我们希望对非垃圾邮件类别给予更高的惩罚。在这种情况下,可以通过权重修改熵损失函数。
// w0是类别0的权重,w1是类别1的权重
loss_function = -w0 * ylog(p) - w1*(1-y)*log(1-p)
情况1:正确分类(垃圾邮件)
- 真实标签: 垃圾邮件 (y = 1)
- 预测概率: p = 0.9 (对垃圾邮件的高置信度)
- 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)
由于真实标签是垃圾邮件且模型预测为垃圾邮件,这是正确分类。
loss = − w 1 ⋅ y ⋅ log ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1⋅y⋅log(p)−w0⋅(1−y)⋅log(1−p)
loss = − 20 ⋅ 1 ⋅ log ( 0.9 ) − 1.05 ⋅ 0 ⋅ log ( 0.1 ) \text{loss} = - 20 \cdot 1 \cdot \log(0.9) - 1.05 \cdot 0 \cdot \log(0.1) loss=−20⋅1⋅log(0.9)−1.05⋅0⋅log(0.1)
loss = − 20 ⋅ log ( 0.9 ) \text{loss} = - 20 \cdot \log(0.9) loss=−20⋅log(0.9)
loss ≈ − 20 ⋅ ( − 0.105 ) \text{loss} \approx - 20 \cdot (-0.105) loss≈−20⋅(−0.105)
loss ≈ 2.1 \text{loss} \approx 2.1 loss≈2.1
情况2:错误分类(垃圾邮件)
- 真实标签: 垃圾邮件 (y = 1)
- 预测概率: p = 0.3 (对垃圾邮件的低置信度)
- 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)
由于真实标签是垃圾邮件但模型预测为非垃圾邮件,这是错误分类。
loss = − w 1 ⋅ y ⋅ log ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1⋅y⋅log(p)−w0⋅(1−y)⋅log(1−p)
loss = − 20 ⋅ 1 ⋅ log ( 0.3 ) − 1.05 ⋅ 0 ⋅ log ( 0.7 ) \text{loss} = - 20 \cdot 1 \cdot \log(0.3) - 1.05 \cdot 0 \cdot \log(0.7) loss=−20⋅1⋅log(0.3)−1.05⋅0⋅log(0.7)
loss = − 20 ⋅ log ( 0.3 ) \text{loss} = - 20 \cdot \log(0.3) loss=−20⋅log(0.3)
loss ≈ − 20 ⋅ ( − 0.523 ) \text{loss} \approx - 20 \cdot (-0.523) loss≈−20⋅(−0.523)
loss ≈ 10.46 \text{loss} \approx 10.46 loss≈10.46
情况3:正确分类(非垃圾邮件)
- 真实标签: 非垃圾邮件 (y = 0)
- 预测概率: p = 0.1 (对垃圾邮件的低置信度)
- 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)
由于真实标签是非垃圾邮件且模型预测为非垃圾邮件,这是正确分类。
loss = − w 1 ⋅ y ⋅ log ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1⋅y⋅log(p)−w0⋅(1−y)⋅log(1−p)
loss = − 20 ⋅ 0 ⋅ log ( 0.1 ) − 1.05 ⋅ 1 ⋅ log ( 0.9 ) \text{loss} = - 20 \cdot 0 \cdot \log(0.1) - 1.05 \cdot 1 \cdot \log(0.9) loss=−20⋅0⋅log(0.1)−1.05⋅1⋅log(0.9)
loss = − 1.05 ⋅ log ( 0.9 ) \text{loss} = - 1.05 \cdot \log(0.9) loss=−1.05⋅log(0.9)
loss ≈ − 1.05 ⋅ ( − 0.105 ) \text{loss} \approx - 1.05 \cdot (-0.105) loss≈−1.05⋅(−0.105)
loss ≈ 0.11 \text{loss} \approx 0.11 loss≈0.11
情况4:错误分类(非垃圾邮件)
- 真实标签: 非垃圾邮件 (y = 0)
- 预测概率: p = 0.8 (对垃圾邮件的高置信度)
- 类别权重: w_0 = 1.05 (非垃圾邮件), w_1 = 20 (垃圾邮件)
由于真实标签是非垃圾邮件但模型预测为垃圾邮件,这是错误分类。
loss = − w 1 ⋅ y ⋅ log ( p ) − w 0 ⋅ ( 1 − y ) ⋅ log ( 1 − p ) \text{loss} = - w_1 \cdot y \cdot \log(p) - w_0 \cdot (1 - y) \cdot \log(1 - p) loss=−w1⋅y⋅log(p)−w0⋅(1−y)⋅log(1−p)
loss = − 20 ⋅ 0 ⋅ log ( 0.8 ) − 1.05 ⋅ 1 ⋅ log ( 0.2 ) \text{loss} = - 20 \cdot 0 \cdot \log(0.8) - 1.05 \cdot 1 \cdot \log(0.2) loss=−20⋅0⋅log(0.8)−1.05⋅1⋅log(0.2)
loss = − 1.05 ⋅ log ( 0.2 ) \text{loss} = - 1.05 \cdot \log(0.2) loss=−1.05⋅log(0.2)
loss ≈ − 1.05 ⋅ ( − 0.699 ) \text{loss} \approx - 1.05 \cdot (-0.699) loss≈−1.05⋅(−0.699)
loss ≈ 0.73 \text{loss} \approx 0.73 loss≈0.73
总结
在这些例子中,可以看到:
- 当模型正确分类邮件时,损失相对较低。
- 当模型错误分类邮件时,损失较高,尤其是垃圾邮件类别由于更高的权重。
在训练期间,错误分类的高损失促使模型调整其参数,以减少未来迭代中的高损失,从而提高整体性能,特别是对少数类别(垃圾邮件)。当错误分类垃圾邮件的损失更高时,这会提示模型在训练期间更加关注正确预测垃圾邮件。通过在数据集不平衡时平衡每个类别的影响,模型倾向于偏向多数类(在这种情况下是非垃圾邮件),因为它看到更多该类的例子。通过赋予少数类(垃圾邮件)更高的权重,我们确保对少数类的错误具有更大的影响,从而鼓励模型更多地关注正确分类少数类。提高少数类的召回率:在垃圾邮件检测等场景中,正确识别垃圾邮件(真阳性)通常比正确识别非垃圾邮件(真阴性)更为重要。通过惩罚模型将垃圾邮件错误分类为非垃圾邮件,可以提高垃圾邮件类的召回率,减少假阴性数量。
在垃圾邮件检测的上下文中:
真阳性(TP):实际是垃圾邮件且模型正确识别为垃圾邮件的邮件。真阴性(TN):实际不是垃圾邮件(非垃圾邮件)且模型正确识别为非垃圾邮件的邮件。假阳性(FP):实际不是垃圾邮件(非垃圾邮件)但模型错误识别为垃圾邮件的邮件。假阴性(FN):实际是垃圾邮件但模型错误识别为非垃圾邮件的邮件。
使用简单重采样
以一定比例重采样非垃圾邮件类别,以减少训练集中的不平衡。保持验证数据和测试数据不变(不重采样)非常重要。
简单重采样的类型
少数类的过采样:这涉及到复制少数类的例子,以增加其在训练数据集中的频率。
多数类的欠采样:这涉及到随机移除多数类的例子,以减少其在训练数据集中的频率。
重要考虑因素
验证和测试数据的完整性:保持验证和测试数据完整,不进行任何重采样,以确保性能指标反映模型对未见数据的泛化能力。重采样只应应用于训练数据。
过拟合风险(在过采样中):过采样可能导致过拟合,因为模型可能会记住少数类的重复实例。
信息丢失(在欠采样中):欠采样可能会导致多数类的有价值信息丢失,这可能会降低模型的性能。
import pandas as pd
from sklearn.utils import resample
# 创建一个示例数据集
data = {
'feature1': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
'feature2': [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
'label': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
}
# 将字典转换为pandas DataFrame
dataset = pd.DataFrame(data)
# 分离多数类和少数类
non_spam = dataset[dataset['label'] == 0]
spam = dataset[dataset['label'] == 1]
# 少数类的过采样
spam_oversampled = resample(spam, replace=True, n_samples=len(non_spam), random_state=42)
# 将多数类与过采样的少数类合并
oversampled_dataset = pd.concat([non_spam, spam_oversampled])
# 验证过采样后的新分布
print("过采样数据集类别分布:")
print(oversampled_dataset['label'].value_counts())
print(oversampled_dataset)
# 多数类的欠采样
non_spam_undersampled = resample(non_spam, replace=False, n_samples=len(spam), random_state=42)
# 将欠采样的多数类与少数类合并
undersampled_dataset = pd.concat([non_spam_undersampled, spam])
# 验证欠采样后的新分布
print("\n欠采样数据集类别分布:")
print(undersampled_dataset['label'].value_counts())
print(undersampled_dataset)
打印结果:
过采样数据集类别分布:
label
0 14
1 14
Name: count, dtype: int64
feature1 feature2 label
0 1 2 0
1 2 3 0
2 3 4 0
3 4 5 0
4 5 6 0
5 6 7 0
6 7 8 0
7 8 9 0
8 9 10 0
9 10 11 0
10 11 12 0
11 12 13 0
12 13 14 0
13 14 15 0
17 18 19 1
18 19 20 1
16 17 18 1
18 19 20 1
18 19 20 1
15 16 17 1
16 17 18 1
16 17 18 1
16 17 18 1
18 19 20 1
17 18 19 1
16 17 18 1
19 20 21 1
18 19 20 1
欠采样数据集类别分布:
label
0 6
1 6
Name: count, dtype: int64
feature1 feature2 label
9 10 11 0
11 12 13 0
0 1 2 0
12 13 14 0
5 6 7 0
8 9 10 0
14 15 16 1
15 16 17 1
16 17 18 1
17 18 19 1
18 19 20 1
19 20 21 1
使用合成重采样
合成少数过采样技术(SMOTE)包括基于现有少数类元素合成新元素。其工作原理是从少数类中随机选择一个点,并为该点计算k近邻。合成点添加在选择点及其邻居之间。由于实际原因,SMOTE不像其他方法那样广泛使用。
参考资料:
- Educative上的机器学习系统设计