11"""Run smoke tests"""
22
3- import os
43import sys
54from pathlib import Path
65
76import torch
8- import torch .nn as nn
97import torchvision
10- from torchvision .io import read_image
8+ from torchvision .io import decode_jpeg , read_file , read_image
119from torchvision .models import resnet50 , ResNet50_Weights
1210
1311SCRIPT_DIR = Path (__file__ ).parent
@@ -22,13 +20,20 @@ def smoke_test_torchvision() -> None:
2220
2321def smoke_test_torchvision_read_decode () -> None :
2422 img_jpg = read_image (str (SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg" ))
25- if img_jpg .ndim != 3 or img_jpg . numel () < 100 :
23+ if img_jpg .shape != ( 3 , 606 , 517 ) :
2624 raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
2725 img_png = read_image (str (SCRIPT_DIR / "assets" / "interlaced_png" / "wizard_low.png" ))
28- if img_png .ndim != 3 or img_png . numel () < 100 :
26+ if img_png .shape != ( 4 , 471 , 354 ) :
2927 raise RuntimeError (f"Unexpected shape of img_png: { img_png .shape } " )
3028
3129
30+ def smoke_test_torchvision_decode_jpeg_cuda ():
31+ img_jpg_data = read_file (str (SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg" ))
32+ img_jpg = decode_jpeg (img_jpg_data , device = "cuda" )
33+ if img_jpg .shape != (3 , 606 , 517 ):
34+ raise RuntimeError (f"Unexpected shape of img_jpg: { img_jpg .shape } " )
35+
36+
3237def smoke_test_compile () -> None :
3338 try :
3439 model = resnet50 ().cuda ()
@@ -77,6 +82,7 @@ def main() -> None:
7782 smoke_test_torchvision_read_decode ()
7883 smoke_test_torchvision_resnet50_classify ()
7984 if torch .cuda .is_available ():
85+ smoke_test_torchvision_decode_jpeg_cuda ()
8086 smoke_test_torchvision_resnet50_classify ("cuda" )
8187 smoke_test_compile ()
8288
0 commit comments