首页 > 其他分享 >聊聊损失函数1. 噪声鲁棒损失函数简析 & 代码实现

聊聊损失函数1. 噪声鲁棒损失函数简析 & 代码实现

时间:2023-01-01 15:12:26浏览次数:72  
标签:loss torch 函数 标签 self labels beta 损失 简析

今天来聊聊非常规的损失函数。在常用的分类交叉熵,以及回归均方误差之外,针对训练样本可能存在的数据长尾,标签噪声,数据不均衡等问题,我们来聊聊适用不同场景有针对性的损失函数。第一章我们介绍,当标注标签存在噪声时可以尝试的损失函数,这里的标签噪声主要指独立于特征分布的标签噪声。代码详见pytorch, Tensorflow

Symmetric Loss Function

paper: Making Risk Minimization Tolerant to Label Noise

这里我们用最基础的二分类问题,和一个简化的假设"标注噪声和标签独立且均匀分布",来解释下什么是对标注噪声鲁棒的损失函数。假设整体误标注的样本占比为\(\eta\),则在真实标签y=0和y=1中均有\(\eta\)比例的误标注,1被标成0,0被标称1。带噪声的损失函数如下

\[\begin{align} L(f(x), y_{noise}) &= (1-\eta)*L(f(x), y) + \eta * L(f(x), 1-y) \\ & = (1-2\eta)*L(f(x),y) + \eta*[L(f(x),y)+L(f(x),1-y)] \\ & = (1-2\eta)*L(f(x),y) + \eta K \\ \end{align} \]

因此如果损失函数满足\(L(f(x),y)+L(f(x),1-y)=constant\),则带噪声的损失函数会和不带噪声的\(L(f(x),y)\)收敛到相同的解。作者认为这样的损失函数就是symmetric的。

那有哪些常见的损失函数是symmetric loss呢?

MAE就是!对于二分类的softmax的输出层\(L(f(x),y)+L(f(x),1-y)=|y-f(x)| + |1-y-f(x)| = 1\)

敲黑板!记住这一点,因为后面的GCE和SCE其实都和MAE有着脱不开的关系。这里对symmetric loss的论证做了简化,细节详见论文~

Generalized Cross Entropy(GCE)

paper:Generalized Cross Entropy Loss for Training Deep Neural Networks with Noisy Labels

话接上文,MAE虽然是一种noise robust的损失函数,但是在深度学习中,因为MAE的梯度不是1就是-1,所有样本梯度scale都相同,缺乏对样本难易程度和模型置信度的刻画,因此MAE很难收敛。

作者提出了一种融合MAE和Cross Entropy的方案,话不多说直接上Loss

\[L_{q}(f(x),y_j) = \frac{1-f_j(x)^q}{q} \]

作者使用了negative box-cox来作为损失函数,乍看和MAE没啥关系。不过改变q的取值,就会发现玄妙所在

  • q->1: \(L=1-f_j(x)\), 就是MAE Loss
  • q->0: 根据洛必达法则,对分子分母同时求导,就会得到\(L=-log(f_j(x))\), 就是Cross Entropy

所以GCE损失函数通过控制q的取值,在MAE和CrossEntropy中寻找折中点。这个和Huber Loss的设计有些相似,只不过Huber是显式的用alpha权重来融合RMSE和MAE,而GCE是隐式的融合。q->1, 对噪声的鲁棒性更好,但更难收敛。作者还提出了截断GCE,对过大的loss进行截断,这里就不细说了~

pytorch实现如下,TF实现见文首链接

class GeneralizeCrossEntropy(nn.Module):
    def __init__(self, q=0.7):
        super(GeneralizeCrossEntropy, self).__init__()
        self.q = q

    def forward(self, logits, labels):
        # Negative box cox: (1-f(x)^q)/q
        labels = torch.nn.functional.one_hot(labels, num_classes=logits.shape[-1])
        probs = F.softmax(logits, dim=-1)
        loss = (1 - torch.pow(torch.sum(labels * probs, dim=-1), self.q)) / self.q
        loss = torch.mean(loss)
        return loss

Symmetric Cross Entropy(SCE)

Symmetric Cross Entropy for Robust Learning with Noisy Labels

作者是从交叉熵的另一个含义出发, 最小化交叉熵实际是为了最小化预测分布和真实分布的KL散度, 二者关联如下,其中H(y)是真实标签的信息熵是个常数

\[\begin{align} KL(y||f(x)) &= \sum ylog(f(x)) - \sum ylog(y) \\ & = H(y, f(x)) - H(y) = CrossEntropy(y, f(x)) - H(y) \end{align} \]

