一、32位浮点数
32位浮点数(Single Precision Floating Point)是一种用于表示实数的标准格式,由IEEE 754标准定义。
表示方法
32位浮点数由三部分组成:
- 符号位(S):1位,表示数值的正负。
- 指数位(E):8位,用于表示数值的范围。
- 尾数位(M):23位,表示有效数字。
其表示公式为:
( − 1 ) S × 1. M × 2 ( E − 127 ) ( − 1 ) S × 1. M × 2 ( E − 127 ) ( − 1 ) S × 1. M × 2 ( E − 127 ) (−1)S×1.M×2(E−127)(-1)^S \times 1.M \times 2^{(E-127)}(−1)S×1.M×2(E−127) (−1)S×1.M×2(E−127)(−1)S×1.M×2(E−127)(−1)S×1.M×2(E−127)
- 符号位 S 决定数的正负,0表示正数,1表示负数。
- 指数位 E 采用偏移量为127的表示方法,即实际指数为 E−127。
- 尾数位 M 代表小数部分,实际有效数字为 1.M。
优缺点
优点:
- 范围广:可以表示非常大的数和非常小的数。
- 精度高:对大多数应用场景下的计算精度需求都能满足。
缺点:
- 计算复杂:浮点运算相对耗时,硬件实现复杂。
- 存储空间大:占用32位存储空间。
应用场景
32位浮点数广泛用于科学计算、图形处理、机器学习等需要高精度和大范围数值表示的领域。
二、8位定点数
8位定点数(Fixed Point)是一种用于表示小范围数值的表示方法,适用于嵌入式系统和资源受限的环境。
表示方法
8位定点数的表示方法有多种,常见的是 Q7 格式,即:
- 符号位(S):1位,表示数值的正负。
- 整数位:0位。
- 小数位:7位,表示小数部分。
其表示公式为:
( − 1 ) S × ( M 27 ) ( − 1 ) S × ( M 2 7 ) ( − 1 ) S × ( 27 M ) (−1)S×(M27)(-1)^S \times \left(\frac{M}{2^7}\right)(−1)S×(27M) (−1)S×(M27)(−1)S×(27M)(−1)S×(27M)
- 符号位 S 决定数的正负,0表示正数,1表示负数。
- 尾数位 M 直接表示小数部分。
优缺点
优点:
- 计算简单:定点数运算简单,硬件实现高效。
- 存储空间小:仅占用8位存储空间,节省内存。
缺点:
- 范围有限:只能表示较小范围的数值。
- 精度有限:小数位越多,能表示的范围越小。
应用场景
8位定点数常用于嵌入式系统、DSP(数字信号处理)和物联网设备中,这些场景对计算资源和存储空间要求严格,且数值范围和精度需求较低。
三、比较与选择
精度与范围
- 32位浮点数:适用于需要高精度和大范围数值的应用场景,如科学计算和机器学习。
- 8位定点数:适用于资源受限且数值范围和精度要求较低的场景,如嵌入式系统和简单的信号处理。
计算复杂性
- 32位浮点数:计算复杂,硬件实现成本高。
- 8位定点数:计算简单,硬件实现成本低。
存储需求
- 32位浮点数:占用32位存储空间。
- 8位定点数:占用8位存储空间,更节省内存。
实际应用
- 32位浮点数:广泛用于需要高精度和广泛范围的领域,如科学计算、图形处理、机器学习等。
- 8位定点数:广泛用于嵌入式系统、物联网设备和DSP等资源受限的领域。
四、SML实践
接下来实践一个代码,来展示如何使用 SecretFlow 库和 SPU(Secure Processing Unit)设备来执行隐私保护的计算任务。代码涵盖了网络配置、数学运算的模拟、数据加载和处理、以及网络模拟操作等多个方面。模拟网络条件变化前后对计算任务性能的影响
import secretflow as sf
import spu
import os
import numpy as np
import jax.numpy as jnp
import jax
import jax.lax
import spu.utils.simulation as spsim
import spu.spu_pb2 as spu_pb2
from functools import partial
network_conf ={
"parties":{
"alice":{
"address":"alice:8000",
},
"bob":{
"address":"bob:8000",
},
},
}
party = os.getenv("SELF_PARTY","alice")
sf.shutdown()
sf.init(
address="127.0.0.1:6379",
cluster_config={**network_conf,"self_party": party},
log_to_driver=True,
)
!yum install -y iproute-tc
#we know that dk is wrong when |x| is very small
# Let us try it.(we only show part here.)
# define some test function and data used in simulation
#def test_square_and_sum_when_x_small(x):
# return jnp.sum(jnp.square(x))
def compute_dk_func(x, eps=1e-6, iterations=100):
result = x
for _ in range(iterations):
result = jnp.square(result)
result = jax.lax.rsqrt(jnp.sum(result) + eps)
return result
x = np.array([1e-5]* 10)
#First,we run SPU with simulator
#Indeed,simulation can be run within single node.
# a.run with CHEETAH
sim_che = spsim.Simulator.simple(2, spu_pb2.ProtocolKind.CHEETAH, spu_pb2.FieldType. FM64)
spsim.sim_jax(sim_che, test_square_and_sum_when_x_small)(x)
#b.run with ABY3
# this time,we alse print some profile stats.
config_aby = spu.RuntimeConfig(
protocol=spu_pb2.ProtocolKind.ABY3,
field=spu.FieldType.FM64,
fxp_fraction_bits=18,
enable_hal_profile=True,
enable_pphlo_profile=True,
)
sim_aby=spsim.Simulator(3,config_aby)
print(spsim.sim_jax(sim_aby, test_square_and_sum_when_x_small)(x))
!tc qdisc del dev eth0 root
!ping -c 4 bob
!ping -c 4 alice
# Emulation should be run from source in SPU, so we use Secretflow here to do the efficiency experiments.
# You can use the similar trick for emulation directly in SPU.
def compute_dk_func(x,eps=1e-6):
return jax.lax.rsqrt(jnp.sum(jnp.square(x))+ eps)
x = np.random.rand(1_000_000)
# SPU settings
cluster_def={
'nodes':[
{'party':'alice','id':'local:0','address': 'alice'+ ':12945'},
{'party':'bob','id':'local:1','address':'bob'+':12945'},
],
'runtime_config':{
#SEMI2K support 2/3 PC,ABY3 only support 3PC, CHEETAH only support 2PC.
# pls pay attention to size of nodes above, nodes size need match to Pc setting.
'protocol':spu.spu_pb2.SEMI2K,
'field':spu.spu_pb2.FM64
},
}
alice_device = sf.PYU("alice")
bob_device = sf.PYU("bob")
spu_device =sf.SPU(cluster_def)
#first, load data to PYU
alice_data = alice_device(lambda x: x)(x)
#SPU may need some init, so we run this twice...
ret = spu_device(compute_dk_func)(alice_data)
sf.reveal(ret);
#调整网络状况,限制带宽和延迟
!tc qdisc add dev eth0 root handle 1: tbf rate 100mbit burst 128kb limit 10000
!tc qdisc add dev eth0 parent 1:1 handle 10: netem delay 10msec limit 8000
!ping -c 4 bob
!ping -c 4 alice
#调整网络状况后再执行计算任务,此次任务执行时间应该变长了
ret = spu_device(compute_dk_func)(alice_data)
sf.reveal(ret);
!tc qdisc del dev eth0 root
接下来作一个jax.numpy.digitize 的MPC-friendly的实现,原课程的案例如下:
x=np.array([0.0, 0.2, 6.4, 3.0, 1.6, 12.0])
bins =np.array([0.0, 1.0, 2.5, 4.0, 10.0])
jnp.digitize(x,bins)
config_aby = spu.RuntimeConfig(
protocol=spu_pb2.ProtocolKind.ABY3,
field=spu.FieldType.FM64,
fxp_fraction_bits=18,
enable_hal_profile=True,
enable_pphlo_profile=True,
)
sim_aby=spsim.Simulator(3,config_aby)
print(spsim.sim_jax(sim_aby, jnp.digitize)(x,bins))
# MPC-friendly jnp.digitize example
# Note: here,we only deal the case of `right=False', other cases are similar.
def my_digitize(x, bins):
# vectorize
com=x.reshape(*x.shape, -1)>= bins
#count the number ofxthat exceeds bins
return jnp.sum(com, axis=1)
print(spsim.sim_jax(sim_aby,my_digitize)(x,bins))
在原案例中的实现在每次迭代中都对整个数组进行平方和开方操作,这可能导致不必要的计算负担。现在使用广播和累计求和直接计算看下效果
def optimized_digitize(x, bins):
# 使用广播和累积求和直接计算
return jnp.sum(x[:, None] >= bins, axis=-1)
print(spsim.sim_jax(sim_aby,optimized_digitize)(x,bins))
任务时间上快了一些,通信没什么变化,还是有一定效果的
标签:jnp,jax,09,alice,SPU,SML,127,spu,sim From: https://blog.csdn.net/sunxi1900/article/details/140088102