首页 > 其他分享 >Focal Loss详解及其pytorch实现

Focal Loss详解及其pytorch实现

时间:2024-08-20 09:22:25浏览次数:16  
标签:Loss log 0.25 样本 pytorch gamma alpha hat Focal

Focal Loss详解及其pytorch实现



文章目录


引言

Focal Loss是由何恺明等人在2017年的论文《Focal Loss for Dense Object Detection》中提出的。它是一种专门为解决目标检测中类别不平衡和难易样本不平衡问题而设计的损失函数。本文将详细介绍Focal Loss的基本概念、二分类和多分类的交叉熵损失函数,以及如何设置Focal Loss中的关键参数,并提供PyTorch的实现代码。

二分类与多分类的交叉熵损失函数

二分类交叉熵损失

在二分类的任务中,一般使用Sigmoid作为最后的激活函数,输出代表样本为正的概率值 y ^ \hat{y} y^​,二分类非正即负,所以样本为负的概率值为 1 − y ^ 1-\hat{y} 1−y^​。二分类交叉熵损失的计算公式为:

CEL = − y ⋅ log ⁡ ( y ^ ) − ( 1 − y ) ⋅ log ⁡ ( 1 − y ^ ) \text{CEL} = -y \cdot \log(\hat{y}) - (1-y) \cdot \log(1-\hat{y}) CEL=−y⋅log(y^​)−(1−y)⋅log(1−y^​)

其中 y y y 是实际标签,正样本为1,负样本为0, y ^ \hat{y} y^​ 是Sigmoid激活函数的输出值。

多分类交叉熵损失

在多分类的情况下,一般使用Softmax作为最后的激活函数,输出有多个值,对应每个分类的概率值,且这些值之和为1。多分类交叉熵损失的计算公式为:

CEL = − ∑ c = 1 C y c ⋅ log ⁡ ( y ^ c ) = − log ⁡ ( y ^ c ) \text{CEL} = -\sum_{c=1}^{C} y_c \cdot \log(\hat{y}_c) = -\log(\hat{y}_c) CEL=−c=1∑C​yc​⋅log(y^​c​)=−log(y^​c​)

其中 y ^ c \hat{y}_c y^​c​ 表示Softmax激活函数输出结果中第 c c c 类的对应的值, C C C 是类别的总数。

Focal Loss基础概念

关键点理解

要真正理解Focal Loss,有三个关键点需要明确:

  1. 二分类(Sigmoid)和多分类(Softmax)的交叉熵损失表达形式的区别
  2. 理解难分类样本与易分类样本
  3. Focal Loss中的超参数 α \alpha α 和 γ \gamma γ 的作用

什么是难分类样本和易分类样本?

  • 易分类样本:模型预测正确的概率较高,即 y ^ t \hat{y}_t y^​t​ 较大(通常 y ^ t > 0.5 \hat{y}_t > 0.5 y^​t​>0.5)。
  • 难分类样本:模型预测正确的概率较低,即 y ^ t \hat{y}_t y^​t​ 较小(通常 y ^ t < 0.5 \hat{y}_t < 0.5 y^​t​<0.5)。

其中 y ^ t \hat{y}_t y^​t​ 定义为:
y ^ t = { y ^ if  y = 1 1 − y ^ otherwise \hat{y}_t = \begin{cases} \hat{y} & \text{if } y = 1 \\ 1 - \hat{y} & \text{otherwise} \end{cases} y^​t​={y^​1−y^​​if y=1otherwise​

超参数 γ \gamma γ 的作用

超参数 γ \gamma γ 控制了难分类样本和易分类样本在损失函数中的比重。Focal Loss相对于原始的交叉熵损失增加了 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1−y^​t​)γ 这一项,对原始交叉熵损失进行了衰减。当 γ \gamma γ 增大时,对易分类样本的损失衰减更加明显,从而使模型更加关注难分类样本。

超参数 α \alpha α 的作用

超参数 α \alpha α 用于调整正负样本之间的权重。在二分类中, α \alpha α 的值反映了样本数量较少的类的权重。通常情况下,正样本数量较少(在本文中正样本代表数量少的样本),因此 α \alpha α 值反映了正样本的权重。随着 γ \gamma γ 的增加, α \alpha α 应该稍微降低。这是因为:

  • 低 α \alpha α 对应高 γ \gamma γ。负样本通常容易被正确分类,其权重已经被 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1−y^​t​)γ 显著降低,因此无需给正样本再增加额外过大的权重 α \alpha α。
  • 在Focal Loss中, γ \gamma γ 占主要地位,它确保了模型更加关注那些难以正确分类的样本。
  • 当处理负样本时, α \alpha α 的值通常为 1 − α 1 - \alpha 1−α,其中 α \alpha α 为正样本的权重。

超参数 α \alpha α 的详细解释

在Focal Loss中, α \alpha α 的作用是调整正负样本之间的权重。理论上,数量越少的类应该具有更大的权重。然而,在原论文作者的实验中,当 α = 0.25 \alpha = 0.25 α=0.25 和 γ = 2 \gamma = 2 γ=2 时,模型表现最好。这引发了一个问题:为什么正样本的权重( α = 0.25 \alpha = 0.25 α=0.25)反而比负样本的权重( 1 − α = 0.75 1 - \alpha = 0.75 1−α=0.75)要低,尤其是当负样本的数量远远多于正样本时?

这是因为Focal Loss的设计初衷是为了减少易分类样本的贡献,让模型更加关注难分类样本。随着 γ \gamma γ 的增加,难分类样本的权重实际上已经被显著提高了。此外,由于负样本通常更容易被正确分类,其权重已经被 ( 1 − y ^ t ) γ (1 - \hat{y}_t)^\gamma (1−y^​t​)γ 大幅降低,因此不需要再额外增加正样本的权重。这意味着,在Focal Loss中, γ \gamma γ 的作用更为关键,而 α \alpha α 的作用则相对次要。

实例计算

假设我们有一个正样本,模型预测的概率为0.8,取 γ = 2 \gamma = 2 γ=2。

  1. 计算 y ^ t \hat{y}_t y^​t​:
    y ^ t = y ^ = 0.8 \hat{y}_t = \hat{y} = 0.8 y^​t​=y^​=0.8

  2. 计算Focal Loss:
    FL ( y ^ t ) = − α t ⋅ ( 1 − 0.8 ) 2 ⋅ log ⁡ ( 0.8 ) \text{FL}(\hat{y}_t) = -\alpha_t \cdot (1 - 0.8)^2 \cdot \log(0.8) FL(y^​t​)=−αt​⋅(1−0.8)2⋅log(0.8)

若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 0.25 \alpha_t = 0.25 αt​=0.25,因此:
FL ( y ^ t ) = − 0.25 ⋅ ( 0.2 ) 2 ⋅ log ⁡ ( 0.8 ) ≈ − 0.25 ⋅ 0.04 ⋅ ( − 0.22314 ) ≈ 0.00223 \text{FL}(\hat{y}_t) = -0.25 \cdot (0.2)^2 \cdot \log(0.8) \approx -0.25 \cdot 0.04 \cdot (-0.22314) \approx 0.00223 FL(y^​t​)=−0.25⋅(0.2)2⋅log(0.8)≈−0.25⋅0.04⋅(−0.22314)≈0.00223

负样本实例

假设我们有一个负样本,模型预测的概率为0.2,取 γ = 2 \gamma = 2 γ=2。

  1. 计算 y ^ t \hat{y}_t y^​t​:
    y ^ t = 1 − y ^ = 1 − 0.2 = 0.8 \hat{y}_t = 1 - \hat{y} = 1 - 0.2 = 0.8 y^​t​=1−y^​=1−0.2=0.8

  2. 计算Focal Loss:
    FL ( y ^ t ) = − α t ⋅ ( 1 − 0.8 ) 2 ⋅ log ⁡ ( 0.8 ) \text{FL}(\hat{y}_t) = -\alpha_t \cdot (1 - 0.8)^2 \cdot \log(0.8) FL(y^​t​)=−αt​⋅(1−0.8)2⋅log(0.8)

