File tree Expand file tree Collapse file tree 5 files changed +21
-3
lines changed Expand file tree Collapse file tree 5 files changed +21
-3
lines changed Original file line number Diff line number Diff line change @@ -12,8 +12,10 @@ ARG TORCHVISION_VERSION
1212FROM gcr.io/kaggle-images/python-lightgbm-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${LIGHTGBM_VERSION} AS lightgbm_whl
1313FROM gcr.io/kaggle-images/python-torch-whl:${GPU_BASE_IMAGE_NAME}-${BASE_IMAGE_TAG}-${TORCH_VERSION} AS torch_whl
1414FROM ${BASE_IMAGE_REPO}/${GPU_BASE_IMAGE_NAME}:${BASE_IMAGE_TAG}
15- ENV CUDA_MAJOR_VERSION=11
16- ENV CUDA_MINOR_VERSION=0
15+ ARG CUDA_MAJOR_VERSION
16+ ARG CUDA_MINOR_VERSION
17+ ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION}
18+ ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION}
1719# NVIDIA binaries from the host are mounted to /opt/bin.
1820ENV PATH=/opt/bin:${PATH}
1921# Add CUDA stubs to LD_LIBRARY_PATH to support building the GPU image on a CPU machine.
@@ -99,7 +101,8 @@ RUN conda install implicit && \
99101# Install PyTorch
100102{{ if eq .Accelerator "gpu" }}
101103COPY --from=torch_whl /tmp/whl/*.whl /tmp/torch/
102- RUN pip install /tmp/torch/*.whl && \
104+ RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION} && \
105+ pip install /tmp/torch/*.whl && \
103106 rm -rf /tmp/torch && \
104107 /tmp/clean-layer.sh
105108{{ else }}
Original file line number Diff line number Diff line change @@ -37,6 +37,8 @@ pipeline {
3737 --build-arg TORCHAUDIO_VERSION=$TORCHAUDIO_VERSION \
3838 --build-arg TORCHTEXT_VERSION=$TORCHTEXT_VERSION \
3939 --build-arg TORCHVISION_VERSION=$TORCHVISION_VERSION \
40+ --build-arg CUDA_MAJOR_VERSION=$CUDA_MAJOR_VERSION \
41+ --build-arg CUDA_MINOR_VERSION=$CUDA_MINOR_VERSION \
4042 --push
4143 '''
4244 }
Original file line number Diff line number Diff line change @@ -7,3 +7,5 @@ TORCH_VERSION=1.9.1
77TORCHAUDIO_VERSION=0.9.1
88TORCHTEXT_VERSION=0.10.1
99TORCHVISION_VERSION=0.10.1
10+ CUDA_MAJOR_VERSION=11
11+ CUDA_MINOR_VERSION=0
Original file line number Diff line number Diff line change @@ -6,12 +6,15 @@ ARG PACKAGE_VERSION
66ARG TORCHAUDIO_VERSION
77ARG TORCHTEXT_VERSION
88ARG TORCHVISION_VERSION
9+ ARG CUDA_MAJOR_VERSION
10+ ARG CUDA_MINOR_VERSION
911
1012# TORCHVISION_VERSION is mandatory
1113RUN test -n "$TORCHVISION_VERSION"
1214
1315# Build instructions: https://github.com/pytorch/pytorch#from-source
1416RUN conda install astunparse numpy ninja pyyaml mkl mkl-include setuptools==59.5.0 cmake cffi typing_extensions future six requests dataclasses
17+ RUN conda install -c pytorch magma-cuda${CUDA_MAJOR_VERSION}${CUDA_MINOR_VERSION}
1518
1619# By default, it uses the version from version.txt which includes the `a0` (alpha zero) suffix and part of the git hash.
1720# This causes dependency conflicts like these: https://paste.googleplex.com/4786486378496000
Original file line number Diff line number Diff line change @@ -15,6 +15,14 @@ def test_nn(self):
1515 data_torch = autograd .Variable (torch .randn (2 , 5 ))
1616 linear_torch (data_torch )
1717
18+ @gpu_test
19+ def test_linalg (self ):
20+ A = torch .randn (3 , 3 ).t ().to ('cuda' )
21+ B = torch .randn (3 ).t ().to ('cuda' )
22+
23+ result = torch .linalg .solve (A , B )
24+ self .assertEqual (3 , result .shape [0 ])
25+
1826 @gpu_test
1927 def test_gpu_computation (self ):
2028 cuda = torch .device ('cuda' )
You can’t perform that action at this time.
0 commit comments