首页 > 编程语言 >Wasserstein距离的python代码实现scipy.stats.wasserstein_distance解释

Wasserstein距离的python代码实现scipy.stats.wasserstein_distance解释

时间:2023-12-18 11:44:31浏览次数:32  
标签:distance stats 箱子 python 距离 wd1 wasserstein 10

在官方文档scipy.stats.wasserstein_distance — SciPy v1.8.0.dev0+1869.838cfbe Manual (osgeo.cn)页面中

scipy.stats.wasserstein_distance(u_values, v_values, u_weights=None, v_weights=None)

对参数u_values,v_value,u_weights,v_weights解释不清晰。

通过看文章Wasserstein距离的直观解释_em距离-CSDN博客对Wasserstein距离的理解和对样例的测试。对搜索引擎多次指向的一篇文章博客EMD距离wasserstein_distance的使用_推土机距离-CSDN博客进行勘误。

库的应用

我们对文章中使用到的样例进行计算

  • 1,正确调用方法为:

    from scipy.stats import wasserstein_distance
    # 对权重未归一化
    wd1 = wasserstein_distance([0,1,2,3],[0,1,2,3], [4,2,1,3],[3,1,2,4])
    # wd1 = 0.4
    
    # 前面两个u_values,v_value值不同
    wd2 = wasserstein_distance([1,2,3,4],[1,2,3,4], [4,2,1,3],[3,1,2,4])
    # wd2 = 0.4
    # u_values,v_value 值变化
    wd3 = wasserstein_distance([0,1,3,4],[0,1,3,4], [4,2,1,3],[3,1,2,4])
    # wd3 = 0.6
    # 对示例中的进行归一化处理
    wd4 = wasserstein_distance([0,1,2,3],[0,1,2,3], [4/10,2/10,1/10,3/10],[3/10,1/10,2/10,4/10])
    # wd4 = 0.4
    

    前面的u_values,v_value可以理解为槽位,后面的权重u_weights,v_weights理解为箱子Wasserstein距离 就是,我们现在需要将上方的箱子,挪成下面的箱子的堆放形状,移动箱子的距离最小。

    wd1和wd2: wd1和wd2中槽位的相距都是1,虽然位置在坐标中不同,但计算出来的距离是一样的。可以这样理解,槽位的位置不管在小明家还是小丽家,只要他们箱子摆放形状和槽位的方位和距离保持一致,搬动箱子的距离是固定的。

    wd1和wd3: wd1和wd3中槽位在坐标x的方位不变,但是wd3中,槽位1,2和槽位3,4距离中间变大,最终距离计算为0.6比wd1=0.4大。可以这样理解,虽然都是搬箱子,且方位前后一致,但是,由于箱子中间距离变大,搬箱子搬动的距离增大。

    wd1和wd4: wd1和wd3中,后面权重不一样,但是wd1会默认转换为wd4.

  • 错误调用示范

    文章EMD距离wasserstein_distance的使用_推土机距离-CSDN博客 中按照其写的方式,直接调用为:

    x0 = wasserstein_distance([0, 1, 3], [0, 1, 3])		#相同的分,分布的差异为0
    x1 = wasserstein_distance([3, 1, 0], [0, 1, 3])		#相同分布,但元素顺序打乱,分布的差异值仍然为
    

    为什么x0和x1计算出来的值一样?因为x0和x1的槽位是一样的,但是权重系统固定分配一样,在槽位上默认摆放的箱子数量是相同的。和下面代码调用效果一样。x1可以理解为,在坐标3,1,0处各有一个箱子,现在要将箱子搬到0,1,3处各有一个箱子,那显然我们不用搬动任何箱子即可实现。故距离为0。

    x0 = wasserstein_distance([0, 1, 3], [0, 1, 3],[1,1,1],[1,1,1])		#相同的分,分布的差异为0
    x1 = wasserstein_distance([3, 1, 0], [0, 1, 3],[1,1,1],[1,1,1])		#相同分布,但元素顺序打乱,分布的差异值仍然为 
    

总结

