首页 > 其他分享 >学习笔记:什么是Wasserstein distance

学习笔记:什么是Wasserstein distance

时间:2023-08-23 20:44:51浏览次数:49  
标签:distance 概率分布 笔记 Wasserstein array eq displaystyle gamma

简单地说,就是衡量两个概率分布之间的差异。也可以说是将一个概率分布转换成另一个概率分布要花费多少代价。

图1:在一维空间中的三个概率分布

比如,上图中有三个概率分布f, g, h,我们可以说f与g之间的距离比f与h之间的距离更小。

上述只是感性上的认知,那么如何计算出准确的数值呢?如果我们想求f与g之间差距,Wasserstein distance要求找到一种从f转移到g的方案,使得转移代价最小,用式子表示为:

\[\mathcal{W}[f, g]=\inf _{\gamma \in \Pi[f, g]} \iint \gamma({x}, {y}) d({x}, {y}) \mathrm d {x} \mathrm d {y} \]

这里\(\inf\)表示选择数值最小的方案,\(\Pi[f,g]\)表示所有\(f,g\)的转移方案的集合,\(\gamma\)是一种转移方案,\(d\)是自定义的距离计算方法。合起来的意思就是:从所有\(f\)到\(g\)的转移方案中,选择一个转移代价最小的方案,这个代价就是Wasserstein distance。

转移方案 Transport plan

为了方便讲解什么是转移方案,我们假设两个概率分布分别如下所示

图2:在二维空间中的两个概率分布

我们的目的是让x分布变成y分布,于是可以这样转移:

图3:从x到y的一种转移方案

例如\(y_2\)的0.4,是由\(x_1\)的0.1、\(x_2\)的0.1以及\(x_3\)的0.2组成的。总之,我们将这些转移关系列成一个表格,就是:

图4:从x到y的一种转移方案的表格

容易看出,这个表格实际上就是\(x\)和\(y\)的一个联合分布。也就是全部的值都是>=0的,每一列相加为对应的x值,每一行相加为对应的y值。

也就是说,从所有\(f\)与\(g\)的联合分布中,找到一个联合分布,使得以下式子的值最小:

