Skip to content

Commit aea9d81

Browse files
author
Peter Y. Yeh
committed
lint
refactor for better readibility
1 parent b96196b commit aea9d81

File tree

1 file changed

+29
-28
lines changed

1 file changed

+29
-28
lines changed

setup.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def use_debug_mode():
7474
CUDAExtension,
7575
)
7676

77-
7877
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
7978

8079
# Constant known variables used throughout this file
@@ -258,38 +257,41 @@ def get_extensions():
258257
]
259258
)
260259

260+
# Get base directory and source paths
261261
this_dir = os.path.dirname(os.path.curdir)
262262
extensions_dir = os.path.join(this_dir, "torchao", "csrc")
263-
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
264263

265-
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
266-
cuda_sources = list(
267-
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
268-
)
269-
270-
extensions_hip_dir = os.path.join(
271-
extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin"
272-
)
273-
hip_sources = list(
274-
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
275-
)
264+
# Collect C++ source files
265+
sources = list(glob.glob(os.path.join(extensions_dir, "**/*.cpp"), recursive=True))
276266

277-
if not IS_ROCM and use_cuda:
278-
sources += cuda_sources
279-
280-
# TOOD: Remove this and use what CUDA has once we fix all the builds.
281-
if IS_ROCM and use_cuda:
282-
# Add ROCm GPU architecture check
283-
gpu_arch = torch.cuda.get_device_properties(0).name
284-
if gpu_arch != "gfx942":
285-
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
286-
print(
287-
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
267+
# Collect CUDA source files if needed
268+
if use_cuda:
269+
if not IS_ROCM:
270+
# Regular CUDA sources
271+
extensions_cuda_dir = os.path.join(extensions_dir, "cuda")
272+
cuda_sources = list(
273+
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
274+
)
275+
sources += cuda_sources
276+
else:
277+
# ROCm sources
278+
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
279+
hip_sources = list(
280+
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
288281
)
289-
return None
290-
sources += hip_sources
291282

292-
if len(sources) == 0:
283+
# Check ROCm GPU architecture compatibility
284+
gpu_arch = torch.cuda.get_device_properties(0).name
285+
if gpu_arch != "gfx942":
286+
print(f"Warning: Unsupported ROCm GPU architecture: {gpu_arch}")
287+
print(
288+
"Currently only gfx942 is supported. Skipping compilation of ROCm extensions"
289+
)
290+
return None
291+
sources += hip_sources
292+
293+
# Return None if no sources found
294+
if not sources:
293295
return None
294296

295297
ext_modules = []
@@ -304,7 +306,6 @@ def get_extensions():
304306
)
305307
)
306308

307-
308309
if build_torchao_experimental:
309310
ext_modules.append(
310311
CMakeExtension(

0 commit comments

Comments
 (0)