Skip to content

Commit 8124a58

Browse files
author
Peter Yeh
committed
lint
1 parent 75b6816 commit 8124a58

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

setup.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ def use_debug_mode():
7979

8080
IS_ROCM = (torch.version.hip is not None) and (ROCM_HOME is not None)
8181

82+
8283
class BuildOptions:
8384
def __init__(self):
8485
# TORCHAO_BUILD_CPU_AARCH64 is enabled by default on Arm-based Apple machines
@@ -90,9 +91,9 @@ def __init__(self):
9091
default=(self._is_arm64() and self._is_macos()),
9192
)
9293
if self.build_cpu_aarch64:
93-
assert (
94-
self._is_arm64()
95-
), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
94+
assert self._is_arm64(), (
95+
"TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
96+
)
9697

9798
# TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
9899
# 1) It increases the build time
@@ -101,9 +102,9 @@ def __init__(self):
101102
"TORCHAO_BUILD_KLEIDIAI", default=False
102103
)
103104
if self.build_kleidi_ai:
104-
assert (
105-
self.build_cpu_aarch64
106-
), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
105+
assert self.build_cpu_aarch64, (
106+
"TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
107+
)
107108

108109
# TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
109110
self.build_experimental_mps = self._os_bool_var(
@@ -112,9 +113,9 @@ def __init__(self):
112113
if self.build_experimental_mps:
113114
assert self._is_macos(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
114115
assert self._is_arm64(), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
115-
assert (
116-
torch.mps.is_available()
117-
), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
116+
assert torch.mps.is_available(), (
117+
"TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
118+
)
118119

119120
def _is_arm64(self) -> bool:
120121
return platform.machine().startswith("arm64")
@@ -341,7 +342,9 @@ def get_extensions():
341342
sources += cuda_sources
342343
else:
343344
# ROCm sources
344-
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin", "tensor_core_tiled_layout")
345+
extensions_hip_dir = os.path.join(
346+
extensions_dir, "cuda", "sparse_marlin", "tensor_core_tiled_layout"
347+
)
345348
hip_sources = list(
346349
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
347350
)

0 commit comments

Comments
 (0)