\[\begin{array}{l} \displaystyle\inf_{\gamma \in \Pi[f, g]} \iint \gamma (x, y) d(x, y) \mathrm d x \mathrm d y \\ \text { s.t. }\left\{\begin{array}{l} \displaystyle\int \gamma(x, y)\mathrm d y=f(x) \\ \displaystyle\int \gamma(x, y)\mathrm d x=g(y) \\ \gamma (x, y) \geqslant 0 \end{array}\right. \end{array} \]

如何计算

我只看了两个离散的概率分布的距离如何计算。

据说python有一个库POT(Python Optimal Transport)可以用,具体的我并不了解。不过,这里我用一个例子,展示如何用线性规划计算两个离散概率分布的Wasserstein distance。

(可能用的符号不严谨,理解即可)
已知两个概率分布分别为\(X=\{x_1,\cdots,x_n\}\)和\(Y=\{y_1,\cdots,y_m\}\),每两个元素之间的距离为\(D=\{d_{1,1},\cdots,d_{n,m}\}\),其中\(d_{i,j}\)表示\(d(x_i,y_j)\)。
我们假设它们的联合概率分布为\(W=\{w_{1,1},\cdots,w_{n,m}\}\),那么问题就是让\(\sum_{i,j}w_{i,j}d_{i,j}\)最小,即解下面这个线性规划问题:

\[ \begin{array}{l} \displaystyle\min\quad w_{1,1}d_{1,1}+w_{1,2}d_{1,2}+\cdots+w_{n,m}d_{n,m} \\ \text { s.t. }\left\{\begin{array}{l} \displaystyle\sum_j w_{1,j}=x_1 \\ \qquad\vdots \\ \displaystyle\sum_j w_{n,j}=x_n \\ \displaystyle\sum_i w_{i,1}=y_1 \\ \qquad\vdots \\ \displaystyle\sum_i w_{i,m}=y_m \\ w_{x,y} \geqslant 0 \end{array}\right. \end{array} \]

在论文DSTAGNN中,相应的代码是这样写的:

from scipy.optimize import linprog

def wasserstein_distance(p, q, D):
    A_eq = []
    for i in range(len(p)):
        A = np.zeros_like(D)
        A[i, :] = 1
        A_eq.append(A.reshape(-1))
    for i in range(len(q)):
        A = np.zeros_like(D)
        A[:, i] = 1
        A_eq.append(A.reshape(-1))
    A_eq = np.array(A_eq)
    b_eq = np.concatenate([p, q])
    D = np.array(D)
    D = D.reshape(-1)
    result = linprog(D, A_eq=A_eq[:-1], b_eq=b_eq[:-1])
    myresult = result.fun
    return myresult

代码的做法大致与上面相同,其中p,q分别是两个概率分布,D和我假设的意义相同,A_eq代表s.t.中等号左边的内容,b_eq代表s.t.中等号右边的数字。他没有写\(w_{x,y}\ge 0\)这个条件,可能默认就是这样的吧。
一个有意思的地方是,代码中A_eq和b_eq都舍掉了最后一个元素,也就是s.t.中的最后一个等号被忽略了。想一下也可以知道,只要前面的那些等式都保证了,最后一个等式是一定能保证的,所以不写也没问题。

参考资料

  1. b站视频 Introduction to the Wasserstein distance
  2. 知乎文章 Wasserstein距离

标签:distance,概率分布,笔记,Wasserstein,array,eq,displaystyle,gamma
From: https://www.cnblogs.com/white514/p/17652748.html

相关文章

  • C++面向对象笔记(转载自黑马程序员)
    C++核心编程本阶段主要针对C++面向对象编程技术做详细讲解,探讨C++中的核心和精髓。1内存分区模型C++程序在执行时,将内存大方向划分为4个区域代码区:存放函数体的二进制代码,由操作系统进行管理的全局区:存放全局变量和静态变量以及常量栈区:由编译器自动分配释放,存放函数的......
  • openGauss学习笔记-48 openGauss 高级数据管理-函数
    openGauss学习笔记-48openGauss高级数据管理-函数openGauss常用的函数如下:48.1数学函数abs(x)描述:绝对值。返回值类型:和输入相同。示例:openGauss=#SELECTabs(-17.4);abs------17.4(1row)cbrt(dp)描述:立方根。返回值类型:doubleprecision示例:openGauss......
  • 《408操作系统 》复习笔记 ③ 第二章 调度与调度算法
    调度当有一堆任务要处理,由于资源有限,没办法同时处理。需要某种规则来决定处理这些任务的顺序作业作业:一个具体的任务用户向系统提交一个作业=用户让操作系统启动一个程序(来处理一个具体的任务)调度的三个层次高级调度(作业调度)按照某种策略从外存的作业后备队列中挑选......
  • Asp.net mvc 笔记
    捕捉处理全局异常自定义一个Attribute继承默认的HandleErrorAttributenamespaceEmpowerApiService.Filter{publicclassCustomerErrorAttribute:HandleErrorAttribute{privatestaticNLog.Loggerlogger=NLog.LogManager.GetCurrentClassLogger();......
  • K8S-安装笔记
    准备:主机环境的前期准备工作个人环境使用3台CentOSLinuxrelease8.5.2111,搭建需要联网,配置yum的k8s仓库等。IP地址:172.17.136.28/29/32/33,主机名对应为:gip28、gip29、gip32、gip33期中k8smaster主节点为gip28注意:以下操作如果没有特殊说明,则默认在所有的节点均执行。一、安装dock......
  • 线性基学习笔记
    \(#definglllonglong\)线性基用处:快速查询一个数是否可以被一堆数异或出来快速查询一堆数可以异或出来的最大\(/\)最小值快速查询一堆数可以异或出来的第\(k\)大值线性基空间复杂度:设有一个序列,其值域为\([1,N]\),我们可以构造一个长度为\(⌈\log_2N⌉\)......
  • [刷题笔记] Luogu P2679 [NOIP2015 提高组] 子串
    ProblemDescription我们可以换个思路。从字符串\(A\)中拿出\(k\)个字串使其变成\(B\)。求有几种不同的方案?Analysis我们发现\(A\)中的一个字符取或者不取影响后面的决策,这并不代表它一定有后效性,我们可以记录这一层状态。和最长公共子序列同理,定义\(f_{i,j,k,l}(\fo......
  • Python基础入门学习笔记 077 GUI的终极选择:Tkinter14
    Tkinter提供了三种标准对话框模块,分别是:messagebox、filedialog、colorchoosermessagebox(消息对话框)实例1:askokcancel函数1fromtkinterimport*23print(messagebox.askokcancel("FishCDemo","发射核弹?"))45mainloop() 实例2:askquestion函数 实例3:asire......
  • Python基础入门学习笔记 074 GUI的终极选择:Tkinter11
    事件绑定对于每个组件来说,可以通过bind()方法将函数或方法绑定到具体的事件上。当被触发的事件满足该组件绑定的事件时,Tkinter就会带着事件描述去调用handler()方法实例1:捕获单击鼠标位置1fromtkinterimport*23root=Tk()45defcallback(event):6prin......
  • Python基础入门学习笔记 075 GUI的终极选择:Tkinter12
    Message组件Message(消息)组件是Label组件的变体,用于显示多行文本信息。Message组件能够自动换行,并调整文本的尺寸使其适应给定得尺寸。实例1:1fromtkinterimport*23root=Tk()4w1=Message(root,text="这是一则消息",width=100)5w1.pack()6w2=Message(root,......