@@ -79,6 +79,7 @@ def use_debug_mode():
7979
8080IS_ROCM = (torch .version .hip is not None ) and (ROCM_HOME is not None )
8181
82+
8283class BuildOptions :
8384 def __init__ (self ):
8485 # TORCHAO_BUILD_CPU_AARCH64 is enabled by default on Arm-based Apple machines
@@ -90,9 +91,9 @@ def __init__(self):
9091 default = (self ._is_arm64 () and self ._is_macos ()),
9192 )
9293 if self .build_cpu_aarch64 :
93- assert (
94- self . _is_arm64 ()
95- ), "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
94+ assert self . _is_arm64 (), (
95+ "TORCHAO_BUILD_CPU_AARCH64 requires an arm64 machine"
96+ )
9697
9798 # TORCHAO_BUILD_KLEIDIAI is disabled by default for now because
9899 # 1) It increases the build time
@@ -101,9 +102,9 @@ def __init__(self):
101102 "TORCHAO_BUILD_KLEIDIAI" , default = False
102103 )
103104 if self .build_kleidi_ai :
104- assert (
105- self . build_cpu_aarch64
106- ), "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
105+ assert self . build_cpu_aarch64 , (
106+ "TORCHAO_BUILD_KLEIDIAI requires TORCHAO_BUILD_CPU_AARCH64 be set"
107+ )
107108
108109 # TORCHAO_BUILD_EXPERIMENTAL_MPS is disabled by default.
109110 self .build_experimental_mps = self ._os_bool_var (
@@ -112,9 +113,9 @@ def __init__(self):
112113 if self .build_experimental_mps :
113114 assert self ._is_macos (), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MacOS"
114115 assert self ._is_arm64 (), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires arm64"
115- assert (
116- torch . mps . is_available ()
117- ), "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
116+ assert torch . mps . is_available (), (
117+ "TORCHAO_BUILD_EXPERIMENTAL_MPS requires MPS be available"
118+ )
118119
119120 def _is_arm64 (self ) -> bool :
120121 return platform .machine ().startswith ("arm64" )
@@ -341,7 +342,9 @@ def get_extensions():
341342 sources += cuda_sources
342343 else :
343344 # ROCm sources
344- extensions_hip_dir = os .path .join (extensions_dir , "cuda" , "sparse_marlin" , "tensor_core_tiled_layout" )
345+ extensions_hip_dir = os .path .join (
346+ extensions_dir , "cuda" , "sparse_marlin" , "tensor_core_tiled_layout"
347+ )
345348 hip_sources = list (
346349 glob .glob (os .path .join (extensions_hip_dir , "*.cu" ), recursive = True )
347350 )
0 commit comments