diff --git a/test/test_image.py b/test/test_image.py index 7fcd54c9c8f..f71e023c4ae 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -368,6 +368,13 @@ def test_decode_jpeg_cuda(mode, img_path, scripted): # Some difference expected between jpeg implementations assert (img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2 +@needs_cuda +def test_decode_image_cuda_raises(): + data = torch.randint(0, 127, size=(255,), device="cuda", dtype=torch.uint8) + exception_raised = True + with pytest.raises(RuntimeError): + decode_image(data) + @needs_cuda @pytest.mark.parametrize("cuda_device", ("cuda", "cuda:0", torch.device("cuda"))) diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 1cc05dc76ca..da4dc5833de 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -7,6 +7,8 @@ namespace vision { namespace image { torch::Tensor decode_image(const torch::Tensor& data, ImageReadMode mode) { + // Check that tensor is a CPU tensor + TORCH_CHECK(data.device() == torch::kCPU, "Expected a CPU tensor"); // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); // Check that the input tensor is 1-dimensional