问题背景
在不同的框架中对于Device ID的配置方法都略有不同,这里提两种Jax中配置Device ID的方法。
配置环境变量
这个方法是比较流行的,直接在环境变量里面配置:
export CUDA_VISIBLE_DEVICES=1
这样就使得当前shell下运行的程序只能识别到1
号显卡,一般就是第二张显卡了。如果需要配置多张显卡,类似的可以指定:
export CUDA_VISIBLE_DEVICES=0,1
当然,如果是在Python程序中运行的话,也可以直接在Python脚本中配置环境变量:
import os
os.environ['CUDA_VISIBLE_DEVICES']='1'
当然,该语句最好在Jax初始化之前执行。
Jit-device参数配置
这种配置方法会更加具体一点,可以直接指定某个即时编译的函数所使用的device id,如下是一个使用的案例:
import os
# 禁用显存预分配
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
import time
import numpy as np
np.random.seed(0)
import jax
from jax import numpy as jnp
# 创建CPU上的张量
N = 5000000
crd = np.random.random((N, 3))
# 生成显卡对应的对象
gpus = jax.devices()
# 分配对象到不同的显卡上
crd0 = jax.jit(jnp.array, device=gpus[0])(crd[:3000000])
crd1 = jax.jit(jnp.array, device=gpus[1])(crd[3000000:])
time.sleep(5)
在这个案例中,我们在CPU上初始化一个crd
张量,然后通过jax.numpy.array
函数将该张量拷贝到指定的GPU环境中。其中向显卡第一张显卡拷贝了3000000组数据,向第二张显卡拷贝了2000000组数据,这样的话如果在nvidia-smi
中就可以看到两个不同的显存占用了。
总结概要
本文主要介绍了2个在Jax框架中配置显卡Device ID的方法。第一种方法可以使用环境变量进行配置,对于众多的深度学习框架都是可以兼容的。而第二种方案是在Jax即时编译的过程中通过Jax生成的Device对象来控制数据的传输和函数执行的Device ID。
版权声明
本文首发链接为:https://www.cnblogs.com/dechinphy/p/jax-device-id.html
作者ID:DechinPhy
更多原著文章:https://www.cnblogs.com/dechinphy/
请博主喝咖啡:https://www.cnblogs.com/dechinphy/gallery/image/379634.html