From cd6e9d6f62d2d60056038dc7e035c8f1aa07637e Mon Sep 17 00:00:00 2001 From: Vincent Roseberry Date: Fri, 9 Apr 2021 20:55:52 +0000 Subject: [PATCH] Add JAX to the CPU/TPU image. Fixes #918 Before, we were installing JAX only on the GPU image. Included tests to prevent regression. http://b/177334844 --- Dockerfile | 2 ++ gpu.Dockerfile | 2 +- tests/test_jax.py | 12 ++++-------- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/Dockerfile b/Dockerfile index f7c15b99..db5b7d03 100644 --- a/Dockerfile +++ b/Dockerfile @@ -420,6 +420,8 @@ RUN pip install flashtext && \ # pycrypto is used by competitions team. pip install pycrypto && \ pip install easyocr && \ + # Keep JAX version in sync with GPU image. + pip install jax==0.2.12 jaxlib==0.1.64 && \ /tmp/clean-layer.sh # Download base easyocr models. diff --git a/gpu.Dockerfile b/gpu.Dockerfile index 01b448ce..b40d0807 100644 --- a/gpu.Dockerfile +++ b/gpu.Dockerfile @@ -77,7 +77,7 @@ RUN pip uninstall -y lightgbm && \ echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \ /tmp/clean-layer.sh -# Install JAX +# Install JAX (Keep JAX version in sync with CPU image) RUN pip install jax==0.2.12 jaxlib==0.1.64+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \ /tmp/clean-layer.sh diff --git a/tests/test_jax.py b/tests/test_jax.py index 0a3c24ae..1847ceec 100644 --- a/tests/test_jax.py +++ b/tests/test_jax.py @@ -1,22 +1,18 @@ import unittest - import time +import jax.numpy as np + from common import gpu_test +from jax import grad, jit class TestJAX(unittest.TestCase): def tanh(self, x): - import jax.numpy as np y = np.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) - @gpu_test - def test_JAX(self): - # importing inside the gpu-only test because these packages can't be - # imported on the CPU image since they are not present there. - from jax import grad, jit - + def test_grad(self): grad_tanh = grad(self.tanh) ag = grad_tanh(1.0) self.assertEqual(0.4199743, ag)