11ARG BASE_TAG=staging
22
3- FROM nvidia/cuda:10.0 -cudnn7-devel-ubuntu16.04 AS nvidia
4- FROM gcr.io/kaggle-images/python-tensorflow-whl:2.1.0-py36 as tensorflow_whl
3+ FROM nvidia/cuda:10.1 -cudnn7-devel-ubuntu16.04 AS nvidia
4+ FROM gcr.io/kaggle-images/python-tensorflow-whl:2.1.0-py36-2 as tensorflow_whl
55FROM gcr.io/kaggle-images/python:${BASE_TAG}
66
77ADD clean-layer.sh /tmp/clean-layer.sh
@@ -13,8 +13,11 @@ COPY --from=nvidia /etc/apt/trusted.gpg /etc/apt/trusted.gpg.d/cuda.gpg
1313
1414# Ensure the cuda libraries are compatible with the custom Tensorflow wheels.
1515# TODO(b/120050292): Use templating to keep in sync or COPY installed binaries from it.
16- ENV CUDA_VERSION=10.0.130
17- ENV CUDA_PKG_VERSION=10-0=$CUDA_VERSION-1
16+ ENV CUDA_MAJOR_VERSION=10
17+ ENV CUDA_MINOR_VERSION=1
18+ ENV CUDA_PATCH_VERSION=243
19+ ENV CUDA_VERSION=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION.$CUDA_PATCH_VERSION
20+ ENV CUDA_PKG_VERSION=$CUDA_MAJOR_VERSION-$CUDA_MINOR_VERSION=$CUDA_VERSION-1
1821LABEL com.nvidia.volumes.needed="nvidia_driver"
1922LABEL com.nvidia.cuda.version="${CUDA_VERSION}"
2023ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
@@ -26,7 +29,7 @@ ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:${PATH}
2629ENV LD_LIBRARY_PATH="/usr/local/nvidia/lib64:/usr/local/cuda/lib64:/usr/local/cuda/lib64/stubs"
2730ENV NVIDIA_VISIBLE_DEVICES=all
2831ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
29- ENV NVIDIA_REQUIRE_CUDA="cuda>=10.0 "
32+ ENV NVIDIA_REQUIRE_CUDA="cuda>=$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION "
3033RUN apt-get update && apt-get install -y --no-install-recommends \
3134 cuda-cupti-$CUDA_PKG_VERSION \
3235 cuda-cudart-$CUDA_PKG_VERSION \
@@ -36,11 +39,11 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
3639 cuda-nvml-dev-$CUDA_PKG_VERSION \
3740 cuda-minimal-build-$CUDA_PKG_VERSION \
3841 cuda-command-line-tools-$CUDA_PKG_VERSION \
39- libcudnn7=7.5.0.56 -1+cuda10.0 \
40- libcudnn7-dev=7.5.0.56 -1+cuda10.0 \
41- libnccl2=2.4.2 -1+cuda10.0 \
42- libnccl-dev=2.4.2 -1+cuda10.0 && \
43- ln -s /usr/local/cuda-10.0 /usr/local/cuda && \
42+ libcudnn7=7.6.5.32 -1+cuda$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION \
43+ libcudnn7-dev=7.6.5.32 -1+cuda$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION \
44+ libnccl2=2.5.6 -1+cuda$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION \
45+ libnccl-dev=2.5.6 -1+cuda$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION && \
46+ ln -s /usr/local/cuda-$CUDA_MAJOR_VERSION.$CUDA_MINOR_VERSION /usr/local/cuda && \
4447 ln -s /usr/local/cuda/lib64/stubs/libcuda.so /usr/local/cuda/lib64/stubs/libcuda.so.1 && \
4548 /tmp/clean-layer.sh
4649
@@ -67,7 +70,7 @@ RUN pip uninstall -y lightgbm && \
6770
6871# Install JAX
6972ENV JAX_PYTHON_VERSION=cp36
70- ENV JAX_CUDA_VERSION=cuda100
73+ ENV JAX_CUDA_VERSION=cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION
7174ENV JAX_PLATFORM=linux_x86_64
7275ENV JAX_BASE_URL="https://storage.googleapis.com/jax-releases"
7376
@@ -80,15 +83,15 @@ RUN pip uninstall -y tensorflow && \
8083 pip install /tmp/tensorflow_gpu/tensorflow*.whl && \
8184 rm -rf /tmp/tensorflow_gpu && \
8285 conda remove --force -y pytorch torchvision torchaudio cpuonly && \
83- conda install -y pytorch torchvision torchaudio cudatoolkit=10.0 -c pytorch && \
86+ conda install -y pytorch torchvision torchaudio cudatoolkit=CUDA_MAJOR_VERSION.CUDA_MINOR_VERSION -c pytorch && \
8487 pip uninstall -y mxnet && \
8588 # b/126259508 --no-deps prevents numpy from being downgraded.
86- pip install --no-deps mxnet-cu100 && \
89+ pip install --no-deps mxnet-cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION && \
8790 /tmp/clean-layer.sh
8891
8992# Install GPU-only packages
9093RUN pip install pycuda && \
91- pip install cupy-cuda100 && \
94+ pip install cupy-cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION && \
9295 pip install pynvrtc && \
9396 /tmp/clean-layer.sh
9497
0 commit comments