Skip to content

Commit 88b6b93

Browse files
authored
Extend RandomShortestSize to support Video specific flavour of the augmentation (#6770)
* Extend RandomShortestSize to support Video specific flavour of the augmentation * Adding a test. * Apply changes from code review
1 parent e3238e5 commit 88b6b93

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

test/test_prototype_transforms.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1379,10 +1379,9 @@ def test__transform(self, mocker):
13791379

13801380

13811381
class TestRandomShortestSize:
1382-
def test__get_params(self, mocker):
1382+
@pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)])
1383+
def test__get_params(self, min_size, max_size, mocker):
13831384
spatial_size = (3, 10)
1384-
min_size = [5, 9]
1385-
max_size = 20
13861385

13871386
transform = transforms.RandomShortestSize(min_size=min_size, max_size=max_size)
13881387

@@ -1395,10 +1394,9 @@ def test__get_params(self, mocker):
13951394
assert isinstance(size, tuple) and len(size) == 2
13961395

13971396
longer = max(size)
1398-
assert longer <= max_size
1399-
14001397
shorter = min(size)
1401-
if longer == max_size:
1398+
if max_size is not None:
1399+
assert longer <= max_size
14021400
assert shorter <= max_size
14031401
else:
14041402
assert shorter in min_size

torchvision/prototype/transforms/_geometry.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ class RandomShortestSize(Transform):
730730
def __init__(
731731
self,
732732
min_size: Union[List[int], Tuple[int], int],
733-
max_size: int,
733+
max_size: Optional[int] = None,
734734
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
735735
antialias: Optional[bool] = None,
736736
):
@@ -744,7 +744,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
744744
orig_height, orig_width = query_spatial_size(flat_inputs)
745745

746746
min_size = self.min_size[int(torch.randint(len(self.min_size), ()))]
747-
r = min(min_size / min(orig_height, orig_width), self.max_size / max(orig_height, orig_width))
747+
r = min_size / min(orig_height, orig_width)
748+
if self.max_size is not None:
749+
r = min(r, self.max_size / max(orig_height, orig_width))
748750

749751
new_width = int(orig_width * r)
750752
new_height = int(orig_height * r)

0 commit comments

Comments
 (0)