@@ -49,6 +49,17 @@ setup_cuda() {
4949
5050 # Now work out the CUDA settings
5151 case " $CU_VERSION " in
52+ cu112)
53+ if [[ " $OSTYPE " == " msys" ]]; then
54+ export CUDA_HOME=" C:\\ Program Files\\ NVIDIA GPU Computing Toolkit\\ CUDA\\ v11.2"
55+ else
56+ export CUDA_HOME=/usr/local/cuda-11.2/
57+ fi
58+ export FORCE_CUDA=1
59+ # Hard-coding gencode flags is temporary situation until
60+ # https://github.com/pytorch/pytorch/pull/23408 lands
61+ export NVCC_FLAGS=" -gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50"
62+ ;;
5263 cu110)
5364 if [[ " $OSTYPE " == " msys" ]]; then
5465 export CUDA_HOME=" C:\\ Program Files\\ NVIDIA GPU Computing Toolkit\\ CUDA\\ v11.0"
@@ -170,10 +181,13 @@ setup_wheel_python() {
170181 if [[ " $( uname) " == Darwin || " $OSTYPE " == " msys" ]]; then
171182 eval " $( conda shell.bash hook) "
172183 conda env remove -n " env$PYTHON_VERSION " || true
173- conda create -yn " env$PYTHON_VERSION " python=" $PYTHON_VERSION "
184+ if [[ " $PYTHON_VERSION " == 3.9 ]]; then
185+ export CONDA_CHANNEL_FLAGS=" ${CONDA_CHANNEL_FLAGS} -c=conda-forge"
186+ fi
187+ conda create ${CONDA_CHANNEL_FLAGS} -yn " env$PYTHON_VERSION " python=" $PYTHON_VERSION "
174188 conda activate " env$PYTHON_VERSION "
175189 # Install libpng from Anaconda (defaults)
176- conda install libpng jpeg -y
190+ conda install ${CONDA_CHANNEL_FLAGS} -c conda-forge libpng " jpeg<=9b " -y
177191 else
178192 # Install native CentOS libJPEG, LAME, freetype and GnuTLS
179193 yum install -y libjpeg-turbo-devel lame freetype gnutls
@@ -189,6 +203,7 @@ setup_wheel_python() {
189203 3.6) python_abi=cp36-cp36m ;;
190204 3.7) python_abi=cp37-cp37m ;;
191205 3.8) python_abi=cp38-cp38 ;;
206+ 3.9) python_abi=cp39-cp39 ;;
192207 * )
193208 echo " Unrecognized PYTHON_VERSION=$PYTHON_VERSION "
194209 exit 1
@@ -263,6 +278,9 @@ setup_conda_pytorch_constraint() {
263278 if [[ " $OSTYPE " == msys && " $CU_VERSION " == cu92 ]]; then
264279 export CONDA_CHANNEL_FLAGS=" ${CONDA_CHANNEL_FLAGS} -c defaults -c numba/label/dev"
265280 fi
281+ if [[ " $PYTHON_VERSION " == 3.9 ]]; then
282+ export CONDA_CHANNEL_FLAGS=" ${CONDA_CHANNEL_FLAGS} -c=conda-forge"
283+ fi
266284}
267285
268286# Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT
@@ -272,6 +290,9 @@ setup_conda_cudatoolkit_constraint() {
272290 export CONDA_CUDATOOLKIT_CONSTRAINT=" "
273291 else
274292 case " $CU_VERSION " in
293+ cu112)
294+ export CONDA_CUDATOOLKIT_CONSTRAINT=" - cudatoolkit >=11.2,<11.3 # [not osx]"
295+ ;;
275296 cu110)
276297 export CONDA_CUDATOOLKIT_CONSTRAINT=" - cudatoolkit >=11.0,<11.1 # [not osx]"
277298 ;;
@@ -307,6 +328,9 @@ setup_conda_cudatoolkit_plain_constraint() {
307328 export CMAKE_USE_CUDA=0
308329 else
309330 case " $CU_VERSION " in
331+ cu112)
332+ export CONDA_CUDATOOLKIT_CONSTRAINT=" cudatoolkit=11.2"
333+ ;;
310334 cu102)
311335 export CONDA_CUDATOOLKIT_CONSTRAINT=" cudatoolkit=10.2"
312336 ;;
0 commit comments