From a54fc9610e582ebde8d3dd3164db6ce85f9fa8ab Mon Sep 17 00:00:00 2001 From: Dustin Herbison Date: Thu, 26 May 2022 14:16:45 +0000 Subject: [PATCH 1/2] Load correct libtpu for pytorch & jax. It turns out jax & pytorch are incompatible (require different libtpu versions). In order to support importing EITHER of them (but not both) we will swap in the correct libtpu during import (by monkey-patching the import code for both). http://b/213335159 --- tpu/Dockerfile | 15 +++++++++++---- tpu/config.txt | 3 +-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tpu/Dockerfile b/tpu/Dockerfile index ff6a8a0b..99f7f49e 100644 --- a/tpu/Dockerfile +++ b/tpu/Dockerfile @@ -1,8 +1,6 @@ ARG BASE_IMAGE_TAG -ARG LIBTPU_IMAGE_TAG ARG TENSORFLOW_VERSION -FROM gcr.io/cloud-tpu-v2-images/libtpu:${LIBTPU_IMAGE_TAG} as libtpu FROM gcr.io/kaggle-images/python-tpu-tensorflow-whl:python-${BASE_IMAGE_TAG}-${TENSORFLOW_VERSION} AS tensorflow_whl FROM gcr.io/kaggle-images/python:${BASE_IMAGE_TAG} @@ -12,20 +10,29 @@ ARG TORCH_VERSION ENV ISTPUVM=1 -COPY --from=libtpu /libtpu.so /lib - COPY --from=tensorflow_whl /tmp/tensorflow_pkg/tensorflow*.whl /tmp/tensorflow_pkg/ RUN pip install /tmp/tensorflow_pkg/tensorflow*.whl && \ rm -rf /tmp/tensorflow_pkg && \ /tmp/clean-layer.sh +# LIBTPU installed here: +ENV DEFAULT_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/libtpu.so +ENV PYTORCH_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/torch-libtpu.so +ENV JAX_LIBTPU=/opt/conda/lib/python3.7/site-packages/libtpu/jax-libtpu.so + # https://cloud.google.com/tpu/docs/pytorch-xla-ug-tpu-vm#changing_pytorch_version RUN pip uninstall -y torch && \ pip install torch==${TORCH_VERSION} && \ # The URL doesn't include patch version. i.e. must use 1.11 instead of 1.11.0 pip install torch_xla[tpuvm] -f https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-${TORCH_VERSION%.*}-cp37-cp37m-linux_x86_64.whl && \ + cp $DEFAULT_LIBTPU $PYTORCH_LIBTPU && \ /tmp/clean-layer.sh # https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm#install_jax_on_your_cloud_tpu_vm RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && \ + cp $DEFAULT_LIBTPU $JAX_LIBTPU && \ /tmp/clean-layer.sh + +# Monkey-patch JAX & PYTORCH to load the correct libtpu.so when they are imported: +RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \ + sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py diff --git a/tpu/config.txt b/tpu/config.txt index 8ab81c9c..569f305c 100644 --- a/tpu/config.txt +++ b/tpu/config.txt @@ -1,5 +1,4 @@ # TODO(b/213335159): Use ci-pretest for BASE_IMAGE_TAG once stable. -BASE_IMAGE_TAG=v108 -LIBTPU_IMAGE_TAG=libtpu_1.1.0_RC00 +BASE_IMAGE_TAG=v115 TENSORFLOW_VERSION=2.8.0 TORCH_VERSION=1.11.0 \ No newline at end of file From 05c135e90ab78e7d993374450b029509034999d6 Mon Sep 17 00:00:00 2001 From: Dustin Herbison Date: Thu, 26 May 2022 15:07:00 +0000 Subject: [PATCH 2/2] monkeypatch tf and add env vars to suppress warns --- tpu/Dockerfile | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tpu/Dockerfile b/tpu/Dockerfile index 99f7f49e..797288d6 100644 --- a/tpu/Dockerfile +++ b/tpu/Dockerfile @@ -33,6 +33,11 @@ RUN pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-release cp $DEFAULT_LIBTPU $JAX_LIBTPU && \ /tmp/clean-layer.sh -# Monkey-patch JAX & PYTORCH to load the correct libtpu.so when they are imported: +# Monkey-patch TF, JAX & PYTORCH to load the correct libtpu.so when they are imported: RUN sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${PYTORCH_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/torch_xla/__init__.py && \ - sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py + sed -i "s|^\(\(.*\)libtpu.configure_library_path.*\)|\1\n\2os.environ['TPU_LIBRARY_PATH'] = '${JAX_LIBTPU}'|" /opt/conda/lib/python3.7/site-packages/jax/_src/cloud_tpu_init.py && \ + sed -i "1s/^/from jax._src.cloud_tpu_init import cloud_tpu_init\ncloud_tpu_init()\n/" /opt/conda/lib/python3.7/site-packages/tensorflow/__init__.py + +# Set these env vars so that they don't produce errs calling the metadata server to load them: +ENV TPU_ACCELERATOR_TYPE=v3-8 +ENV TPU_PROCESS_ADDRESSES=local \ No newline at end of file