@@ -52,27 +52,36 @@ def read_version(file_path="version.txt"):
5252import torch
5353from torch .utils .cpp_extension import (
5454 CUDA_HOME ,
55- ROCM_HOME ,
5655 IS_WINDOWS ,
56+ ROCM_HOME ,
5757 BuildExtension ,
5858 CppExtension ,
5959 CUDAExtension ,
6060)
6161
6262IS_ROCM = (torch .version .hip is not None ) and (ROCM_HOME is not None )
6363
64+
6465def get_extensions ():
6566 debug_mode = os .getenv ("DEBUG" , "0" ) == "1"
6667 if debug_mode :
6768 print ("Compiling in debug mode" )
6869
6970 if not torch .cuda .is_available ():
70- print ("PyTorch GPU support is not available. Skipping compilation of CUDA extensions" )
71+ print (
72+ "PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
73+ )
7174 if (CUDA_HOME is None and ROCM_HOME is None ) and torch .cuda .is_available ():
72- print ("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions" )
73- print ("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit" )
75+ print (
76+ "CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
77+ )
78+ print (
79+ "If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
80+ )
7481
75- use_cuda = torch .cuda .is_available () and (CUDA_HOME is not None or ROCM_HOME is not None )
82+ use_cuda = torch .cuda .is_available () and (
83+ CUDA_HOME is not None or ROCM_HOME is not None
84+ )
7685 extension = CUDAExtension if use_cuda else CppExtension
7786
7887 extra_link_args = []
@@ -125,8 +134,12 @@ def get_extensions():
125134 glob .glob (os .path .join (extensions_cuda_dir , "**/*.cu" ), recursive = True )
126135 )
127136
128- extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "tensor_core_tiled_layout" , "sparse_marlin" )
129- hip_sources = list (glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True ))
137+ extensions_hip_dir = os .path .join (
138+ extensions_dir , "cuda" , "tensor_core_tiled_layout" , "sparse_marlin"
139+ )
140+ hip_sources = list (
141+ glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
142+ )
130143
131144 if not IS_ROCM and use_cuda :
132145 sources += cuda_sources
0 commit comments