相关:
https://jax.readthedocs.io/en/latest/device_memory_profiling.html
代码:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
显存分析的示意图:
jax.random.normal 操作,经过jit编译:
jnp.tile 操作,不经过jit编译:
标签:显存,jnp,Jax,jax,编译,jit,memory From: https://www.cnblogs.com/devilmaycry812839668/p/17974752