@@ -49,6 +49,17 @@ setup_cuda() {
4949
5050  #  Now work out the CUDA settings
5151  case  " $CU_VERSION " in 
52+     cu112)
53+       if  [[ " $OSTYPE " ==  " msys" ;  then 
54+         export  CUDA_HOME=" C:\\ Program Files\\ NVIDIA GPU Computing Toolkit\\ CUDA\\ v11.2" 
55+       else 
56+         export  CUDA_HOME=/usr/local/cuda-11.2/
57+       fi 
58+       export  FORCE_CUDA=1
59+       #  Hard-coding gencode flags is temporary situation until
60+       #  https://github.com/pytorch/pytorch/pull/23408 lands
61+       export  NVCC_FLAGS=" -gencode=arch=compute_35,code=sm_35 -gencode=arch=compute_50,code=sm_50 -gencode=arch=compute_60,code=sm_60 -gencode=arch=compute_70,code=sm_70 -gencode=arch=compute_75,code=sm_75 -gencode=arch=compute_80,code=sm_80 -gencode=arch=compute_86,code=sm_86 -gencode=arch=compute_50,code=compute_50" 
62+       ;;
5263    cu110)
5364      if  [[ " $OSTYPE " ==  " msys" ;  then 
5465        export  CUDA_HOME=" C:\\ Program Files\\ NVIDIA GPU Computing Toolkit\\ CUDA\\ v11.0" 
@@ -170,10 +181,13 @@ setup_wheel_python() {
170181  if  [[ " $( uname) " ==  Darwin ||  " $OSTYPE " ==  " msys" ;  then 
171182    eval  " $( conda shell.bash hook) " 
172183    conda env remove -n " env$PYTHON_VERSION " ||  true 
173-     conda create -yn " env$PYTHON_VERSION " " $PYTHON_VERSION " 
184+     if  [[ " $PYTHON_VERSION " ==  3.9 ]];  then 
185+       export  CONDA_CHANNEL_FLAGS=" ${CONDA_CHANNEL_FLAGS}  -c=conda-forge" 
186+     fi 
187+     conda create ${CONDA_CHANNEL_FLAGS}  -yn " env$PYTHON_VERSION " " $PYTHON_VERSION " 
174188    conda activate " env$PYTHON_VERSION " 
175189    #  Install libpng from Anaconda (defaults)
176-     conda install libpng jpeg -y
190+     conda install ${CONDA_CHANNEL_FLAGS}  -c conda-forge  libpng " jpeg<=9b " 
177191  else 
178192    #  Install native CentOS libJPEG, LAME, freetype and GnuTLS
179193    yum install -y libjpeg-turbo-devel lame freetype gnutls
@@ -189,6 +203,7 @@ setup_wheel_python() {
189203      3.6) python_abi=cp36-cp36m ;;
190204      3.7) python_abi=cp37-cp37m ;;
191205      3.8) python_abi=cp38-cp38 ;;
206+       3.9) python_abi=cp39-cp39 ;;
192207      * )
193208        echo  " Unrecognized PYTHON_VERSION=$PYTHON_VERSION " 
194209        exit  1
@@ -263,6 +278,9 @@ setup_conda_pytorch_constraint() {
263278  if  [[ " $OSTYPE " ==  msys &&  " $CU_VERSION " ==  cu92 ]];  then 
264279    export  CONDA_CHANNEL_FLAGS=" ${CONDA_CHANNEL_FLAGS}  -c defaults -c numba/label/dev" 
265280  fi 
281+   if  [[ " $PYTHON_VERSION " ==  3.9 ]];  then 
282+     export  CONDA_CHANNEL_FLAGS=" ${CONDA_CHANNEL_FLAGS}  -c=conda-forge" 
283+   fi 
266284}
267285
268286#  Translate CUDA_VERSION into CUDA_CUDATOOLKIT_CONSTRAINT
@@ -272,6 +290,9 @@ setup_conda_cudatoolkit_constraint() {
272290    export  CONDA_CUDATOOLKIT_CONSTRAINT=" " 
273291  else 
274292    case  " $CU_VERSION " in 
293+       cu112)
294+         export  CONDA_CUDATOOLKIT_CONSTRAINT=" - cudatoolkit >=11.2,<11.3 # [not osx]" 
295+         ;;
275296      cu110)
276297        export  CONDA_CUDATOOLKIT_CONSTRAINT=" - cudatoolkit >=11.0,<11.1 # [not osx]" 
277298        ;;
@@ -307,6 +328,9 @@ setup_conda_cudatoolkit_plain_constraint() {
307328    export  CMAKE_USE_CUDA=0
308329  else 
309330    case  " $CU_VERSION " in 
331+       cu112)
332+         export  CONDA_CUDATOOLKIT_CONSTRAINT=" cudatoolkit=11.2" 
333+         ;;
310334      cu102)
311335        export  CONDA_CUDATOOLKIT_CONSTRAINT=" cudatoolkit=10.2" 
312336        ;;
0 commit comments