From 8304f2da75d52dbb03754cd1d38bce64e68a89e9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 11:45:24 +0100 Subject: [PATCH 1/3] Extend RandomShortestSize to support Video specific flavour of the augmentation --- torchvision/prototype/transforms/_geometry.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b09533273e4..8bea7843e25 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -733,7 +733,7 @@ class RandomShortestSize(Transform): def __init__( self, min_size: Union[List[int], Tuple[int], int], - max_size: int, + max_size: Optional[int] = None, interpolation: InterpolationMode = InterpolationMode.BILINEAR, antialias: Optional[bool] = None, ): @@ -747,7 +747,9 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: orig_height, orig_width = query_spatial_size(sample) min_size = self.min_size[int(torch.randint(len(self.min_size), ()))] - r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width)) + r = min_size / min(orig_height, orig_width) + if self.max_size is not None: + r = min(r, self.max_size / max(orig_height, orig_width)) new_width = int(orig_width * r) new_height = int(orig_height * r) From 8ca29dd202e3200fa0f862505a9e482b7796aabc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 12:23:17 +0100 Subject: [PATCH 2/3] Adding a test. --- test/test_prototype_transforms.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 11a51f7b533..7e85d94f824 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1379,10 +1379,9 @@ def test__transform(self, mocker): class TestRandomShortestSize: - def test__get_params(self, mocker): + @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) + def test__get_params(self, min_size, max_size, mocker): spatial_size = (3, 10) - min_size = [5, 9] - max_size = 20 transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size) @@ -1395,10 +1394,9 @@ def test__get_params(self, mocker): assert isinstance(size, tuple) and len(size) == 2 longer = max(size) - assert longer <= max_size - shorter = min(size) - if longer == max_size: + if max_size is not None and longer == max_size: + assert longer <= max_size assert shorter <= max_size else: assert shorter in min_size From 431dd762e03fb75d13f233905c0f9ff69d1f461a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Oct 2022 12:28:03 +0100 Subject: [PATCH 3/3] Apply changes from code review --- test/test_prototype_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 7e85d94f824..5928e6718c1 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1395,7 +1395,7 @@ def test__get_params(self, min_size, max_size, mocker): longer = max(size) shorter = min(size) - if max_size is not None and longer == max_size: + if max_size is not None: assert longer <= max_size assert shorter <= max_size else: