相关:
https://jax.readthedocs.io/en/latest/pallas/tpu.html
向量计算设备,如:GPU、TPU等,都是通过向量计算来进行加速的,因此在这类设备中进行向量计算的计算单元是成百上千的,但是进行结构控制的电路单元比较少,可以基本认为在向量设备中进行流程控制是标量的,而不是向量的,也就是说在一个计算周期里面并不能对多个计算单元进行流程控制,而只能是对单个或几个(若干个)计算单元进行流程控制,因为设备中并不能为每一个计算单元配备一个流程控制单元,因为流程控制在整个计算过程的占比太小。但是,在设备的kernel函数中我们往往是无可避免的使用到流程控制语句,因此在编程时就会采用几种方法:
-
直接使用标量性质的流程控制语句,在流程控制语句执行时极大的降低设备的运算效率;
-
使用循环展开方式,也就是把循环结构改写成顺序执行结构,这也是jax框架中所使用的方法;
-
使用向量规约方法来实现原有的循环和判断流程,如jax的cond、fori_loop、for_loop操作。但是该种情况并不通用,很多标量的循环和判断结构无法转换为向量归约,但无法进行向量规约时该种方法不可用。
在jax框架中,进行循环、判断展开操作是自动的,是由编译器自动进行的,不同于CUDA这种底层原语操作,不需要手动编写展开代码,但是也正因为是编译器自动展开的,因此我们需要避免把多个计算操作写在训练结构下,而是应该把多个循环结构下的计算操作写在一起并做成跳转调用操作。
比如:
for i in range(3):
operation_1(i)
operation_2(i)
operation_3(i)
展开后为:
operation_1(0)
operation_2(0)
operation_3(0)
operation_1(1)
operation_2(1)
operation_3(1)
operation_1(2)
operation_2(2)
operation_3(2)
如果我们把展开编译前的代码写为:
def fun(i):
operation_1(i)
operation_2(i)
operation_3(i)
for i in range(3):
fun(i)
那么展开后代码为:
def fun(i):
operation_1(i)
operation_2(i)
operation_3(i)
fun(0)
fun(1)
fun(2)
资料:
标签:Google,jax,流程,循环,计算,fun,operation,向量 From: https://www.cnblogs.com/devilmaycry812839668/p/17973020