Skip to content

Commit a4e8c30

Browse files
author
Peter Y. Yeh
committed
lint
1 parent 15974c7 commit a4e8c30

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

setup.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,27 +52,36 @@ def read_version(file_path="version.txt"):
5252
import torch
5353
from torch.utils.cpp_extension import (
5454
CUDA_HOME,
55-
ROCM_HOME,
5655
IS_WINDOWS,
56+
ROCM_HOME,
5757
BuildExtension,
5858
CppExtension,
5959
CUDAExtension,
6060
)
6161

6262
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
6363

64+
6465
def get_extensions():
6566
debug_mode = os.getenv("DEBUG", "0") == "1"
6667
if debug_mode:
6768
print("Compiling in debug mode")
6869

6970
if not torch.cuda.is_available():
70-
print("PyTorch GPU support is not available. Skipping compilation of CUDA extensions")
71+
print(
72+
"PyTorch GPU support is not available. Skipping compilation of CUDA extensions"
73+
)
7174
if (CUDA_HOME is None and ROCM_HOME is None) and torch.cuda.is_available():
72-
print("CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions")
73-
print("If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit")
75+
print(
76+
"CUDA toolkit or ROCm is not available. Skipping compilation of CUDA extensions"
77+
)
78+
print(
79+
"If you'd like to compile CUDA extensions locally please install the cudatoolkit from https://anaconda.org/nvidia/cuda-toolkit"
80+
)
7481

75-
use_cuda = torch.cuda.is_available() and (CUDA_HOME is not None or ROCM_HOME is not None)
82+
use_cuda = torch.cuda.is_available() and (
83+
CUDA_HOME is not None or ROCM_HOME is not None
84+
)
7685
extension = CUDAExtension if use_cuda else CppExtension
7786

7887
extra_link_args = []
@@ -125,8 +134,12 @@ def get_extensions():
125134
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
126135
)
127136

128-
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin")
129-
hip_sources = list(glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True))
137+
extensions_hip_dir = os.path.join(
138+
extensions_dir, "cuda", "tensor_core_tiled_layout", "sparse_marlin"
139+
)
140+
hip_sources = list(
141+
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
142+
)
130143

131144
if not IS_ROCM and use_cuda:
132145
sources += cuda_sources

0 commit comments

Comments
 (0)