From fa6e8dda4ed7220ee279f5707d054c23d6b22350 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 9 Mar 2023 16:44:08 +0000 Subject: [PATCH 1/2] Add device check to `io.decode_image` It only works for CPU tensors, so raise an error if called with non-CPU tensor --- test/test_image.py | 10 ++++++++++ torchvision/csrc/io/image/cpu/decode_image.cpp | 2 ++ 2 files changed, 12 insertions(+) diff --git a/test/test_image.py b/test/test_image.py index 7fcd54c9c8f..06c2453cbbe 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -368,6 +368,16 @@ def test_decode_jpeg_cuda(mode, img_path, scripted): # Some difference expected between jpeg implementations assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 +@needs_cuda +def test_decode_image_cuda_raises(): + data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8) + exception_raised = True + try: + decode_image(data) + except RuntimeError as e: + exception_raised = True + assert exception_raised + @needs_cuda @pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda"))) diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 1cc05dc76ca..da4dc5833de 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -7,6 +7,8 @@ namespace vision { namespace image { torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { + // Check that tensor is a CPU tensor + TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional From 740443c918d454759fbc39a3b5e408e9c988b7b9 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 9 Mar 2023 09:08:41 -0800 Subject: [PATCH 2/2] Update test/test_image.py Co-authored-by: Nicolas Hug --- test/test_image.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/test_image.py b/test/test_image.py index 06c2453cbbe..f71e023c4ae 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -372,11 +372,8 @@ def test_decode_jpeg_cuda(mode, img_path, scripted): def test_decode_image_cuda_raises(): data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8) exception_raised = True - try: + with pytest.raises(RuntimeError): decode_image(data) - except RuntimeError as e: - exception_raised = True - assert exception_raised @needs_cuda