1+ ARG BASE_IMAGE_REPO
2+ ARG BASE_IMAGE_TAG
3+ ARG CPU_BASE_IMAGE_NAME
4+ ARG GPU_BASE_IMAGE_NAME
5+ ARG LIGHTGBM_VERSION
6+ ARG TORCH_VERSION
7+ ARG TORCHAUDIO_VERSION
8+ ARG TORCHTEXT_VERSION
9+ ARG TORCHVISION_VERSION
10+
111{{ if eq .Accelerator "gpu" }}
2- FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-6:m80
12+ FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
13+ FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
14+ FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
315ENV CUDA_MAJOR_VERSION=11
416ENV CUDA_MINOR_VERSION=0
517{{ else }}
6- FROM gcr.io/deeplearning-platform-release/tf2-cpu.2-6:m80
18+ FROM ${BASE_IMAGE_REPO}/${CPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
719{{ end }}
820# Keep these variables in sync if base image is updated.
921ENV TENSORFLOW_VERSION=2.6.0
22+
23+ # We need to redefine the ARG here to get the ARG value defined above the FROM instruction.
24+ # See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact
25+ ARG LIGHTGBM_VERSION
26+ ARG TORCH_VERSION
27+ ARG TORCHAUDIO_VERSION
28+ ARG TORCHTEXT_VERSION
29+ ARG TORCHVISION_VERSION
30+
1031# Disable pesky logs like: KMP_AFFINITY: pid 6121 tid 6121 thread 0 bound to OS proc set 0
1132# See: https://stackoverflow.com/questions/57385766/disable-tensorflow-log-information
1233ENV KMP_WARNINGS=0
@@ -15,6 +36,9 @@ ADD clean-layer.sh /tmp/clean-layer.sh
1536ADD patches/nbconvert-extensions.tpl /opt/kaggle/nbconvert-extensions.tpl
1637ADD patches/template_conf.json /opt/kaggle/conf.json
1738
39+ # Adds the libcuda.so to LD_LIBRARY_PATH which is necessary for the GPU mxnet package.
40+ ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/compat
41+
1842{{ if eq .Accelerator "gpu" }}
1943# b/200968891 Keeps horovod once torch is upgraded.
2044RUN pip uninstall -y horovod && \
@@ -52,29 +76,24 @@ RUN conda install cudf=21.08 cuml=21.08 cudatoolkit=$CUDA_MAJOR_VERSION.$CUDA_MI
5276
5377# Install PyTorch
5478{{ if eq .Accelerator "gpu" }}
55- RUN pip install torch==1.7.1+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION torchvision==0.8.2+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION torchaudio==0.7.2 torchtext==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html && \
79+ COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
80+ RUN pip install /tmp/torch/*.whl && \
81+ rm -rf /tmp/torch && \
5682 /tmp/clean-layer.sh
5783{{ else }}
58- RUN pip install torch==1.7.1 +cpu torchvision==0.8.2 +cpu torchaudio==0.7.2 torchtext==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html && \
84+ RUN pip install torch==$TORCH_VERSION +cpu torchvision==$TORCHVISION_VERSION +cpu torchaudio==$TORCHAUDIO_VERSION torchtext==$TORCHTEXT_VERSION -f https://download.pytorch.org/whl/torch_stable.html && \
5985 /tmp/clean-layer.sh
6086{{ end }}
6187
6288# Install LightGBM
63- ENV LIGHTGBM_VERSION=3.2.1
6489{{ if eq .Accelerator "gpu" }}
90+ COPY --from=lightgbm_whl /tmp/whl/*.whl /tmp/lightgbm/
6591# Install OpenCL (required by LightGBM GPU version)
6692RUN apt-get install -y ocl-icd-libopencl1 clinfo && \
6793 mkdir -p /etc/OpenCL/vendors && \
6894 echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \
69- cd /usr/local/src && \
70- git clone --recursive https://github.com/microsoft/LightGBM && \
71- cd LightGBM && \
72- git checkout tags/v$LIGHTGBM_VERSION && \
73- mkdir build && cd build && \
74- cmake -DUSE_GPU=1 -DOpenCL_LIBRARY=/usr/local/cuda/lib64/libOpenCL.so -DOpenCL_INCLUDE_DIR=/usr/local/cuda/include/ .. && \
75- make -j$(nproc) && \
76- cd /usr/local/src/LightGBM/python-package && \
77- python setup.py install --precompile && \
95+ pip install /tmp/lightgbm/*.whl && \
96+ rm -rf /tmp/lightgbm && \
7897 /tmp/clean-layer.sh
7998{{ else }}
8099RUN pip install lightgbm==$LIGHTGBM_VERSION && \
@@ -386,8 +405,7 @@ RUN pip install bleach && \
386405 pip install widgetsnbextension && \
387406 pip install pyarrow && \
388407 pip install feather-format && \
389- # fastai >= 2.3.1 upgrades pytorch/torchvision. upgrade of pytorch will be handled in b/181966788
390- pip install fastai==2.2.7 && \
408+ pip install fastai && \
391409 pip install allennlp && \
392410 # https://b.corp.google.com/issues/184685619#comment9: 3.9.0 is causing a major performance degradation with spacy 2.3.5
393411 pip install importlib-metadata==3.4.0 && \
0 commit comments