首页 > 其他分享 >分位数回归损失函数代码实现解析

分位数回归损失函数代码实现解析

时间:2022-11-26 13:00:30浏览次数:66  
标签:tau 代码 dy 位数 delta 解析 回归 gamma

目录

1. 绪论

对于分位数回归损失函数,最近看到了两种不同的实现。这种实现和 Bing 上检索到的任何一种分位数损失函数表达形式都不一样。

import keras.backend as K

def QR_error(y_true, y_pred, tau):
    dy = y_pred - y_true
    return K.mean((1.0 - tau) * K.relu(dy) + tau * K.relu(-dy), axis=-1)
def quantile_loss(q, y, y_p):
        e = y-y_p
        return K.mean(K.maximum(q*e, (q-1)*e))

下面,对这两种形式和检索的损失函数对应分析。


2. 分位数回归

分位数回归是统计学和计量经济学中使用的一种回归分析。最小二乘方法估计的是预测变量的条件平均值,而分位数回归估计的是响应变量的条件中位数(或其他分位数)。分位数回归是线性回归的一种扩展,当线性回归的条件不满足时可以使用分位数回归。

关于什么是分位数回归以及分位数回归的推导,下面这些博客或多或少有介绍,但鲜有对损失函数的实现解析。

在此,本文不再赘述分位数回归理论,重点讲解分位数回归损失函数的代码实现,以及它们的不同形式。


3. 分位数回归损失函数

分位数回归时,所使用的损失函数有这么几种表达形式:

  1. \(L_\gamma=\sum_{i=y_i<y_i^p}(\gamma-1) \cdot\left|y_i-y_i^p\right|+\sum_{i=y_i \geq y_i^p}(\gamma) \cdot\left|y_i-y_i^p\right|\)
    来源于回归问题中5种常用损失函数
  2. \(L_{\text {quantile }}=\frac{1}{N} \sum_{i=1}^N Ⅱ_{y>f(x)}(1-\gamma)|y-f(x)|+Ⅱ_{y<f(x)} \gamma|y-f(x)|\)
    来源于回归损失函数2 : HUber loss,Log Cosh Loss,以及 Quantile Loss损失函数 Loss Function 之 分位数损失 Quantile Loss

式中,\(y_i\)是真实值,\(y^p\)或者\(f(x)\)为预测值。两式子表示\(y\)全部取值的分位数损失。

但是,离散随机变量而言,我们经过推导,通过将\({\displaystyle y-y^p}\)相对于\({ y^p}\)的预期损失最小化,可以找到特定的分位数,来源于《QUANTILE REGRESSION》5-6页。可以发现,损失函数是如下形式的:

\[L_\gamma=\frac{(\gamma-1)}{N} \sum_{y_i<y^p}\left(y_i-y^p\right)+\frac{\gamma}{N} \sum_{y_i \geq y^p}\left(y_i-y^p\right) \]

不是上述两式子中的任何一个,因为,他们并不包含绝对值。但是可以将\((\gamma - 1)\)放入在求和号内。


4. \((\gamma - 1)\)的放入

需要注意的是,上式,两个求和号是满足条件的求和

对于单个值,其分位数损失计算方式,按照上式执行。令\(y_i-y_i^p=\delta y\),当\(\delta y<0\)

\[ \rho _{\gamma}(\delta y)=(\gamma-1)\cdot \delta y=(1-\gamma) (-\delta y)=(1-\gamma )|\delta y| \]

当\(\delta y>0\):

\[ \rho_{\gamma}(\delta y)= \gamma \cdot\delta y \]

因此,对于\(N\)个随机变量,他们的总共的分位数损失为:

\[L_\gamma = \frac{1}{N} \left[\sum_{i=y_i < y_i^p} (1-\gamma) \cdot\left|y_i-y_i^p\right|+\sum_{i=y_i > y_i^p}\gamma \cdot\left|y_i-y_i^p\right|\right] \]

5. 程序代码表达

