diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 11a51f7b533..5928e6718c1 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1379,10 +1379,9 @@ def test__transform(self, mocker): class TestRandomShortestSize: - def test__get_params(self, mocker): + @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) + def test__get_params(self, min_size, max_size, mocker): spatial_size = (3, 10) - min_size = [5, 9] - max_size = 20 transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) @@ -1395,10 +1394,9 @@ def test__get_params(self, mocker): assert isinstance(size, tuple) and len(size) == 2 longer = max(size) - assert longer <= max_size - shorter = min(size) - if longer == max_size: + if max_size is not None: + assert longer <= max_size assert shorter <= max_size else: assert shorter in min_size diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4987256ce8e..5c67bf0ec78 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -730,7 +730,7 @@ class RandomShortestSize(Transform): def __init__( self, min_size: Union[List[int], Tuple[int], int], - max_size: int, + max_size: Optional[int] = None, interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, ): @@ -744,7 +744,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_height, orig_width = query_spatial_size(flat_inputs) min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] - r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) + r = min_size / min(orig_height, orig_width) + if self.max_size is not None: + r = min(r, self.max_size / max(orig_height, orig_width)) new_width = int(orig_width * r) new_height = int(orig_height * r)