若取 α = 0.25 \alpha = 0.25 α=0.25,则 α t = 1 − 0.25 = 0.75 \alpha_t = 1 - 0.25 = 0.75 αt​=1−0.25=0.75,因此:
FL ( y ^ t ) = − 0.75 ⋅ ( 0.2 ) 2 ⋅ log ⁡ ( 0.8 ) ≈ − 0.75 ⋅ 0.04 ⋅ ( − 0.22314 ) ≈ 0.00669 \text{FL}(\hat{y}_t) = -0.75 \cdot (0.2)^2 \cdot \log(0.8) \approx -0.75 \cdot 0.04 \cdot (-0.22314) \approx 0.00669 FL(y^​t​)=−0.75⋅(0.2)2⋅log(0.8)≈−0.75⋅0.04⋅(−0.22314)≈0.00669

多分类实例

假设我们有三个类别(猫、狗、兔子),模型预测的概率分别为 [ 0.2 , 0.5 , 0.3 ] [0.2, 0.5, 0.3] [0.2,0.5,0.3],实际标签是狗(one-hot编码为[0, 1, 0]),取 γ = 2 \gamma = 2 γ=2。

  1. 计算 y ^ c \hat{y}_c y^​c​:
    y ^ c = y ^ 2 = 0.5 \hat{y}_c = \hat{y}_2 = 0.5 y^​c​=y^​2​=0.5

  2. 计算Focal Loss:
    FL ( y ^ 2 ) = − α 2 ⋅ ( 1 − 0.5 ) 2 ⋅ log ⁡ ( 0.5 ) \text{FL}(\hat{y}_2) = -\alpha_2 \cdot (1 - 0.5)^2 \cdot \log(0.5) FL(y^​2​)=−α2​⋅(1−0.5)2⋅log(0.5)

若取 α 2 = 0.25 \alpha_2 = 0.25 α2​=0.25,则:
FL ( y ^ 2 ) = − 0.25 ⋅ ( 0.5 ) 2 ⋅ log ⁡ ( 0.5 ) ≈ − 0.25 ⋅ 0.25 ⋅ ( − 0.69315 ) ≈ 0.04332 \text{FL}(\hat{y}_2) = -0.25 \cdot (0.5)^2 \cdot \log(0.5) \approx -0.25 \cdot 0.25 \cdot (-0.69315) \approx 0.04332 FL(y^​2​)=−0.25⋅(0.5)2⋅log(0.5)≈−0.25⋅0.25⋅(−0.69315)≈0.04332

PyTorch实现

二分类Focal Loss

import torch

class FocalLossBinary(torch.nn.Module):
    """
    二分类Focal Loss
    """
    def __init__(self, alpha=0.25, gamma=2):
        super(FocalLossBinary, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, preds, labels):
        """
        preds: sigmoid的输出结果
        labels: 标签
        """
        eps = 1e-7
        loss_1 = -1 * self.alpha * torch.pow((1 - preds), self.gamma) * torch.log(preds + eps) * labels
        loss_0 = -1 * (1 - self.alpha) * torch.pow(preds, self.gamma) * torch.log(1 - preds + eps) * (1 - labels)
        loss = loss_0 + loss_1
        return torch.mean(loss)

多分类Focal Loss

import torch

class FocalLossMultiClass(torch.nn.Module):
    def __init__(self, weight=None, gamma=2):
        super(FocalLossMultiClass, self).__init__()
        self.gamma = gamma
        self.weight = weight
    
    def forward(self, preds, labels):
        """
        preds: softmax输出结果
        labels: 真实值
        """
        eps = 1e-7
        y_pred = preds.view((preds.size()[0], preds.size()[1], -1))  # B*C*H*W->B*C*(H*W)
        
        target = labels.view(y_pred.size())  # B*C*H*W->B*C*(H*W)
        
        ce = -1 * torch.log(y_pred + eps) * target
        floss = torch.pow((1 - y_pred), self.gamma) * ce
        if self.weight is not None:
            floss = torch.mul(floss, self.weight)
        floss = torch.sum(floss, dim=1)
        return torch.mean(floss)

结论

Focal Loss通过引入两个超参数 α \alpha α 和 γ \gamma γ,有效地解决了类别不平衡和难易样本不平衡的问题。通过调整这些超参数,可以使模型更加关注那些难以正确分类的样本,从而提高整体性能。在实际应用中,可以通过实验来确定最佳的 α \alpha α 和 γ \gamma γ 值。

