Skip to content

Commit ee468d6

Browse files
committed
Add smoke test Using a simple RN50 with torch.compile
1 parent 924d373 commit ee468d6

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

test/smoke_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ def smoke_test_torchvision_read_decode() -> None:
2626
if img_png.ndim != 3 or img_png.numel() < 100:
2727
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
2828

29+
def smoke_test_compile() -> None:
30+
import torch.nn as nn
31+
model = resnet50().cuda()
32+
model = torch.compile(model)
33+
x = torch.randn(1, 3, 224, 224).cuda()
34+
out = model(x)
35+
print(out.shape)
2936

3037
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
3138
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
@@ -59,8 +66,12 @@ def main() -> None:
5966
smoke_test_torchvision_resnet50_classify()
6067
if torch.cuda.is_available():
6168
smoke_test_torchvision_resnet50_classify("cuda")
69+
<<<<<<< HEAD
6270
if torch.backends.mps.is_available():
6371
smoke_test_torchvision_resnet50_classify("mps")
72+
=======
73+
smoke_test_compile()
74+
>>>>>>> 2b8667d9a4 (Add smoke test Using a simple RN50 with torch.compile)
6475

6576

6677
if __name__ == "__main__":

0 commit comments

Comments
 (0)