File tree Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Expand file tree Collapse file tree 2 files changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -123,7 +123,7 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \
123123
124124# Install JAX
125125{{ if eq .Accelerator "gpu" }}
126- RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_releases .html && \
126+ RUN pip install jax[cuda11_cudnn805] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases .html && \
127127 /tmp/clean-layer.sh
128128{{ else }}
129129RUN pip install jax[cpu] && \
Original file line number Diff line number Diff line change 33import os
44import time
55
6+ import jax
67import jax .numpy as np
78
89from common import gpu_test
@@ -21,4 +22,4 @@ def test_grad(self):
2122
2223 def test_backend (self ):
2324 expected_backend = 'cpu' if len (os .environ .get ('CUDA_VERSION' , '' )) == 0 else 'gpu'
24-
25+ self . assertEqual ( expected_backend , jax . default_backend ())
You can’t perform that action at this time.
0 commit comments