@@ -74,7 +74,6 @@ def use_debug_mode():
7474 CUDAExtension ,
7575)
7676
77-
7877IS_ROCM = (torch .version .hip is not None ) and (ROCM_HOME is not None )
7978
8079# Constant known variables used throughout this file
@@ -258,38 +257,41 @@ def get_extensions():
258257 ]
259258 )
260259
260+ # Get base directory and source paths
261261 this_dir = os .path .dirname (os .path .curdir )
262262 extensions_dir = os .path .join (this_dir , "torchao" , "csrc" )
263- sources = list (glob .glob (os .path .join (extensions_dir , "**/*.cpp" ), recursive = True ))
264263
265- extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
266- cuda_sources = list (
267- glob .glob (os .path .join (extensions_cuda_dir , "**/*.cu" ), recursive = True )
268- )
269-
270- extensions_hip_dir = os .path .join (
271- extensions_dir , "cuda" , "tensor_core_tiled_layout" , "sparse_marlin"
272- )
273- hip_sources = list (
274- glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
275- )
264+ # Collect C++ source files
265+ sources = list (glob .glob (os .path .join (extensions_dir , "**/*.cpp" ), recursive = True ))
276266
277- if not IS_ROCM and use_cuda :
278- sources += cuda_sources
279-
280- # TOOD: Remove this and use what CUDA has once we fix all the builds.
281- if IS_ROCM and use_cuda :
282- # Add ROCm GPU architecture check
283- gpu_arch = torch .cuda .get_device_properties (0 ).name
284- if gpu_arch != "gfx942" :
285- print (f"Warning: Unsupported ROCm GPU architecture: { gpu_arch } " )
286- print (
287- "Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
267+ # Collect CUDA source files if needed
268+ if use_cuda :
269+ if not IS_ROCM :
270+ # Regular CUDA sources
271+ extensions_cuda_dir = os .path .join (extensions_dir , "cuda" )
272+ cuda_sources = list (
273+ glob .glob (os .path .join (extensions_cuda_dir , "**/*.cu" ), recursive = True )
274+ )
275+ sources += cuda_sources
276+ else :
277+ # ROCm sources
278+ extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "sparse_marlin" )
279+ hip_sources = list (
280+ glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
288281 )
289- return None
290- sources += hip_sources
291282
292- if len (sources ) == 0 :
283+ # Check ROCm GPU architecture compatibility
284+ gpu_arch = torch .cuda .get_device_properties (0 ).name
285+ if gpu_arch != "gfx942" :
286+ print (f"Warning: Unsupported ROCm GPU architecture: { gpu_arch } " )
287+ print (
288+ "Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
289+ )
290+ return None
291+ sources += hip_sources
292+
293+ # Return None if no sources found
294+ if not sources :
293295 return None
294296
295297 ext_modules = []
@@ -304,7 +306,6 @@ def get_extensions():
304306 )
305307 )
306308
307-
308309 if build_torchao_experimental :
309310 ext_modules .append (
310311 CMakeExtension (
0 commit comments