|
| 1 | +ARG BASE_IMAGE_REPO |
| 2 | +ARG BASE_IMAGE_TAG |
| 3 | +ARG CPU_BASE_IMAGE_NAME |
| 4 | +ARG GPU_BASE_IMAGE_NAME |
| 5 | +ARG LIGHTGBM_VERSION |
| 6 | +ARG TORCH_VERSION |
| 7 | +ARG TORCHVISION_VERSION |
| 8 | + |
1 | 9 | {{ if eq .Accelerator "gpu" }} |
2 | | -FROM gcr.io/deeplearning-platform-release/tf2-gpu.2-6:m80 |
| 10 | +FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl |
| 11 | +FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl |
| 12 | +FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG} |
3 | 13 | ENV CUDA_MAJOR_VERSION=11 |
4 | 14 | ENV CUDA_MINOR_VERSION=0 |
5 | 15 | {{ else }} |
6 | | -FROM gcr.io/deeplearning-platform-release/tf2-cpu.2-6:m80 |
| 16 | +FROM ${BASE_IMAGE_REPO}/${CPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG} |
7 | 17 | {{ end }} |
8 | 18 | # Keep these variables in sync if base image is updated. |
9 | 19 | ENV TENSORFLOW_VERSION=2.6.0 |
| 20 | + |
| 21 | +# We need to redefine the ARG here to get the ARG value defined above the FROM instruction. |
| 22 | +# See: https://docs.docker.com/engine/reference/builder/#understand-how-arg-and-from-interact |
| 23 | +ARG LIGHTGBM_VERSION |
| 24 | +ARG TORCH_VERSION |
| 25 | +ARG TORCHVISION_VERSION |
| 26 | + |
10 | 27 | # Disable pesky logs like: KMP_AFFINITY: pid 6121 tid 6121 thread 0 bound to OS proc set 0 |
11 | 28 | # See: https://stackoverflow.com/questions/57385766/disable-tensorflow-log-information |
12 | 29 | ENV KMP_WARNINGS=0 |
@@ -52,29 +69,24 @@ RUN conda install cudf=21.08 cuml=21.08 cudatoolkit=$CUDA_MAJOR_VERSION.$CUDA_MI |
52 | 69 |
|
53 | 70 | # Install PyTorch |
54 | 71 | {{ if eq .Accelerator "gpu" }} |
55 | | -RUN pip install torch==1.7.1+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION torchvision==0.8.2+cu$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION torchaudio==0.7.2 torchtext==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html && \ |
| 72 | +COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/ |
| 73 | +RUN pip install /tmp/torch/*.whl torchaudio==0.9.1 torchtext==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html && \ |
| 74 | + rm -rf /tmp/torch && \ |
56 | 75 | /tmp/clean-layer.sh |
57 | 76 | {{ else }} |
58 | | -RUN pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 torchtext==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html && \ |
| 77 | +RUN pip install torch==$TORCH_VERSION+cpu torchvision==$TORCHVISION_VERSION+cpu torchaudio==0.9.1 torchtext==0.10.1 -f https://download.pytorch.org/whl/torch_stable.html && \ |
59 | 78 | /tmp/clean-layer.sh |
60 | 79 | {{ end }} |
61 | 80 |
|
62 | 81 | # Install LightGBM |
63 | | -ENV LIGHTGBM_VERSION=3.2.1 |
64 | 82 | {{ if eq .Accelerator "gpu" }} |
| 83 | +COPY --from=lightgbm_whl /tmp/whl/*.whl /tmp/lightgbm/ |
65 | 84 | # Install OpenCL (required by LightGBM GPU version) |
66 | 85 | RUN apt-get install -y ocl-icd-libopencl1 clinfo && \ |
67 | 86 | mkdir -p /etc/OpenCL/vendors && \ |
68 | 87 | echo "libnvidia-opencl.so.1" > /etc/OpenCL/vendors/nvidia.icd && \ |
69 | | - cd /usr/local/src && \ |
70 | | - git clone --recursive https://github.com/microsoft/LightGBM && \ |
71 | | - cd LightGBM && \ |
72 | | - git checkout tags/v$LIGHTGBM_VERSION && \ |
73 | | - mkdir build && cd build && \ |
74 | | - cmake -DUSE_GPU=1 -DOpenCL_LIBRARY=/usr/local/cuda/lib64/libOpenCL.so -DOpenCL_INCLUDE_DIR=/usr/local/cuda/include/ .. && \ |
75 | | - make -j$(nproc) && \ |
76 | | - cd /usr/local/src/LightGBM/python-package && \ |
77 | | - python setup.py install --precompile && \ |
| 88 | + pip install /tmp/lightgbm/*.whl && \ |
| 89 | + rm -rf /tmp/lightgbm && \ |
78 | 90 | /tmp/clean-layer.sh |
79 | 91 | {{ else }} |
80 | 92 | RUN pip install lightgbm==$LIGHTGBM_VERSION && \ |
@@ -386,8 +398,7 @@ RUN pip install bleach && \ |
386 | 398 | pip install widgetsnbextension && \ |
387 | 399 | pip install pyarrow && \ |
388 | 400 | pip install feather-format && \ |
389 | | - # fastai >= 2.3.1 upgrades pytorch/torchvision. upgrade of pytorch will be handled in b/181966788 |
390 | | - pip install fastai==2.2.7 && \ |
| 401 | + pip install fastai && \ |
391 | 402 | pip install allennlp && \ |
392 | 403 | # https://b.corp.google.com/issues/184685619#comment9: 3.9.0 is causing a major performance degradation with spacy 2.3.5 |
393 | 404 | pip install importlib-metadata==3.4.0 && \ |
|
0 commit comments