diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index c356d2823ae..28b21ebbaf6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1125,3 +1125,42 @@ def test_ctor(self, trfms): inpt = torch.rand(1, 3, 32, 32) output = c(inpt) assert isinstance(output, torch.Tensor) + + +class TestScaleJitter: + def test__get_params(self, mocker): + image_size = (24, 32) + target_size = (16, 12) + scale_range = (0.5, 1.5) + + transform = transforms.ScaleJitter(target_size=target_size, scale_range=scale_range) + + sample = mocker.MagicMock(spec=features.Image, num_channels=3, image_size=image_size) + params = transform._get_params(sample) + + assert "size" in params + size = params["size"] + + assert isinstance(size, tuple) and len(size) == 2 + height, width = size + + assert int(target_size[0] * scale_range[0]) <= height <= int(target_size[0] * scale_range[1]) + assert int(target_size[1] * scale_range[0]) <= width <= int(target_size[1] * scale_range[1]) + + def test__transform(self, mocker): + interpolation_sentinel = mocker.MagicMock() + + transform = transforms.ScaleJitter(target_size=(16, 12), interpolation=interpolation_sentinel) + transform._transformed_types = (mocker.MagicMock,) + + size_sentinel = mocker.MagicMock() + mocker.patch( + "torchvision.prototype.transforms._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel) + ) + + inpt_sentinel = mocker.MagicMock() + + mock = mocker.patch("torchvision.prototype.transforms._geometry.F.resize") + transform(inpt_sentinel) + + mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5617c010e5f..e92ab2f154c 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -30,6 +30,7 @@ RandomVerticalFlip, RandomZoomOut, Resize, + ScaleJitter, TenCrop, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageColorSpace, ConvertImageDtype diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 1faa9ec5842..c88d05cd58f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -631,3 +631,29 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, interpolation=self.interpolation, ) + + +class ScaleJitter(Transform): + def __init__( + self, + target_size: Tuple[int, int], + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = interpolation + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + _, orig_height, orig_width = get_image_dimensions(image) + + r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + new_width = int(self.target_size[1] * r) + new_height = int(self.target_size[0] * r) + + return dict(size=(new_height, new_width)) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.resize(inpt, size=params["size"], interpolation=self.interpolation)