Skip to content

Commit 9dd75c8

Browse files
committed
port tests for transforms.RandomZoomOut
1 parent d7c20dd commit 9dd75c8

File tree

3 files changed

+63
-30
lines changed

3 files changed

+63
-30
lines changed

test/test_transforms_v2.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,6 @@ class TestSmoke:
138138
(transforms.RandomRotation(degrees=30), None),
139139
(transforms.RandomShortestSize(min_size=10, antialias=True), None),
140140
(transforms.RandomVerticalFlip(p=1.0), None),
141-
(transforms.RandomZoomOut(p=1.0), None),
142141
(transforms.Resize([16, 16], antialias=True), None),
143142
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2), antialias=True), None),
144143
(transforms.ClampBoundingBoxes(), None),
@@ -390,34 +389,6 @@ def was_applied(output, inpt):
390389
assert transform.was_applied(output, input)
391390

392391

393-
class TestRandomZoomOut:
394-
def test_assertions(self):
395-
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
396-
transforms.RandomZoomOut(fill="abc")
397-
398-
with pytest.raises(TypeError, match="should be a sequence of length"):
399-
transforms.RandomZoomOut(0, side_range=0)
400-
401-
with pytest.raises(ValueError, match="Invalid canvas side range"):
402-
transforms.RandomZoomOut(0, side_range=[4.0, 1.0])
403-
404-
@pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)])
405-
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
406-
def test__get_params(self, fill, side_range):
407-
transform = transforms.RandomZoomOut(fill=fill, side_range=side_range)
408-
409-
h, w = size = (24, 32)
410-
image = make_image(size)
411-
412-
params = transform._get_params([image])
413-
414-
assert len(params["padding"]) == 4
415-
assert 0 <= params["padding"][0] <= (side_range[1] - 1) * w
416-
assert 0 <= params["padding"][1] <= (side_range[1] - 1) * h
417-
assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w
418-
assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h
419-
420-
421392
class TestElasticTransform:
422393
def test_assertions(self):
423394

test/test_transforms_v2_refactored.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3945,3 +3945,65 @@ def test_transform_correctness(self, brightness, contrast, saturation, hue):
39453945

39463946
mae = (actual.float() - expected.float()).abs().mean()
39473947
assert mae < 2
3948+
3949+
3950+
class TestRandomZoomOut:
3951+
@pytest.mark.parametrize(
3952+
"make_input",
3953+
[
3954+
make_image_tensor,
3955+
make_image_pil,
3956+
make_image,
3957+
make_bounding_boxes,
3958+
make_segmentation_mask,
3959+
make_detection_mask,
3960+
make_video,
3961+
],
3962+
)
3963+
def test_transform(self, make_input):
3964+
check_transform(transforms.RandomZoomOut(p=1), make_input())
3965+
3966+
def test_transform_error(self):
3967+
for side_range in [None, 1, [1, 2, 3]]:
3968+
with pytest.raises(
3969+
ValueError if isinstance(side_range, list) else TypeError, match="should be a sequence of length 2"
3970+
):
3971+
transforms.RandomZoomOut(side_range=side_range)
3972+
3973+
for side_range in [[0.5, 1.5], [2.0, 1.0]]:
3974+
with pytest.raises(ValueError, match="Invalid side range"):
3975+
transforms.RandomZoomOut(side_range=side_range)
3976+
3977+
@pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]])
3978+
@pytest.mark.parametrize(
3979+
"make_input",
3980+
[
3981+
make_image_tensor,
3982+
make_image_pil,
3983+
make_image,
3984+
make_bounding_boxes,
3985+
make_segmentation_mask,
3986+
make_detection_mask,
3987+
make_video,
3988+
],
3989+
)
3990+
@pytest.mark.parametrize("device", cpu_and_cuda())
3991+
def test_transform_params_correctness(self, side_range, make_input, device):
3992+
if make_input is make_image_pil and device != "cpu":
3993+
pytest.skip("PIL image tests with parametrization device!='cpu' will degenerate to that anyway.")
3994+
3995+
transform = transforms.RandomZoomOut(side_range=side_range)
3996+
3997+
input = make_input()
3998+
height, width = F.get_size(input)
3999+
4000+
params = transform._get_params([input])
4001+
assert "padding" in params
4002+
4003+
padding = params["padding"]
4004+
assert len(padding) == 4
4005+
4006+
assert 0 <= padding[0] <= (side_range[1] - 1) * width
4007+
assert 0 <= padding[1] <= (side_range[1] - 1) * height
4008+
assert 0 <= padding[2] <= (side_range[1] - 1) * width
4009+
assert 0 <= padding[3] <= (side_range[1] - 1) * height

torchvision/transforms/v2/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def __init__(
546546

547547
self.side_range = side_range
548548
if side_range[0] < 1.0 or side_range[0] > side_range[1]:
549-
raise ValueError(f"Invalid canvas side range provided {side_range}.")
549+
raise ValueError(f"Invalid side range provided {side_range}.")
550550

551551
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
552552
orig_h, orig_w = query_size(flat_inputs)

0 commit comments

Comments
 (0)