我正在研究语言任务的多类分类(4 类),并使用 BERT 模型进行分类任务。我正在关注这篇博文
NLP 的迁移学习:微调 BERT 用于文本分类
我的 BERT 微调模型返回
nn.LogSoftmax(dim=1)
我的数据非常不平衡,所以我使用了| ||计算类别的权重并使用损失中的权重。
sklearn.utils.class_weight.compute_class_weight
我的结果不太好,所以我想到用
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float)
cross_entropy = nn.NLLLoss(weight=weights)
进行实验,并有一个焦点损失的代码。
Focal Loss
我现在有 3 个问题。首先也是最重要的是
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
BCE_loss = nn.CrossEntropyLoss()(inputs, targets)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
我应该使用带有焦点损失的类权重吗?
- 如果我必须在这个
-
中实现权重,我可以在
Focal Loss
里面使用weights
参数nn.CrossEntropyLoss()
- 如果这个实现不正确,包括权重在内的正确代码应该是什么(如果可能)
当然,让我们分解如何在 PyTorch 中使用类权重和焦点损失来处理不平衡的多类分类问题。
了解问题
当某些类在数据集中比其他类出现的频率更高时,就会出现类不平衡。这可能会导致模型偏向多数类,在少数类上的表现不佳。类权重和焦点损失是解决此问题并在类之间实现更公平训练的两种技术。
1. 类权重
类权重为每个类分配一个权重,与该类的频率成反比。这意味着在计算损失时,少数类获得更高的权重,这有助于平衡多数类的影响。
已经在使用
sklearn.utils.class_weight.compute_class_weight
正确计算类权重。在 PyTorch 中,通常将其与
nn.CrossEntropyLoss
一起使用:
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
weights= torch.tensor(class_weights,dtype=torch.float).to(device) # 将权重发送到设备
cross_entropy = nn.CrossEntropyLoss(weight=weights)
请确保的
device
与的输入数据和模型(CPU 或 GPU)相同。
2. 焦点损失
焦点损失(Lin 等人,2017 年)通过降低易分类样本(即模型有信心地正确分类的样本)的损失来扩展交叉熵损失。它侧重于对模型具有挑战性的样本。
下面是 PyTorch 中焦点损失的一种常见实现:
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduce = reduce
def forward(self, inputs, targets):
ce_loss = nn.CrossEntropyLoss(reduction='none')(inputs, targets)
pt = torch.exp(-ce_loss) # 获取每个样本的概率
F_loss = self.alpha * (1 - pt)**self.gamma * ce_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
组合类权重和焦点损失
是的,可以组合类权重和焦点损失来解决类不平衡问题。实际上,这可能非常有效。可以修改焦点损失实现以包含类权重:
class FocalLossWithWeights(nn.Module):
def __init__(self, alpha=1, gamma=2, weights=None, reduce=True):
super(FocalLossWithWeights, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.weights = weights
self.reduce = reduce
def forward(self, inputs, targets):
ce_loss = nn.CrossEntropyLoss(weight=self.weights, reduction='none')(inputs, targets)
pt = torch.exp(-ce_loss)
F_loss = self.alpha * (1 - pt)**self.gamma * ce_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
使用修改后的焦点损失
# ...(的代码用于数据加载和模型定义)
# 假设已经计算了 class_weights
focal_loss_with_weights = FocalLossWithWeights(alpha=1, gamma=2, weights=weights.to(device))
# 在的训练循环中:
# ...
logits = model(input_ids, attention_mask)
loss = focal_loss_with_weights(logits, labels)
# ...
解释
-
在组合的实现中,我们将
nn.CrossEntropyLoss(reduction='none')
用于逐个样本损失,而不会聚合。 -
然后,我们将类权重直接应用于
nn.CrossEntropyLoss
中的逐个样本损失。 -
alpha
和gamma
参数控制焦点损失行为。alpha
调节类权重,而gamma
控制对错误分类样本的关注程度。
通过试验不同的
alpha
和
gamma
值以及类权重计算方法(例如,
'balanced'
或自定义权重),找到适合的特定数据集和问题的最佳组合非常重要。