使用Pallas为jax编写kernel扩展,需要使用JAX-Triton扩展包。由于Google的深度学习框架Jax主要是面向自己的TPU进行开发的,虽然也同时支持NVIDIA的GPU,但是支持力度有限,目前JAX-Triton只能在TPU设备上正常运行,无法保证在GPU上正常运行。
该结果使用kaggle上的TPU和GPU进行测试获得。
测试时间:
2024-01-18 21:12:09 星期四
标签:Triton,kernel,Pallas,Jax,JAX,TPU,GPU From: https://www.cnblogs.com/devilmaycry812839668/p/17973422