首页 > 其他分享 >numpy中比较两个数字的断言函数

numpy中比较两个数字的断言函数

时间:2024-03-13 11:47:54浏览次数:25  
标签:断言 testing equal assert print np array numpy 函数

比如在比较torch模型输出和onnxruntime输出,

import onnxruntime

ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])

def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

————————————————————————————————————

文章目录

 

import numpy as  np
 

断言函数

 

单元测试,单元测试是对一部分代码进行测试,可以提高代码质量,可重复性测试等.单元测试通常使用断言函数,在计算时,通常要考虑浮点数比较问题,numpy.testing包中包含很多实用的工具函数.

assert_almost_equal断言精度近似相等

#指定精度为小数点后七位
a = 0.123456789
b = 0.123456780
print(np.testing.assert_almost_equal(a,b,decimal=8))
#None 表示没有异常

#指定精度为小数点后九位
# print(np.testing.assert_almost_equal(a,b,decimal=9))

# AssertionError: 
# Arrays are not almost equal to 9 decimals
#  ACTUAL: 0.123456789
#  DESIRED: 0.12345678

c = 0.122
d = 0.121
print(np.testing.assert_almost_equal(c,d,decimal=3))
 
None
None
 

需要注意的是,如果在指定为数上数值相差1则仍然不会报错,如c和d所示.同样的道理,若指定a=0.123456789和b=0.123456788则指定decimal=9是不会出现异常的.

assert_approx_equal断言有效位近似相等

#指定有效位为8
a = 0.123456789
b = 0.123456788
print(np.testing.assert_approx_equal(a,b,significant=8))

#指定有效位为9
# print(np.testing.assert_approx_equal(a,b,significant=9))

# AssertionError: 
# Items are not equal to 9 significant digits:
#  ACTUAL: 0.123456789
#  DESIRED: 0.12345678
 
None
None
 

与assert_almost_equal类似,如果在指定为数上数值相差1则仍然不会报错.上面两个函数是精度和有效位的差别,但在实际使用中并没有差别。

assert_array_almost_equal数组近似比较

assert_array_almost_equal数组会首先比较维度,然后再比较数值。

# 精度为8
a = np.array([0,0.123456789])
b = np.array([0,0.123456780])
print(np.testing.assert_array_almost_equal(a,b,decimal=8))

# 精度为9
# print(np.testing.assert_array_almost_equal(a,b,decimal=9))

# Arrays are not almost equal to 9 decimals

# Mismatched elements: 1 / 2 (50%)
# Max absolute difference: 9.e-09
# Max relative difference: 7.29000059e-08
#  x: array([0.         , 0.123456789])
#  y: array([0.        , 0.12345678])

c = np.array([0,0.123456780,0]) #三维
# print(np.testing.assert_array_almost_equal(a,c,decimal=8))

# AssertionError: 
# Arrays are not almost equal to 8 decimals

# (shapes (2,), (3,) mismatch)
#  x: array([0.        , 0.12345679])
#  y: array([0.        , 0.12345678, 0.        ])
 
None
 

assert_array_equal比较数组相等

严格比较数组的维度与元素值

a = np.array([0,0.123456789])
b = np.array([0,0.123456789])
print(np.testing.assert_array_equal(a,b))
 
None
 

assert_allclose比较数组相等

与assert_array_equal不同的是,该函数有atol(绝对容差限)、rtol参数(相对容差限)。比如对于数组a,b,则将测试是否满足
∣ a − b ∣ ≤ ( a t o l + r t o l ∗ ∣ b ∣ ) |a-b| \leq (atol+rtol*|b|) ∣a−b∣≤(atol+rtol∗∣b∣)

a = np.array([0,0.123456789])
b = np.array([0,0.123456780])
print(np.testing.assert_allclose(a,b,rtol=1e-7,atol=0))
 
None
 

assert_array_less比较数组大小

assert_array_less(a,b)严格比较数组a是否小于b

a = np.array([0,0.1])
b = np.array([0.1,0.2])
print(np.testing.assert_array_less(a,b))
 
None
 

assert_equal比较对象相等

这里的对象可以是数组、列表、元组以及字典

# print(np.testing.assert_equal((1,2),(1,3)))  #出现异常

print(np.testing.assert_equal((1,2),(1,2)))

print(np.testing.assert_equal([1,2],[1,2]))

print(np.testing.assert_equal({'1':1,'2':2},{'1':1,'2':2}))
None
None
None
 

assert_string_equal 比较字符串相等

不仅比较字符,还比较大小写

print(np.testing.assert_string_equal('abc','abc'))

# print(np.testing.assert_string_equal('Abc','abc')) #出现异常
 

 

None

assert_array_almost_equal_nulp比较浮点数

机器精度(machine epsilon)是指浮点运算中的相对舍入误差上界。即机器允许在机器精度的范围下的误差。

#使用finfo函数确定机器精度
eps = np.finfo(float).eps
print(eps)

a = 1.0
b = a + eps #加上机器精度
c = a + 2*eps #加上2个机器精度 超出范围会出现异常
d = a + 1.4*eps
print(np.testing.assert_array_almost_equal_nulp(a,b))

# print(np.testing.assert_array_almost_equal_nulp(a,c))
# AssertionError: X and Y are not equal to 1 ULP (max is 2)

print(np.testing.assert_array_almost_equal_nulp(a,d))
 
2.220446049250313e-16
None
None
 

 

报错信息中的ULP(uint of Least Precision),指的是浮点数的最小精确度数。根据IEEE 754标准,四则运b算的标准必须保持在半个ULP内。在上面的c中,超过了的1个eps,实际测试中,不超过1.4倍的eps也是不会出现异常的。

assert_array_max_ulp多ULP浮点数比较

该函数可以通过maxulp参数设置多个ULP(默认为1)来增大浮点数比较的允许误差。如果两个浮点数的差距大于所设置或者默认的ULP,则函数assert_array_max_ulp会出现异常,若在误差范围内,则函数会返回两者所差的ULP个数(按第一个小数位四舍五入)。

a = 1.0
b = a + 2*eps
c = a + 1.499*eps
# print(np.testing.assert_array_max_ulp(a,b))
# AssertionError: Arrays are not almost equal up to 1 ULP

print(np.testing.assert_array_max_ulp(a,b,maxulp=2))

print(np.testing.assert_array_max_ulp(a,a))

print(np.testing.assert_array_max_ulp(a,c,maxulp=2))
 
2.0
0.0
1.0
 

单元测试

# python中的单元测试
#编写阶层函数
def function(n):
    if n==0:
        return 1
    elif n<0:
        raise ValueError("输入的值不合法")
    else:
        array = np.arange(1,n+1)
        return np.cumprod(array)[-1]
    

print(function(0))
print(function(9))
# function(-1)
# ValueError: 输入的值不合法
 
1
362880
#利用unittest模块进行单元测试
import unittest
import numpy as np
def function(n):
    if n==0:
        return 1
    elif n<0:
        raise ValueError("输入的值不合法")
    else:
        array = np.arange(1,n+1)
        return np.cumprod(array)[-1]
    

class FactoyiaTest(unittest.TestCase):
    """继承unittest.TestCase类"""
    def test_factorial(self):
        #计算3的阶层
        self.assertEqual(6,function(3))
        
    def test_zero(self):
        #计算0的阶层
        self.assertEqual(1,function(0))
    def test_negative(self):
        self.assertRaises(IndexError,function(-1))

if  __name__ =='__main__':
    unittest.main()
    
# .E.
# ======================================================================
# ERROR: test_negative (__main__.FactoyiaTest)
# ----------------------------------------------------------------------
# Traceback (most recent call last):
#   File "C:/Users/zhj/Desktop/untitled3.py", line 32, in test_negative
#     self.assertRaises(IndexError,function(-1))
#   File "C:/Users/zhj/Desktop/untitled3.py", line 16, in function
#     raise ValueError("输入的值不合法")
# ValueError: 输入的值不合法

# ----------------------------------------------------------------------
# Ran 3 tests in 0.010s

# FAILED (errors=1)
 

 

将上面的程序放在单独的.py文件,运行时提示错误提示,也就在使用function时出现了非法输入。

