diff --git a/test/test_transforms.py b/test/test_transforms.py index a65d848ec92..7346e2c5094 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -298,6 +298,11 @@ def test_random_crop(self): self.assertEqual(result.size(1), height + 1) self.assertEqual(result.size(2), width + 1) + t = transforms.RandomCrop(48) + img = torch.ones(3, 32, 32) + with self.assertRaisesRegex(ValueError, r"Required crop size .+ is larger then input image size .+"): + t(img) + def test_pad(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 07043b65a56..6a7c6515790 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -532,6 +532,12 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int """ w, h = F._get_image_size(img) th, tw = output_size + + if h + 1 < th or w + 1 < tw: + raise ValueError( + "Required crop size {} is larger then input image size {}".format((th, tw), (h, w)) + ) + if w == tw and h == th: return 0, 0, h, w