问题描述:
WARNING - No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
解决方案:
- 测试如下代码,如果为True 则执行第二步。
import torch print(torch.cuda.is_available())
- 测试如下代码
import jax print(jax.devices())
- 安装指定版本的cudatoolkit 和 cudnn, 其中 cudatoolkit的版本使用nvidia-smi确定,cudnn的版本使用conda search cudnn确定
conda install cudatoolkit=12.0 cudatoolkit=8.2
- 通过以下命令安装GPU版本的jaxlib
pip install -U jaxlib==0.4.16+cuda12.cudnn88 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- 卸载并重新安装scvi-tools
pip uninstall scvi-tools pip install scvi-tools