@@ -146,7 +146,7 @@ def test_cuda_runtime_errors_captured() -> None:
146
146
raise RuntimeError ("Expected CUDA RuntimeError but have not received!" )
147
147
148
148
149
- def smoke_test_cuda (package : str , runtime_error_check : str ) -> None :
149
+ def smoke_test_cuda (package : str , runtime_error_check : str , torch_compile_check : str ) -> None :
150
150
if not torch .cuda .is_available () and is_cuda_system :
151
151
raise RuntimeError (f"Expected CUDA { gpu_arch_ver } . However CUDA is not loaded." )
152
152
@@ -163,7 +163,7 @@ def smoke_test_cuda(package: str, runtime_error_check: str) -> None:
163
163
print (f"{ module ['name' ]} CUDA: { version } " )
164
164
165
165
# torch.compile is available on macos-arm64 and Linux for python 3.8-3.11
166
- if sys .version_info < (3 , 12 , 0 ) and (
166
+ if torch_compile_check == "enabled" and sys .version_info < (3 , 12 , 0 ) and (
167
167
(target_os == "linux" and torch .cuda .is_available ()) or
168
168
target_os == "macos-arm64" ):
169
169
smoke_test_compile ()
@@ -310,6 +310,13 @@ def main() -> None:
310
310
choices = ["enabled" , "disabled" ],
311
311
default = "enabled" ,
312
312
)
313
+ parser .add_argument (
314
+ "--torch-compile-check" ,
315
+ help = "Check torch compile" ,
316
+ type = str ,
317
+ choices = ["enabled" , "disabled" ],
318
+ default = "enabled" ,
319
+ )
313
320
options = parser .parse_args ()
314
321
print (f"torch: { torch .__version__ } " )
315
322
@@ -323,7 +330,7 @@ def main() -> None:
323
330
if options .package == "all" :
324
331
smoke_test_modules ()
325
332
326
- smoke_test_cuda (options .package , options .runtime_error_check )
333
+ smoke_test_cuda (options .package , options .runtime_error_check , options . torch_compile_check )
327
334
328
335
329
336
if __name__ == "__main__" :
0 commit comments