@@ -17,7 +17,6 @@ def smoke_test_torchvision() -> None:
1717 all (x is not None for x in [torch .ops .image .decode_png , torch .ops .torchvision .roi_align ]),
1818 )
1919
20-
2120def smoke_test_torchvision_read_decode () -> None :
2221 img_jpg = read_image (str (SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg" ))
2322 if img_jpg .ndim != 3 or img_jpg .numel () < 100 :
@@ -26,13 +25,12 @@ def smoke_test_torchvision_read_decode() -> None:
2625 if img_png .ndim != 3 or img_png .numel () < 100 :
2726 raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
2827
29-
30- def smoke_test_torchvision_resnet50_classify () -> None :
31- img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" ))
28+ def smoke_test_torchvision_resnet50_classify (device : str = "cpu" ) -> None :
29+ img = read_image (str (SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg" )).to (device )
3230
3331 # Step 1: Initialize model with the best available weights
3432 weights = ResNet50_Weights .DEFAULT
35- model = resnet50 (weights = weights )
33+ model = resnet50 (weights = weights ). to ( device )
3634 model .eval ()
3735
3836 # Step 2: Initialize the inference transforms
@@ -47,17 +45,19 @@ def smoke_test_torchvision_resnet50_classify() -> None:
4745 score = prediction [class_id ].item ()
4846 category_name = weights .meta ["categories" ][class_id ]
4947 expected_category = "German shepherd"
50- print (f"{ category_name } : { 100 * score :.1f} %" )
48+ print (f"{ category_name } ( { device } ) : { 100 * score :.1f} %" )
5149 if category_name != expected_category :
52- raise RuntimeError (f"Failed ResNet50 classify { category_name } Expected: { expected_category } " )
53-
50+ raise RuntimeError (
51+ f"Failed ResNet50 classify { category_name } Expected: { expected_category } "
52+ )
5453
5554def main () -> None :
5655 print (f"torchvision: { torchvision .__version__ } " )
5756 smoke_test_torchvision ()
5857 smoke_test_torchvision_read_decode ()
5958 smoke_test_torchvision_resnet50_classify ()
60-
59+ if torch .cuda .is_available ():
60+ smoke_test_torchvision_resnet50_classify ("cuda" )
6161
6262if __name__ == "__main__" :
6363 main ()
0 commit comments