diff --git a/test/test_transforms.py b/test/test_transforms.py index 01688022b58..6de4be5943e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1670,7 +1670,7 @@ def test_random_crop(): assert result.size(1) == height + 1 assert result.size(2) == width + 1 - t = transforms.RandomCrop(48) + t = transforms.RandomCrop(33) img = torch.ones(3, 32, 32) with pytest.raises(ValueError, match=r"Required crop size .+ is larger than input image size .+"): t(img) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 79704099d98..301d52df1d7 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -443,7 +443,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: if height < output_height: height += 2 * (output_height - height) - if height + 1 < output_height or width + 1 < output_width: + if height < output_height or width < output_width: raise ValueError( f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}" ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 095460675cc..4076b65dbd6 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -628,7 +628,7 @@ def get_params(img: Tensor, output_size: Tuple[int, int]) -> Tuple[int, int, int _, h, w = F.get_dimensions(img) th, tw = output_size - if h + 1 < th or w + 1 < tw: + if h < th or w < tw: raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}") if w == tw and h == th: