比如在比较torch模型输出和onnxruntime输出,
import onnxruntime
ort_session = onnxruntime.InferenceSession("super_resolution.onnx", providers=["CPUExecutionProvider"])
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
————————————————————————————————————
文章目录
- 断言函数
- assert_almost_equal断言精度近似相等
- assert_approx_equal断言有效位近似相等
- assert_array_almost_equal数组近似比较
- assert_array_equal比较数组相等
- assert_allclose比较数组相等
- assert_array_less比较数组大小
- assert_equal比较对象相等
- assert_string_equal 比较字符串相等
- assert_array_almost_equal_nulp比较浮点数
- assert_array_max_ulp多ULP浮点数比较
- 单元测试
import numpy as np
断言函数
单元测试,单元测试是对一部分代码进行测试,可以提高代码质量,可重复性测试等.单元测试通常使用断言函数,在计算时,通常要考虑浮点数比较问题,numpy.testing包中包含很多实用的工具函数.
assert_almost_equal断言精度近似相等
#指定精度为小数点后七位
a = 0.123456789
b = 0.123456780
print(np.testing.assert_almost_equal(a,b,decimal=8))
#None 表示没有异常
#指定精度为小数点后九位
# print(np.testing.assert_almost_equal(a,b,decimal=9))
# AssertionError:
# Arrays are not almost equal to 9 decimals
# ACTUAL: 0.123456789
# DESIRED: 0.12345678
c = 0.122
d = 0.121
print(np.testing.assert_almost_equal(c,d,decimal=3))
None
None
需要注意的是,如果在指定为数上数值相差1则仍然不会报错,如c和d所示.同样的道理,若指定a=0.123456789和b=0.123456788则指定decimal=9是不会出现异常的.
assert_approx_equal断言有效位近似相等
#指定有效位为8
a = 0.123456789
b = 0.123456788
print(np.testing.assert_approx_equal(a,b,significant=8))
#指定有效位为9
# print(np.testing.assert_approx_equal(a,b,significant=9))
# AssertionError:
# Items are not equal to 9 significant digits:
# ACTUAL: 0.123456789
# DESIRED: 0.12345678
None
None
与assert_almost_equal类似,如果在指定为数上数值相差1则仍然不会报错.上面两个函数是精度和有效位的差别,但在实际使用中并没有差别。
assert_array_almost_equal数组近似比较
assert_array_almost_equal数组会首先比较维度,然后再比较数值。
# 精度为8
a = np.array([0,0.123456789])
b = np.array([0,0.123456780])
print(np.testing.assert_array_almost_equal(a,b,decimal=8))
# 精度为9
# print(np.testing.assert_array_almost_equal(a,b,decimal=9))
# Arrays are not almost equal to 9 decimals
# Mismatched elements: 1 / 2 (50%)
# Max absolute difference: 9.e-09
# Max relative difference: 7.29000059e-08
# x: array([0. , 0.123456789])
# y: array([0. , 0.12345678])
c = np.array([0,0.123456780,0]) #三维
# print(np.testing.assert_array_almost_equal(a,c,decimal=8))
# AssertionError:
# Arrays are not almost equal to 8 decimals
# (shapes (2,), (3,) mismatch)
# x: array([0. , 0.12345679])
# y: array([0. , 0.12345678, 0. ])
None
assert_array_equal比较数组相等
严格比较数组的维度与元素值
a = np.array([0,0.123456789])
b = np.array([0,0.123456789])
print(np.testing.assert_array_equal(a,b))
None
assert_allclose比较数组相等
与assert_array_equal不同的是,该函数有atol(绝对容差限)、rtol参数(相对容差限)。比如对于数组a,b,则将测试是否满足
∣
a
−
b
∣
≤
(
a
t
o
l
+
r
t
o
l
∗
∣
b
∣
)
|a-b| \leq (atol+rtol*|b|)
∣a−b∣≤(atol+rtol∗∣b∣)
a = np.array([0,0.123456789])
b = np.array([0,0.123456780])
print(np.testing.assert_allclose(a,b,rtol=1e-7,atol=0))
None
assert_array_less比较数组大小
assert_array_less(a,b)严格比较数组a是否小于b
a = np.array([0,0.1])
b = np.array([0.1,0.2])
print(np.testing.assert_array_less(a,b))
None
assert_equal比较对象相等
这里的对象可以是数组、列表、元组以及字典
# print(np.testing.assert_equal((1,2),(1,3))) #出现异常
print(np.testing.assert_equal((1,2),(1,2)))
print(np.testing.assert_equal([1,2],[1,2]))
print(np.testing.assert_equal({'1':1,'2':2},{'1':1,'2':2}))
None
None
None
assert_string_equal 比较字符串相等
不仅比较字符,还比较大小写
print(np.testing.assert_string_equal('abc','abc'))
# print(np.testing.assert_string_equal('Abc','abc')) #出现异常
None
assert_array_almost_equal_nulp比较浮点数
机器精度(machine epsilon)是指浮点运算中的相对舍入误差上界。即机器允许在机器精度的范围下的误差。
#使用finfo函数确定机器精度
eps = np.finfo(float).eps
print(eps)
a = 1.0
b = a + eps #加上机器精度
c = a + 2*eps #加上2个机器精度 超出范围会出现异常
d = a + 1.4*eps
print(np.testing.assert_array_almost_equal_nulp(a,b))
# print(np.testing.assert_array_almost_equal_nulp(a,c))
# AssertionError: X and Y are not equal to 1 ULP (max is 2)
print(np.testing.assert_array_almost_equal_nulp(a,d))
2.220446049250313e-16
None
None
报错信息中的ULP(uint of Least Precision),指的是浮点数的最小精确度数。根据IEEE 754标准,四则运b算的标准必须保持在半个ULP内。在上面的c中,超过了的1个eps,实际测试中,不超过1.4倍的eps也是不会出现异常的。
assert_array_max_ulp多ULP浮点数比较
该函数可以通过maxulp参数设置多个ULP(默认为1)来增大浮点数比较的允许误差。如果两个浮点数的差距大于所设置或者默认的ULP,则函数assert_array_max_ulp会出现异常,若在误差范围内,则函数会返回两者所差的ULP个数(按第一个小数位四舍五入)。
a = 1.0
b = a + 2*eps
c = a + 1.499*eps
# print(np.testing.assert_array_max_ulp(a,b))
# AssertionError: Arrays are not almost equal up to 1 ULP
print(np.testing.assert_array_max_ulp(a,b,maxulp=2))
print(np.testing.assert_array_max_ulp(a,a))
print(np.testing.assert_array_max_ulp(a,c,maxulp=2))
2.0
0.0
1.0
单元测试
# python中的单元测试
#编写阶层函数
def function(n):
if n==0:
return 1
elif n<0:
raise ValueError("输入的值不合法")
else:
array = np.arange(1,n+1)
return np.cumprod(array)[-1]
print(function(0))
print(function(9))
# function(-1)
# ValueError: 输入的值不合法
1
362880
#利用unittest模块进行单元测试
import unittest
import numpy as np
def function(n):
if n==0:
return 1
elif n<0:
raise ValueError("输入的值不合法")
else:
array = np.arange(1,n+1)
return np.cumprod(array)[-1]
class FactoyiaTest(unittest.TestCase):
"""继承unittest.TestCase类"""
def test_factorial(self):
#计算3的阶层
self.assertEqual(6,function(3))
def test_zero(self):
#计算0的阶层
self.assertEqual(1,function(0))
def test_negative(self):
self.assertRaises(IndexError,function(-1))
if __name__ =='__main__':
unittest.main()
# .E.
# ======================================================================
# ERROR: test_negative (__main__.FactoyiaTest)
# ----------------------------------------------------------------------
# Traceback (most recent call last):
# File "C:/Users/zhj/Desktop/untitled3.py", line 32, in test_negative
# self.assertRaises(IndexError,function(-1))
# File "C:/Users/zhj/Desktop/untitled3.py", line 16, in function
# raise ValueError("输入的值不合法")
# ValueError: 输入的值不合法
# ----------------------------------------------------------------------
# Ran 3 tests in 0.010s
# FAILED (errors=1)
将上面的程序放在单独的.py文件,运行时提示错误提示,也就在使用function时出现了非法输入。
标签:断言,testing,equal,assert,print,np,array,numpy,函数 From: https://www.cnblogs.com/chentiao/p/18070283