diff --git a/test/test_image.py b/test/test_image.py index f71e023c4ae..4c210ea7eef 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -368,10 +368,10 @@ def test_decode_jpeg_cuda(mode, img_path, scripted): # Some difference expected between jpeg implementations assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 + @needs_cuda def test_decode_image_cuda_raises(): data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8) - exception_raised = True with pytest.raises(RuntimeError): decode_image(data)