44from pathlib import Path
55
66import torch
7+ import torch .nn as nn
78import torchvision
89from torchvision .io import read_image
910from torchvision .models import resnet50 , ResNet50_Weights
@@ -26,6 +27,12 @@ def smoke_test_torchvision_read_decode() -> None:
2627 if img_png .ndim != 3 or img_png .numel () < 100 :
2728 raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
2829
30+ def smoke_test_compile () -> None :
31+ model = resnet50 ().cuda ()
32+ model = torch .compile (model )
33+ x = torch .randn (1 , 3 , 224 , 224 , device = "cuda" )
34+ out = model (x )
35+ print (f"torch.compile model output: { out .shape } " )
2936
3037def smoke_test_torchvision_resnet50_classify (device : str = "cpu" ) -> None :
3138 img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" )).to (device )
@@ -54,14 +61,18 @@ def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
5461
5562def main () -> None :
5663 print (f"torchvision: { torchvision .__version__ } " )
64+ print (f"torch.cuda.is_available: { torch .cuda .is_available ()} " )
5765 smoke_test_torchvision ()
5866 smoke_test_torchvision_read_decode ()
5967 smoke_test_torchvision_resnet50_classify ()
6068 if torch .cuda .is_available ():
6169 smoke_test_torchvision_resnet50_classify ("cuda" )
70+ smoke_test_compile ()
71+
6272 if torch .backends .mps .is_available ():
6373 smoke_test_torchvision_resnet50_classify ("mps" )
6474
6575
76+
6677if __name__ == "__main__" :
6778 main ()
0 commit comments