diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index aa11982a2f3..de37f8570de 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -136,7 +136,6 @@ class TestSmoke: (transforms.RandomRotation(degrees=30), None), (transforms.RandomShortestSize(min_size=10, antialias=True), None), (transforms.RandomVerticalFlip(p=1.0), None), - (transforms.RandomZoomOut(p=1.0), None), (transforms.Resize([16, 16], antialias=True), None), (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None), (transforms.ClampBoundingBoxes(), None), @@ -388,34 +387,6 @@ def was_applied(output, inpt): assert transform.was_applied(output, input) -class TestRandomZoomOut: - def test_assertions(self): - with pytest.raises(TypeError, match="Got inappropriate fill arg"): - transforms.RandomZoomOut(fill="abc") - - with pytest.raises(TypeError, match="should be a sequence of length"): - transforms.RandomZoomOut(0, side_range=0) - - with pytest.raises(ValueError, match="Invalid canvas side range"): - transforms.RandomZoomOut(0, side_range=[4.0, 1.0]) - - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__get_params(self, fill, side_range): - transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) - - h, w = size = (24, 32) - image = make_image(size) - - params = transform._get_params([image]) - - assert len(params["padding"]) == 4 - assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w - assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h - assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w - assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h - - class TestElasticTransform: def test_assertions(self): diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 55423d359dd..0748bdd343a 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -4000,3 +4000,67 @@ def test_random_transform_correctness(self, num_input_channels): expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_input_channels)) assert_equal(actual, expected, rtol=0, atol=1) + + +class TestRandomZoomOut: + # Tests are light because this largely relies on the already tested `pad` kernels. + + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_boxes, + make_segmentation_mask, + make_detection_mask, + make_video, + ], + ) + def test_transform(self, make_input): + check_transform(transforms.RandomZoomOut(p=1), make_input()) + + def test_transform_error(self): + for side_range in [None, 1, [1, 2, 3]]: + with pytest.raises( + ValueError if isinstance(side_range, list) else TypeError, match="should be a sequence of length 2" + ): + transforms.RandomZoomOut(side_range=side_range) + + for side_range in [[0.5, 1.5], [2.0, 1.0]]: + with pytest.raises(ValueError, match="Invalid side range"): + transforms.RandomZoomOut(side_range=side_range) + + @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) + @pytest.mark.parametrize( + "make_input", + [ + make_image_tensor, + make_image_pil, + make_image, + make_bounding_boxes, + make_segmentation_mask, + make_detection_mask, + make_video, + ], + ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_transform_params_correctness(self, side_range, make_input, device): + if make_input is make_image_pil and device != "cpu": + pytest.skip("PIL image tests with parametrization device!='cpu' will degenerate to that anyway.") + + transform = transforms.RandomZoomOut(side_range=side_range) + + input = make_input() + height, width = F.get_size(input) + + params = transform._get_params([input]) + assert "padding" in params + + padding = params["padding"] + assert len(padding) == 4 + + assert 0 <= padding[0] <= (side_range[1] - 1) * width + assert 0 <= padding[1] <= (side_range[1] - 1) * height + assert 0 <= padding[2] <= (side_range[1] - 1) * width + assert 0 <= padding[3] <= (side_range[1] - 1) * height diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index ce98a1ee091..e184f8085e4 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -546,7 +546,7 @@ def __init__( self.side_range = side_range if side_range[0] < 1.0 or side_range[0] > side_range[1]: - raise ValueError(f"Invalid canvas side range provided {side_range}.") + raise ValueError(f"Invalid side range provided {side_range}.") def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: orig_h, orig_w = query_size(flat_inputs)