Skip to content

Commit 09030c0

Browse files
authored
Add cuda resnet50 test to smoke test (#7020)
* Add cuda resnet50 test * Fix path * Tune vision smoke test
1 parent 23d3f78 commit 09030c0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

test/smoke_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def smoke_test_torchvision() -> None:
1717
all(x is not None for x in [torch.ops.image.decode_png, torch.ops.torchvision.roi_align]),
1818
)
1919

20-
2120
def smoke_test_torchvision_read_decode() -> None:
2221
img_jpg = read_image(str(SCRIPT_DIR / "assets" / "encode_jpeg" / "grace_hopper_517x606.jpg"))
2322
if img_jpg.ndim != 3 or img_jpg.numel() < 100:
@@ -26,13 +25,12 @@ def smoke_test_torchvision_read_decode() -> None:
2625
if img_png.ndim != 3 or img_png.numel() < 100:
2726
raise RuntimeError(f"Unexpected shape of img_png: {img_png.shape}")
2827

29-
30-
def smoke_test_torchvision_resnet50_classify() -> None:
31-
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg"))
28+
def smoke_test_torchvision_resnet50_classify(device: str = "cpu") -> None:
29+
img = read_image(str(SCRIPT_DIR / ".." / "gallery" / "assets" / "dog2.jpg")).to(device)
3230

3331
# Step 1: Initialize model with the best available weights
3432
weights = ResNet50_Weights.DEFAULT
35-
model = resnet50(weights=weights)
33+
model = resnet50(weights=weights).to(device)
3634
model.eval()
3735

3836
# Step 2: Initialize the inference transforms
@@ -47,17 +45,19 @@ def smoke_test_torchvision_resnet50_classify() -> None:
4745
score = prediction[class_id].item()
4846
category_name = weights.meta["categories"][class_id]
4947
expected_category = "German shepherd"
50-
print(f"{category_name}: {100 * score:.1f}%")
48+
print(f"{category_name} ({device}): {100 * score:.1f}%")
5149
if category_name != expected_category:
52-
raise RuntimeError(f"Failed ResNet50 classify {category_name} Expected: {expected_category}")
53-
50+
raise RuntimeError(
51+
f"Failed ResNet50 classify {category_name} Expected: {expected_category}"
52+
)
5453

5554
def main() -> None:
5655
print(f"torchvision: {torchvision.__version__}")
5756
smoke_test_torchvision()
5857
smoke_test_torchvision_read_decode()
5958
smoke_test_torchvision_resnet50_classify()
60-
59+
if torch.cuda.is_available():
60+
smoke_test_torchvision_resnet50_classify("cuda")
6161

6262
if __name__ == "__main__":
6363
main()

0 commit comments

Comments
 (0)