考虑KL散度是非对称的,KL(y||f(x))!=KL(f(x)||y), 前者度量的是使用预测分布对数据进行编码导致的信息损失。然而当y本身存在噪声时,y可能不是正确标签,f(x)才是,这时就需要考虑另一个方向KL散度KL(f(x)||y)。于是作者使用对称KL对应的对称交叉熵(SCE)作为损失函数

\[SCE =CE + RCE = H(y,f(x)) + H(f(x),y) \\ = \sum_j y_jlog(f_j(x)) + \sum_j f_j(x)log(y_j) \]

看到这里多少会有一种作者又拍脑袋了的感觉>.<.不过只需要对RCE的部分做下变换就豁然开朗了。以二分类为例,log(0)无法计算用常数A代替

\[RCE= H(f(x),y) = f_1(x) * log(1) + (1-f_1(x)) *log(0) = A(1-f_1(x)) \]

RCE的部分就是一个MAE!所以SCE本质上是显式的融合交叉熵和MAE!pytorch实现如下,TF实现见文首链接

class SymmetricCrossEntropy(nn.Module):
    def __init__(self, alpha=0.1, beta=1):
        super(SymmetricCrossEntropy, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.epsilon = 1e-10

    def forward(self, logits, labels):
        # KL(p|q) + KL(q|p)
        labels = torch.nn.functional.one_hot(labels, num_classes=logits.shape[-1])
        probs = F.softmax(logits, dim=-1)
        # KL
        y_true = torch.clip(labels, self.eps, 1.0 - self.eps)
        y_pred = probs
        ce = -torch.mean(torch.sum(y_true * torch.log(y_pred), dim=-1))

        # reverse KL
        y_true = probs
        y_pred = torch.clip(labels, self.eps, 1.0 - self.eps)
        rce = -torch.mean(torch.sum(y_true * torch.log(y_pred), dim=-1))

        return self.alpha * ce + self.beta * rce

Peer Loss

  • Peer Loss Functions:Learning from Noisy Labels without Knowning Noise Rates
  • NLNL: Negative Learning for Noisy Labels

Peer Loss相比GCE和SCE只适用于Cross Entropy, 它的设计更加灵活。每个样本的损失函数由常规loss和随机label的loss加权得到,权重为alpha,这里的loss支持任意的分类损失函数。随机label作者通过打乱一个batch里面的label顺序得到~

原理上感觉Peer Loss和NLNL很是相似都是negative learning的思路。对比下二者的损失函数,PL是最小化带噪标签y的损失的同时,最大化模型在随机标签上的损失。NL是直接最大化模型在非真实标签y上的损失。本质上都是negative learning,模型学习的不是x是什么,而是x不是什么,通过推动所有不正确分类的p->0,来得到正确的标签。从这个逻辑上说感觉Peer Loss和NLNL在高维的多分类场景下应该有更好的表现~

\[PL(f(x),y) = L(f(x),y) - \alpha L(f(x),\tilde{y}) \]

\[NL(f(x),y) = L(1-f(x), \tilde{y}) \]

pytorch实现如下,TF实现见文首链接

class PeerLoss(nn.Module):
    def __init__(self, alpha=0.5, loss):
        super(PeerLoss, self).__init__()
        self.alpha = alpha
        self.loss = loss

    def forward(self, preds, labels):
        index = list(range(labels.shape[0]))
        rand_index = random.shuffle(index)
        rand_labels = labels[rand_index]
        loss_true = self.loss(preds, labels)
        loss_rand = self.loss(preds, rand_labels)
        loss = loss_true - self.alpha * loss_rand
        return loss

Bootstrap Loss

Training Deep Neural Networks on Noisy Labels with Bootstrapping

Bootstrap Loss是从预测一致性的角度来降低噪声标签对模型的影响,作者给了soft和hard两种损失函数。

soft Bootstrap是在Cross Entropy的基础上加上预测熵值,在最小化预测误差的同时最小化概率熵值,推动概率趋近于0/1,得到更置信的预测。这里其实用到了之前在半监督时提到的最小熵原则(小样本利器3. 半监督最小熵正则)也就是推动分类边界远离高密度区。

对噪声标签,模型初始预估的熵值会较大(p->0.5), 因为加入了熵正则项,模型即便不去拟合噪声标签,而是向正确标签移动(提高预测置信度降低熵值),也会降低损失函数.不过这里感觉熵正则的引入也有可能使得模型预测置信度过高而导致过拟合

\[L_{soft} = \sum (\beta y_i + (1-\beta) p_i) log(p_i) \]

而Hard Bootstrap是把以上的预测概率值替换为预测概率最大的分类,Hard相比Soft更加类似label smoothing。举个栗子:当真实标签为y=0,噪声标签y=1,预测概率为[0.7,0.3]时,\(\beta=0.9\)时Bootstrap拟合的y实际为[0.1,0.9], 会降低错误标签的置信度,给模型学习其他标签的机会。而当模型预测和标签一致时y值不变,所以不会对正确有样本有太多影响,效果上作者评估也是Hard Bootstrap的效果要显著更好~

\[L_{hard} = \sum (\beta y_i + (1-\beta) argmx(p_i)) log(p_i) \]

pytorch实现如下,TF实现见文首链接

class BootstrapCrossEntropy(nn.Module):
    def __init__(self, beta=0.95, is_hard=0):
        super(BootstrapCrossEntropy, self).__init__()
        self.beta = beta
        self.is_hard = is_hard

    def forward(self, logits, labels):
        # (beta * y + (1-beta) * p) * log(p)
        labels = F.one_hot(labels, num_classes=logits.shape[-1])
        probs = F.softmax(logits, dim=-1)
        probs = torch.clip(probs, self.eps, 1 - self.eps)

        if self.is_hard:
            pred_label = F.one_hot(torch.argmax(probs, dim=-1), num_classes=logits.shape[-1])
        else:
            pred_label = probs
        loss = torch.sum((self.beta * labels + (1 - self.beta) * pred_label) * torch.log(probs), dim=-1)
        loss = torch.mean(- loss)
        return loss

对更多降噪loss感兴趣的朋友望过来https://github.com/subeeshvasu/Awesome-Learning-with-Label-Noise

又到年末填坑时间,争取把今年写了一半的草稿都补完,冲鸭!


Reference

  1. https://zhuanlan.zhihu.com/p/147371861
  2. https://blog.csdn.net/suredied/article/details/113528384
  3. https://zhuanlan.zhihu.com/p/370775044
  4. https://zhuanlan.zhihu.com/p/569526954
  5. https://zhuanlan.zhihu.com/p/299404214

标签:loss,torch,函数,标签,self,labels,beta,损失,简析
From: https://www.cnblogs.com/gogoSandy/p/17018065.html

相关文章

  • 欧拉函数的实现
    目录欧拉函数的定义欧拉函数的一般解法(试除法)线性筛欧拉函数的线性筛法参考资料欧拉函数的定义对于正整数\(n\),欧拉函数\(\varphi(n)\)是小于等于\(n\)的正整数中与......
  • C++用finally函数实现当前函数运行结束自动执行一段代码
    我们的需求可能有这样的需求,fun(){    xx;    xx;    xx;    //希望在这里能自动执行一段设定好的代码,实现一些自动清除啥啥啥的操作}核心......
  • 函数入门
    函数的作用:函数是组织好的,可重复使用的,用来实现单一或相关联功能的代码段。定义一个函数函数代码块以def关键词开头,后接函数标识符名称和圆括号()。任何传入参数......
  • P2398 GCD SUM——欧拉函数
    此题可以拓展为\(\sum\limits^n_{i=1}\sum\limits^m_{j=1}\gcd(i,j)\)结论是\(\sum\limits^{\min(n,m)}_{d=1}\varphi(d)\lfloor\dfrac{n}{d}\rfloor\lfloor\dfrac{m}{......
  • printf函数
    1.语法​printf("HappyNewYear");   printf("%d",a);printf的语法及其简单将要打印的内容用引号括起来即可;如果使用了格式说明符,引号后敲出列表名用","隔开;2.格式说明......
  • sqlserver 获取汉字拼音的首字母(大写)函数
    USE[test]GO/******对象:UserDefinedFunction[dbo].[GetFirstChar]脚本日期:02/22/201916:39:06******/SETANSI_NULLSONGOSETQUOTED_IDENTIFIERONG......
  • Python__19--函数调用的参数传递与变量的作用域
    1函数调用的参数传递形参(形式参数):在函数定义的时候用到的参数没有具体值,只是一个占位的符号,成为形参;实参(实际参数):在调用函数的时候输入的值。实际参数和形式参......
  • 一文了解 Go fmt 标准库输入函数的使用
    耐心和持久胜过激烈和狂热。哈喽大家好,我是陈明勇,今天分享的内容是Gofmt标准库输入函数的使用。如果本文对你有帮助,不妨点个赞,如果你是Go语言初学者,不妨点个关注,一起成......
  • PX01如何实现烧录函数调用--不进行Flicker调整,如烧全代码、烧固定vcom等
    在实际应用中我们经常会碰到烧某些寄存器值为固定值,此时我们并不想进行Flicker调整,按烧录键直接运行烧录函数并实现预期烧录动作,可以的,请参考如下说明。首先,勾选“烧录使......
  • 常见的函数姓名规范
    get获取/set设置,add增加/remove删除create创建/destory移除start启动/stop停止open打开/close关闭,read读取/write写入load载入/save保存,create创......