参考文献

Focal Loss的理解以及在多分类任务上的使用(Pytorch) -
GHZhao_GIS_RS - CSDN

标签:Loss,log,0.25,样本,pytorch,gamma,alpha,hat,Focal
From: https://blog.csdn.net/weixin_51524504/article/details/141337347

相关文章

  • 深度学习加速秘籍:PyTorch torch.backends.cudnn 模块全解析
    标题:深度学习加速秘籍:PyTorchtorch.backends.cudnn模块全解析在深度学习领域,计算效率和模型性能是永恒的追求。PyTorch作为当前流行的深度学习框架之一,提供了一个强大的接口torch.backends.cudnn,用于控制CUDA深度神经网络库(cuDNN)的行为。本文将深入探讨torch.backends.cu......
  • 深度学习-pytorch-basic-001
    importtorchimportnumpyasnptorch.manual_seed(1234)<torch._C.Generatorat0x21c1651e190>defdescribe(x):print("Type:{}".format(x.type()))print("Shape/Size:{}".format(x.shape))print("Values:{}"......
  • PyTorch深度学习实战(18)—— 可视化工具
    在训练神经网络时,通常希望能够更加直观地了解训练情况,例如损失函数曲线、输入图片、输出图片等信息。这些信息可以帮助读者更好地监督网络的训练过程,并为参数优化提供方向和依据。最简单的办法就是打印输出,这种方式只能打印数值信息,不够直观,同时无法查看分布、图片、声音等......
  • 零基础学习人工智能—Python—Pytorch学习(五)
    前言上文有一些文字打错了,已经进行了修正。本文主要介绍训练模型和使用模型预测数据,本文使用了一些numpy与tensor的转换,忘记的可以第二课的基础一起看。线性回归模型训练结合numpy使用首先使用datasets做一个数据X和y,然后结合之前的内容,求出y_predicted。#pipinstallmatp......
  • PyTorch--双向长短期记忆网络(BiRNN)在MNIST数据集上的实现与分析
    文章目录前言完整代码代码解析1.导入库2.设备配置3.超参数设置4.数据集加载5.数据加载器6.定义BiRNN模型7.实例化模型并移动到设备8.损失函数和优化器9.训练模型10.测试模型11.保存模型常用函数前言本代码实现了一个基于PyTorch的双向长短期记忆网络(BiRNN),用于对MNI......
  • 用pytorch实现LeNet-5网络
     上篇讲述了LeNet-5网络的理论,本篇就试着搭建LeNet-5网络。但是搭建完成的网络还存在着问题,主要是训练的准确率太低,还有待进一步探究问题所在。是超参数的调节有问题?还是网络的结构有问题?还是哪里搞错了什么1.库的导入dataset:datasets.MNIST()函数,该函数作用是导入MNIST数......
  • PyTorch--实现循环神经网络(RNN)模型
    文章目录前言完整代码代码解析导入必要的库设备配置超参数设置数据集加载数据加载器定义RNN模型实例化模型并移动到设备损失函数和优化器训练模型测试模型保存模型小改进神奇的报错ValueError:LSTM:Expectedinputtobe2Dor3D,got4Dinstead前言首先,这篇......
  • 掌握 PyTorch 张量乘法:八个关键函数与应用场景对比解析
    PyTorch提供了几种张量乘法的方法,每种方法都是不同的,并且有不同的应用。我们来详细介绍每个方法,并且详细解释这些函数有什么区别:1、torch.matmultorch.matmul是PyTorch中用于矩阵乘法的函数。它能够处理各种不同维度的张量,并根据张量的维度自动调整其操作方式。torch......
  • pytorch 3 计算图
    计算图结构分析:起始节点ab=5-3ac=2b+3d=5b+6e=7c+d^2f=2e最终输出g=3f-o(其中o是另一个输入)前向传播前向传播按照上述顺序计算每个节点的值。反向传播过程反向传播的目标是计算损失函数(这里假设为g)对每个中间变量和输入的偏导数。从右向左......
  • 【锂电池SOC估计】【PyTorch】基于Basisformer时间序列锂离子电池SOC预测研究(python代
     ......