为了最大化的利用 Python 中 Torch 或者 Keras 的函数库,方便自动求导,我们可以将条件求和变为取最大值函数
同样,我们令\(y_i-y_i^p=\delta y\),上式用伪代码可以写为:

mean(max{tau*dy, (tau-1)*dy})

当\(\delta y<0\)时,是取得 (tau-1)*dy ,也就是\(\rho_{\gamma}(\delta y)=(\gamma-1)\cdot \delta y\) ;
当\(\delta y>0\)时,是取得 tau*dy ,也就是\(\rho_{\gamma}(\delta y)= \gamma \cdot\delta y\) ;
最后求和取均值,得到最终的\(L_\gamma\)。因此,利用程序,即可表达为:

def quantile_loss(q, y, y_p):
        e = y-y_p
        return tf.keras.backend.mean(tf.keras.backend.maximum(q*e, (q-1)*e))

同理,下面的代码也是等效的:

import keras.backend as K

def QR_error(y_true, y_pred, tau):
    dy = y_pred - y_true
    return K.mean((1.0 - tau) * K.relu(dy) + tau * K.relu(-dy), axis=-1)

不过,一定注意,dy的定义,dy = y_pred - y_true和e = y-y_p刚好相反,因此,代码中tau-1也要反过来。

标签:tau,代码,dy,位数,delta,解析,回归,gamma
From: https://www.cnblogs.com/AidanLee/p/16927260.html

相关文章

  • .NET 7 看图桌面应用程序源代码下载
    .net7.0刚发布不久,就拿了练手了,制作了一个看图的桌面应用程序,可以看图片信息。请访问以下页面:https://hovertree.com/h/bjag/3osrx05l.htm这里有源代码可以下载.效果......
  • Java下载文件的四种方式详细代码
    原文链接:https://www.jb51.net/article/232182.htm1.以流的方式下载publicHttpServletResponsedownload(Stringpath,HttpServletResponseresponse){try......
  • 让你的Python代码更干净只需简单一步
    你可以将这两个文件拷贝到自己的项目根目录中,然后执行一次pre-commitinstall,这样每次提交代码的时候,都是干净的代码,是不是很方便?说起来容易做起来难,我们都知道代码可读性......
  • 用YAPF让Python代码瞬间从丑陋变漂亮
    要把Python代码写漂亮,必须遵循PEP8Python编码规范:《​​PEP8--StyleGuideforPythonCode​​​》。但记住PEP8规范,是一件非常痛苦的事情,还好Google发布了一个自动整......
  • 用Python代码画世界杯吉祥物拉伊卜(附代码)
    用Python代码画世界杯吉祥物拉伊卜(附代码)世界杯正在火热进行中,世界杯的吉祥物拉伊卜也非常火。本文用Python代码画世界杯吉祥物。不废话,可以直接先看视频效果。视频效果用P......
  • Grasp Detection论文、代码汇总
    文章目录​​2022​​​​End-to-endTrainableDeepNeuralNetworkforRoboticGraspDetectionandSemanticSegmentationfromRGB​​​​2019​​​​AntipodalRob......
  • Knowledge Transfer论文、代码汇总
    文章目录​​2018​​​​LearningDeepRepresentationswithProbabilisticKnowledgeTransfer(ECCV)​​​​2019​​​​LIKEWHATYOULIKE:KNOWLEDGEDISTILLVIAN......
  • java简单解析wsdl文件
    1packagecom.example.demo.api.soap.client.userInterface.controller;234importorg.w3c.dom.Document;5importorg.w3c.dom.NamedNodeMap;6importor......
  • 基于Sklearn机器学习代码实战
    LinearRegression线性回归入门数据生成为了直观地看到算法的思路,我们先生成一些二维数据来直观展现importnumpyasnpimportmatplotlib.pyplotaspltdeftrue_f......
  • 小程序的后端代码语法
    先启动云开发,在数据库中添加json文件来导入数据。将数据权限建议改为:所用用户可读,仅创建者可读写小程序的数据库类似于mongodb后端也是基于js的,你需要了解箭头函数......