调用该库计算距离时,应该将要计算的分布放在权重位置,即u_weights,v_weights。而在u_values,v_value中按照使用习惯。只需要传递分布的下标即可

好文推荐

Wasserstein距离的直观解释_em距离-CSDN博客

标签:distance,stats,箱子,python,距离,wd1,wasserstein,10
From: https://www.cnblogs.com/jlxa162hhf/p/17910784.html

相关文章

  • maturin 方便发布基于rust 的python 包工具
    maturin是PyO3团队开发的,方便我们开发基于rust的python包,比如PyO3的使用文档中就使用了此工具安装&使用安装(可选,可以基于venv安装)可以基于pip以及pipx pipxinstallmaturin创建一个简单项目python-mvenv.venvsource.venv......
  • 【python】浏览器自动化Selenium安装WebDriver最新Chrome驱动
    selenium 是浏览器自动化测试框架,原本被用于网页测试。但到了爬虫领域,它又成为了爬虫的好帮手。selenium 可以控制你的浏览器,模仿人浏览网页,从而获取数据,自动操作等。首先打开 Chrome浏览器,依次点击浏览器右上角的 三个点 - 帮助 - 关于GoogleChrome查看浏览器版本信......
  • Python定位错误:段错误 (核心已转储)
    技术背景在各种编程语言中都有可能会遇到这样一个报错:“段错误(核心已转储)”。显然是编写代码的过程中有哪里出现了问题,但是这个报错除了这几个字以外没有任何的信息,我们甚至不知道是哪一行的代码出现了这个问题。解决方案在python中可以引用一个faulthandler的函数,就可以显......
  • python的orjson
    简介首先我们先来了解下orjson的优缺点:可以将datetime、date和time实例序列化为RFC3339格式,例如:"2022-06-12T00:00:00+00:00"序列化numpy.ndarray实例的速度比其他库快4-12倍,但使用的内存更少,约为其他库的1/3左右输出速度是标准库的10到20倍序列化的结果是bytes类型,而不是......
  • python之DataClass
    Python在版本3.7(PEP557)中引入了dataclass。dataclass允许你用更少的代码和更多的开箱即用功能来定义类。下面定义了一个具有两个实例属性name和age的常规Person类: classPerson:def__init__(self,name,age):self.name=nameself......
  • python迭代器理解
    目录什么是迭代器?为什么要有迭代?迭代器的优缺点什么是可迭代对象?什么是迭代器对象呢?什么是迭代器?在学习for循环的时候,听到了一个词叫可迭代对象。那什么是可迭代对象?了解后又知道了迭代,可迭代,迭代器这些名词,那这些到底是什么意思呢?我们先知道为什么for循环不像whlie循环一样,使......
  • Machine Learning in Python
    MetricFormulaInterpretationAccuracy$\frac{TP+TN}{TP+TN+FP+FN}$OverallperformanceofmodelPrecision$\frac{TP}{TP+FN}$HowaccuratethepositivepredictionsareRecallSensitivity$\frac{TP}{TP+FP}$Coverageofactualpositivesampl......
  • python网站创建011:函数、序列化
    经典函数:JS经典函数定义function函数名(a1,a2){vara3=a1+a2returna3}varres=函数名(11,22)console.log(res) 那如何在浏览器中查看打印结果呢?  匿名函数:没名字的函数,如果只有一个地方用到,则直接使用......
  • 系统聚类思想及其Python实现
    系统聚类是一种将对象或数据按照其相似性进行分组的方法。与传统聚类方法不同,系统聚类通过构建一颗层次树(或称为树状图或谱系图)来展现数据的层次结构。在系统聚类中,数据点最初被视为单独的簇,然后通过逐步合并或分裂,形成一个层次结构的聚类结果。这种层次结构可用于分析数据的不同......
  • Python——作图
    Python——作图百分比堆积柱状图importmatplotlib.pyplotaspltimportnumpyasnp#定义数据categories=['CatA','CatB','CatC']values1=[20,30,15]values2=[10,25,20]values3=[5,10,15]#将数据转化为相对百分比total=np.array(v......