标签:断言,testing,equal,assert,print,np,array,numpy,函数
From: https://www.cnblogs.com/chentiao/p/18070283

相关文章

  • JavaScript学习--splice()函数入门与精通
    一、splice入门splice方法:通过删除(两个参数)或替换现有元素(三个参数)或者原地添加新的元素(三个参数)来修改数组,并以数组形式返回被修改的内容。此方法会改变原数组。参数:index——必需。整数,规定添加/删除项目的位置,使用负数可从数组结尾处规定位置(从1开始)。howmany——必需......
  • 【习题】随机变量与分布函数
    [T0301]设随机变量\(\xi\)取值于\([0,1]\),若\(P\{x\le\xi<y\}\)只与长度\(y-x\)有关(对一切\(0\lex\ley\le1\)).试证\(\xi\simU[0,1]\).证不妨设\(P\{x\le\xi<y\}=f(y-x)\).令\(x=0\),则有\(P\{0\le\xi<y\}=f(y)\).注意到对\(\for......
  • SqlServer函数大全三十八:DATEPART函数
    在SQLServer中,DATEPART 函数用于返回日期/时间值的指定部分的整数。与 DATENAME 函数不同,DATEPART 返回的是一个数字,而不是一个字符串。这对于需要进行数学计算或比较的场合特别有用。函数的语法如下:sql复制代码DATEPART(datepart,date)其中:datepart 是你想......
  • SqlServer函数大全三十九:CONVERT函数
    在SQLServer中,CONVERT 函数用于将一种数据类型转换为另一种数据类型。这在处理日期、时间、数字和其他数据类型时非常有用,尤其是当你需要确保数据以特定的格式或类型进行存储或显示时。函数的语法如下:sql复制代码CONVERT(data_type[(length)],expression[,style])......
  • SqlServer函数大全三十五:DATEDIFF(返回日期和时间的边界数)函数
    在SQLServer中,DATEDIFF 函数用于返回两个日期之间的边界数差异。这个函数可以计算两个日期之间的年、月、日、小时、分钟、秒或周数差异。DATEDIFF 函数的语法如下:sql复制代码DATEDIFF(datepart,startdate,enddate)datepart 是指定要返回日期部分的参数,比如......
  • linux Shell 命令行-07-func 函数
    拓展阅读linuxShell命令行-00-intro入门介绍linuxShell命令行-02-var变量linuxShell命令行-03-array数组linuxShell命令行-04-operator操作符linuxShell命令行-05-test验证是否符合条件linuxShell命令行-06-flowcontrol流程控制linuxShell命令行-07-f......
  • Vue3 组合函数 element-plus table数据滚动播放
    Vue滚动播放组合函数import{onMounted,onUnmounted}from"vue";exportfunctioncreateScroll(tableRef){lettimer=null;functionstartScroll(){consttable=tableRef.value.layout.table.refs;consttableWrapper=table.bodyWrapper.f......
  • 为什么defineProps宏函数不需要从vue中import导入?
    前言我们每天写vue代码时都在用defineProps,但是你有没有思考过下面这些问题。为什么defineProps不需要import导入?为什么不能在非setup顶层使用defineProps?defineProps是如何将声明的props自动暴露给模板?举几个例子我们来看几个例子,分别对应上面的几个问题。先来看一个正常的......
  • C语言字符函数和字符串函数
    前言今天这篇博客咱们一起来认识一些特殊的函数,在编程的过程中,我们经常要处理字符和字符串,为了方便字符和字符串,C语言提供了一些库函数,让我们一起看看这些函数都有什么功能吧!!!个人主页:小张同学zkf若有问题评论区见感兴趣就关注一下吧目录 1.字符分类函数2.字符......
  • 函数
    一、字符串函数常用的几个如下:函数功能CONCAT(s1,s2,…,sn)字符串拼接,将s1,s2,…,sn拼接成一个字符串LOWER(str)将字符串全部转为小写UPPER(str)将字符串全部转为大写LPAD(str,n,pad)左填充,用字符串pad对str的左边进行填充,达到n个字符串长度RPAD(......