From 80ad73f3ed09b716800435dcae5387b85c1d137b Mon Sep 17 00:00:00 2001 From: Vincent Roseberry Date: Wed, 1 Sep 2021 20:14:07 +0000 Subject: [PATCH] Upgrade JAX - Use new pip install syntax. - The proper version of jaxlib is installed based on the version of jax. --- Dockerfile | 2 +- gpu.Dockerfile | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 188a4728..010b48d9 100644 --- a/Dockerfile +++ b/Dockerfile @@ -400,7 +400,7 @@ RUN pip install flashtext && \ pip install pycrypto && \ pip install easyocr && \ # Keep JAX version in sync with GPU image. - pip install jax==0.2.16 jaxlib==0.1.68 && \ + pip install jax[cpu]==0.2.19 && \ # ipympl adds interactive widget support for matplotlib pip install ipympl==0.7.0 && \ pip install pandarallel && \ diff --git a/gpu.Dockerfile b/gpu.Dockerfile index 4db3b9b5..8f6903e1 100644 --- a/gpu.Dockerfile +++ b/gpu.Dockerfile @@ -77,7 +77,7 @@ RUN pip uninstall -y lightgbm && \ /tmp/clean-layer.sh # Install JAX (Keep JAX version in sync with CPU image) -RUN pip install jax==0.2.16 jaxlib==0.1.68+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \ +RUN pip install jax[cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION]==0.2.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html && \ /tmp/clean-layer.sh # Reinstall packages with a separate version for GPU support.