File tree Expand file tree Collapse file tree 3 files changed +7
-9
lines changed Expand file tree Collapse file tree 3 files changed +7
-9
lines changed Original file line number Diff line number Diff line change @@ -420,6 +420,8 @@ RUN pip install flashtext && \
420420 # pycrypto is used by competitions team.
421421 pip install pycrypto && \
422422 pip install easyocr && \
423+ # Keep JAX version in sync with GPU image.
424+ pip install jax==0.2.12 jaxlib==0.1.64 && \
423425 /tmp/clean-layer.sh
424426
425427# Download base easyocr models.
Original file line number Diff line number Diff line change @@ -77,7 +77,7 @@ RUN pip uninstall -y lightgbm && \
7777 echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \
7878 /tmp/clean-layer.sh
7979
80- # Install JAX
80+ # Install JAX (Keep JAX version in sync with CPU image)
8181RUN pip install jax==0.2.12 jaxlib==0.1.64+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
8282 /tmp/clean-layer.sh
8383
Original file line number Diff line number Diff line change 11import unittest
2-
32import time
43
4+ import jax .numpy as np
5+
56from common import gpu_test
7+ from jax import grad , jit
68
79
810class TestJAX (unittest .TestCase ):
911 def tanh (self , x ):
10- import jax .numpy as np
1112 y = np .exp (- 2.0 * x )
1213 return (1.0 - y ) / (1.0 + y )
1314
14- @gpu_test
15- def test_JAX (self ):
16- # importing inside the gpu-only test because these packages can't be
17- # imported on the CPU image since they are not present there.
18- from jax import grad , jit
19-
15+ def test_grad (self ):
2016 grad_tanh = grad (self .tanh )
2117 ag = grad_tanh (1.0 )
2218 self .assertEqual (0.4199743 , ag )
You can’t perform that action at this time.
0 commit comments