diff --git a/test/smoke_test.py b/test/smoke_test.py index f965c6f6aa4..1f1364512ee 100644 --- a/test/smoke_test.py +++ b/test/smoke_test.py @@ -1,8 +1,8 @@ """Run smoke tests""" import os +import sys from pathlib import Path -from sys import platform import torch import torch.nn as nn @@ -37,7 +37,7 @@ def smoke_test_compile() -> None: out = model(x) print(f"torch.compile model output: {out.shape}") except RuntimeError: - if platform == "win32": + if sys.platform == "win32": print("Successfully caught torch.compile RuntimeError on win") elif sys.version_info >= (3, 11, 0): print("Successfully caught torch.compile RuntimeError on Python 3.11")