@@ -7,10 +7,12 @@ ARG TORCH_VERSION
77ARG TORCHAUDIO_VERSION
88ARG TORCHTEXT_VERSION
99ARG TORCHVISION_VERSION
10+ ARG JAX_VERSION
1011
1112{{ if eq .Accelerator "gpu" }}
1213FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
1314FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
15+ FROM gcr.io/kaggle-images/python-jaxlib-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${JAX_VERSION} AS jaxlib_whl
1416FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
1517{{ else }}
1618FROM ${BASE_IMAGE_REPO}/${CPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
@@ -36,9 +38,9 @@ RUN ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/lib
3638{{ end }}
3739
3840# Keep these variables in sync if base image is updated.
39- ENV TENSORFLOW_VERSION=2.13 .0
41+ ENV TENSORFLOW_VERSION=2.15 .0
4042# See https://github.com/tensorflow/io#tensorflow-version-compatibility
41- ENV TENSORFLOW_IO_VERSION=0.34 .0
43+ ENV TENSORFLOW_IO_VERSION=0.35 .0
4244
4345# We need to redefine the ARG here to get the ARG value defined above the FROM instruction.
4446# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
@@ -47,6 +49,7 @@ ARG TORCH_VERSION
4749ARG TORCHAUDIO_VERSION
4850ARG TORCHTEXT_VERSION
4951ARG TORCHVISION_VERSION
52+ ARG JAX_VERSION
5053
5154# Disable pesky logs like: KMP_AFFINITY: pid 6121 tid 6121 thread 0 bound to OS proc set 0
5255# See: https://stackoverflow.com/questions/57385766/disable-tensorflow-log-information
@@ -158,7 +161,9 @@ RUN pip install lightgbm==$LIGHTGBM_VERSION && \
158161
159162# Install JAX
160163{{ if eq .Accelerator "gpu" }}
161- RUN pip install "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
164+ COPY --from=jaxlib_whl /tmp/whl/*.whl /tmp/jax/
165+ # b/319722433#comment9: Use pip wheels once versions matches our CUDA version.
166+ RUN pip install /tmp/jax/*.whl jax==$JAX_VERSION && \
162167 /tmp/clean-layer.sh
163168{{ else }}
164169RUN pip install jax[cpu] && \
@@ -169,7 +174,7 @@ RUN pip install jax[cpu] && \
169174# Install GPU specific packages
170175{{ if eq .Accelerator "gpu" }}
171176# Install GPU-only packages
172- # No specific package for nnabla-ext-cuda 11 .x minor versions.
177+ # No specific package for nnabla-ext-cuda 12 .x minor versions.
173178RUN export PATH=/usr/local/cuda/bin:$PATH && \
174179 export CUDA_ROOT=/usr/local/cuda && \
175180 pip install pycuda \
@@ -199,10 +204,17 @@ RUN pip install -f http://h2o-release.s3.amazonaws.com/h2o/latest_stable_Py.html
199204
200205RUN pip install \
201206 "tensorflow==${TENSORFLOW_VERSION}" \
202- "tensorflow-io==${TENSORFLOW_IO_VERSION}"\
207+ "tensorflow-io==${TENSORFLOW_IO_VERSION}" \
203208 tensorflow-addons \
204209 tensorflow_decision_forests \
205- tensorflow_text && \
210+ tensorflow_text \
211+ tensorflowjs \
212+ tensorflow_hub && \
213+ /tmp/clean-layer.sh
214+
215+ # TODO(b/318672158): Upgrade to Keras 3 once compatible with other TF libries.
216+ # See blockers here: https://b.corp.google.com/issues/319722433#comment8
217+ RUN pip install keras keras-cv keras-nlp && \
206218 /tmp/clean-layer.sh
207219
208220RUN pip install pysal
@@ -268,12 +280,6 @@ RUN pip install scipy \
268280 apt-get install -y pandoc && \
269281 pip install essentia
270282
271- {{ if eq .Accelerator "gpu" }}
272- # #1281 Install numba MVC support:
273- RUN pip install ptxcompiler-cu11 cubinlinker-cu11 --extra-index-url=https://pypi.nvidia.com
274- ENV NUMBA_CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY=1
275- {{ end }}
276-
277283RUN apt-get install -y git-lfs && \
278284 /tmp/clean-layer.sh
279285
@@ -316,8 +322,7 @@ RUN pip install mpld3 \
316322 s2sphere \
317323 bayesian-optimization \
318324 matplotlib-venn \
319- # b/184083722 pyldavis >= 3.3 requires numpy >= 1.20.0 but TensorFlow 2.4.1 / 2.5.0 requires 1.19.2
320- pyldavis==3.2.2 \
325+ pyldavis \
321326 mlxtend \
322327 altair \
323328 ImageHash \
@@ -527,8 +532,6 @@ RUN pip install flashtext \
527532 gym \
528533 pyarabic \
529534 pandasql \
530- tensorflow_hub \
531- tensorflowjs \
532535 jieba \
533536 # ggplot is broken and main repo does not merge and release https://github.com/yhat/ggpy/pull/668
534537 https://github.com/hbasria/ggpy/archive/0.11.5.zip \
@@ -543,13 +546,7 @@ RUN pip install flashtext \
543546 # b/290207097 switch back to the pip catalyst package when bug fixed
544547 # https://github.com/catalyst-team/catalyst/issues/1440
545548 git+https://github.com/Philmod/catalyst.git@fix-fp16#egg=catalyst \
546- # b/206990323 osmx 1.1.2 requires numpy >= 1.21 which we don't want.
547- osmnx==1.1.1 \
548- # Remove once `keras-core` is released as Keras
549- keras-core \
550- # TODO(b/315833744) unpin when the alpha versions are merged to the main version.
551- keras-cv \
552- keras-nlp && \
549+ osmnx && \
553550 apt-get -y install libspatialindex-dev
554551
555552RUN pip install pytorch-ignite \
0 commit comments