diff --git a/Dockerfile.tmpl b/Dockerfile.tmpl index cc94af52..0210890a 100644 --- a/Dockerfile.tmpl +++ b/Dockerfile.tmpl @@ -123,7 +123,7 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \ # Install JAX {{ if eq .Accelerator "gpu" }} -RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases.html && \ +RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ /tmp/clean-layer.sh {{ else }} RUN pip install jax[cpu] && \ diff --git a/tests/test_jax.py b/tests/test_jax.py index da95c537..b5e0898e 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -3,6 +3,7 @@ import os import time +import jax import jax.numpy as np from common import gpu_test @@ -21,4 +22,4 @@ def test_grad(self): def test_backend(self): expected_backend = 'cpu' if len(os.environ.get('CUDA_VERSION', '')) == 0 else 'gpu' - + self.assertEqual(expected_backend, jax.default_backend())