1313# limitations under the License.
1414
1515ARG UBUNTU_VERSION=20.04
16- ARG CUDA_VERSION=11.3.1
16+ ARG CUDA_VERSION=11.6.1
17+
1718
1819FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu${UBUNTU_VERSION}
1920
2021ARG PYTHON_VERSION=3.9
21- ARG PYTORCH_VERSION=1.12
22+ ARG PYTORCH_VERSION=1.13
2223
2324SHELL ["/bin/bash" , "-c" ]
2425# https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/
3536RUN \
3637 # TODO: Remove the manual key installation once the base image is updated.
3738 # https://github.com/NVIDIA/nvidia-docker/issues/1631
38- apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
39+ # https://github.com/NVIDIA/nvidia-docker/issues/1631#issuecomment-1264715214
40+ apt-get update && apt-get install -y wget && \
41+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
42+ mkdir -p /etc/apt/keyrings/ && mv 3bf863cc.pub /etc/apt/keyrings/ && \
43+ echo "deb [signed-by=/etc/apt/keyrings/3bf863cc.pub] https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" /etc/apt/sources.list.d/cuda.list && \
44+ apt-get update && \
3945 apt-get update -qq --fix-missing && \
4046 NCCL_VER=$(dpkg -s libnccl2 | grep '^Version:' | awk -F ' ' '{print $2}' | awk -F '-' '{print $1}' | grep -ve '^\s *$' ) && \
4147 CUDA_VERSION_MM="${CUDA_VERSION%.*}" && \
@@ -132,24 +138,32 @@ RUN \
132138
133139RUN \
134140 # install Bagua
135- CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))" ) && \
136- CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])" ) && \
137- pip install "bagua-cuda$CUDA_VERSION_BAGUA" && \
138- if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then python -c "import bagua_core; bagua_core.install_deps()" ; fi && \
139- python -c "import bagua; print(bagua.__version__)"
141+ if [[ $PYTORCH_VERSION != "1.13" ]]; then \
142+ CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))" ) ; \
143+ CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])" ) ; \
144+ pip install "bagua-cuda$CUDA_VERSION_BAGUA" ; \
145+ if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then \
146+ python -c "import bagua_core; bagua_core.install_deps()" ; \
147+ fi ; \
148+ python -c "import bagua; print(bagua.__version__)" ; \
149+ fi
140150
141151RUN \
142152 # install ColossalAI
143- PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])" ) ; \
144- CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))" ) ; \
145- CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])" ) ; \
146- pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
147- python -c "import colossalai; print(colossalai.__version__)" ; \
153+ # TODO: 1.13 wheels are not released, remove skip once they are
154+ if [[ $PYTORCH_VERSION != "1.13" ]]; then \
155+ PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])" ) ; \
156+ CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))" ) ; \
157+ CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])" ) ; \
158+ pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
159+ python -c "import colossalai; print(colossalai.__version__)" ; \
160+ fi
148161
149162RUN \
150163 # install rest of strategies
151164 # remove colossalai from requirements since they are installed separately
152165 python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
166+ python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)" ; \
153167 cat requirements/pytorch/strategies.txt && \
154168 pip install -r requirements/pytorch/devel.txt -r requirements/pytorch/strategies.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html
155169
@@ -163,5 +177,4 @@ RUN \
163177 python -c "import sys; ver = sys.version_info ; assert f'{ver.major}.{ver.minor}' == '$PYTHON_VERSION', ver" && \
164178 python -c "import torch; assert torch.__version__.startswith('$PYTORCH_VERSION'), torch.__version__" && \
165179 python requirements/pytorch/check-avail-extras.py && \
166- python requirements/pytorch/check-avail-strategies.py && \
167180 rm -rf requirements/
0 commit comments