From c4123fac992bd92de23c03145f777ba422f11f19 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Fri, 22 Apr 2022 22:54:12 +0200 Subject: [PATCH 1/6] feat: add functional pad on segmentation mask --- test/test_prototype_transforms_functional.py | 35 +++++++++++++++++++ .../transforms/functional/__init__.py | 1 + .../transforms/functional/_geometry.py | 6 ++++ 3 files changed, 42 insertions(+) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 36d1677ede5..6d73ad67ff6 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -369,6 +369,31 @@ def resized_crop_segmentation_mask(): ): yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) +@register_kernel_info_from_sample_inputs_fn +def pad_segmentation_mask(): + for mask, padding, fill, padding_mode in itertools.product( + make_segmentation_masks(), + [[1], [1, 1], [1, 1, 2, 2]], # padding + [0, 1], # fill + ["constant", "symmetric", "edge"], # padding mode, + ): + if padding_mode == "symmetric" and mask.ndim not in [3, 4]: + continue + if padding_mode == "edge" and fill != 0: + continue + if ( + padding_mode == "edge" + and len(padding) == 2 + and mask.ndim not in [2, 3] + or len(padding) == 4 + and mask.ndim not in [4, 3] + or len(padding) == 1 + ): + continue + if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]: + continue + yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode) + @pytest.mark.parametrize( "kernel", @@ -1031,3 +1056,13 @@ def _compute_expected(mask, top_, left_, height_, width_, size_): expected_mask = _compute_expected(in_mask, top, left, height, width, size) output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size) torch.testing.assert_close(output_mask, expected_mask) + +def test_correctness_pad_segmentation_mask_on_fixed_input(device): + mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) + mask[:, 1, 1] = 0 + + out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1) + + expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device) + expected_mask[:, 2, 2] = 0 + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index dfbc81baea3..c13a94035ea 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -62,6 +62,7 @@ pad_bounding_box, pad_image_tensor, pad_image_pil, + pad_segmentation_mask, crop_bounding_box, crop_image_tensor, crop_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 5f9e77fdbf4..c372fc2f612 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -396,6 +396,12 @@ def rotate_segmentation_mask( pad_image_pil = _FP.pad +def pad_segmentation_mask( + segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant" +) -> torch.Tensor: + return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode) + + def pad_bounding_box( bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat ) -> torch.Tensor: From 0adce7a964921b2daa77901e28960197238d91b3 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Sun, 24 Apr 2022 13:56:40 +0200 Subject: [PATCH 2/6] test: add basic correctness test with random masks --- test/test_prototype_transforms_functional.py | 40 ++++++++++++++----- .../transforms/functional/_geometry.py | 4 +- 2 files changed, 33 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6d73ad67ff6..b696e47427d 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -371,16 +371,13 @@ def resized_crop_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn def pad_segmentation_mask(): - for mask, padding, fill, padding_mode in itertools.product( + for mask, padding, padding_mode in itertools.product( make_segmentation_masks(), [[1], [1, 1], [1, 1, 2, 2]], # padding - [0, 1], # fill ["constant", "symmetric", "edge"], # padding mode, ): if padding_mode == "symmetric" and mask.ndim not in [3, 4]: continue - if padding_mode == "edge" and fill != 0: - continue if ( padding_mode == "edge" and len(padding) == 2 @@ -392,7 +389,7 @@ def pad_segmentation_mask(): continue if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]: continue - yield SampleInput(mask, padding=padding, fill=fill, padding_mode=padding_mode) + yield SampleInput(mask, padding=padding, padding_mode=padding_mode) @pytest.mark.parametrize( @@ -1059,10 +1056,35 @@ def _compute_expected(mask, top_, left_, height_, width_, size_): def test_correctness_pad_segmentation_mask_on_fixed_input(device): mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) - mask[:, 1, 1] = 0 - out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1], fill=1) + out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1]) - expected_mask = torch.ones((1, 3 + 1 + 1, 3 + 1 + 1), dtype=torch.long, device=device) - expected_mask[:, 2, 2] = 0 + expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device) + expected_mask[:, 1:-1, 1:-1] = 1 torch.testing.assert_close(out_mask, expected_mask) + + +@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")]) +def test_correctness_pad_segmentation_mask(padding, padding_mode): + def compute_expected_mask(): + h, w = mask.shape[-2], mask.shape[-1] + + pad_left = padding[0] + pad_up = padding[1] + pad_right = padding[2] + pad_down = padding[3] + + new_h = h + pad_up + pad_down + new_w = w + pad_left + pad_right + + new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w) + expected_mask = torch.zeros(new_shape, dtype=torch.long) + expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask + + return expected_mask + + for mask in make_segmentation_masks(): + out_mask = F.pad_segmentation_mask(mask, padding, padding_mode) + + expected_mask = compute_expected_mask() + torch.testing.assert_close(out_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c372fc2f612..2d611c78fcc 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -397,9 +397,9 @@ def rotate_segmentation_mask( def pad_segmentation_mask( - segmentation_mask: torch.Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant" + segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant" ) -> torch.Tensor: - return pad_image_tensor(img=segmentation_mask, padding=padding, fill=fill, padding_mode=padding_mode) + return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode) def pad_bounding_box( From df4f02e8505d0a912f3e042fe8ab6a84a319bfba Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Mon, 25 Apr 2022 00:15:51 +0200 Subject: [PATCH 3/6] test: add all padding options --- test/test_prototype_transforms_functional.py | 33 +++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index b696e47427d..213b81dd62f 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -374,7 +374,7 @@ def pad_segmentation_mask(): for mask, padding, padding_mode in itertools.product( make_segmentation_masks(), [[1], [1, 1], [1, 1, 2, 2]], # padding - ["constant", "symmetric", "edge"], # padding mode, + ["constant", "symmetric", "edge", "reflect"], # padding mode, ): if padding_mode == "symmetric" and mask.ndim not in [3, 4]: continue @@ -1064,15 +1064,24 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): torch.testing.assert_close(out_mask, expected_mask) -@pytest.mark.parametrize("padding,padding_mode", [([1, 2, 3, 4], "constant")]) -def test_correctness_pad_segmentation_mask(padding, padding_mode): - def compute_expected_mask(): - h, w = mask.shape[-2], mask.shape[-1] +@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, 1.0, [1, 2]]) +def test_correctness_pad_segmentation_mask(padding): + def _parse_padding(): + if isinstance(padding, int): + return [padding] * 4 + if isinstance(padding, float): + return [int(padding)] * 4 + if isinstance(padding, list): + if len(padding) == 1: + return padding * 4 + if len(padding) == 2: + return padding * 2 # [left, up, right, down] + + return padding - pad_left = padding[0] - pad_up = padding[1] - pad_right = padding[2] - pad_down = padding[3] + def _compute_expected_mask(padding): + h, w = mask.shape[-2], mask.shape[-1] + pad_left, pad_up, pad_right, pad_down = padding new_h = h + pad_up + pad_down new_w = w + pad_left + pad_right @@ -1083,8 +1092,10 @@ def compute_expected_mask(): return expected_mask + padding = _parse_padding() + for mask in make_segmentation_masks(): - out_mask = F.pad_segmentation_mask(mask, padding, padding_mode) + out_mask = F.pad_segmentation_mask(mask, padding, "constant") - expected_mask = compute_expected_mask() + expected_mask = _compute_expected_mask(padding) torch.testing.assert_close(out_mask, expected_mask) From fe11da00f88dad1c97b2bf7a894790ba9c34052b Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Mon, 25 Apr 2022 19:44:53 +0200 Subject: [PATCH 4/6] fix: pr comments --- test/test_prototype_transforms_functional.py | 51 ++++++++------------ 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 213b81dd62f..7ca43ce9ad2 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -369,6 +369,7 @@ def resized_crop_segmentation_mask(): ): yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) + @register_kernel_info_from_sample_inputs_fn def pad_segmentation_mask(): for mask, padding, padding_mode in itertools.product( @@ -376,19 +377,12 @@ def pad_segmentation_mask(): [[1], [1, 1], [1, 1, 2, 2]], # padding ["constant", "symmetric", "edge", "reflect"], # padding mode, ): - if padding_mode == "symmetric" and mask.ndim not in [3, 4]: - continue - if ( - padding_mode == "edge" - and len(padding) == 2 - and mask.ndim not in [2, 3] - or len(padding) == 4 - and mask.ndim not in [4, 3] - or len(padding) == 1 - ): + if padding_mode == "symmetric" and mask.ndim not in [2, 3, 4]: continue - if padding_mode == "edge" and mask.ndim not in [2, 3, 4, 5]: + + if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [2, 3, 4]: continue + yield SampleInput(mask, padding=padding, padding_mode=padding_mode) @@ -1054,6 +1048,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_): output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size) torch.testing.assert_close(output_mask, expected_mask) + def test_correctness_pad_segmentation_mask_on_fixed_input(device): mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) @@ -1064,24 +1059,22 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): torch.testing.assert_close(out_mask, expected_mask) -@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, 1.0, [1, 2]]) +@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]]) def test_correctness_pad_segmentation_mask(padding): - def _parse_padding(): - if isinstance(padding, int): - return [padding] * 4 - if isinstance(padding, float): - return [int(padding)] * 4 - if isinstance(padding, list): - if len(padding) == 1: - return padding * 4 - if len(padding) == 2: - return padding * 2 # [left, up, right, down] - - return padding - - def _compute_expected_mask(padding): + def _compute_expected_mask(): + def parse_padding(): + if isinstance(padding, int): + return [padding] * 4 + if isinstance(padding, list): + if len(padding) == 1: + return padding * 4 + if len(padding) == 2: + return padding * 2 # [left, up, right, down] + + return padding + h, w = mask.shape[-2], mask.shape[-1] - pad_left, pad_up, pad_right, pad_down = padding + pad_left, pad_up, pad_right, pad_down = parse_padding() new_h = h + pad_up + pad_down new_w = w + pad_left + pad_right @@ -1092,10 +1085,8 @@ def _compute_expected_mask(padding): return expected_mask - padding = _parse_padding() - for mask in make_segmentation_masks(): out_mask = F.pad_segmentation_mask(mask, padding, "constant") - expected_mask = _compute_expected_mask(padding) + expected_mask = _compute_expected_mask() torch.testing.assert_close(out_mask, expected_mask) From 846ba4baf616969e8f4c4b186b1aa9adb0a064e5 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Wed, 27 Apr 2022 00:12:30 +0200 Subject: [PATCH 5/6] fix: tests --- test/test_prototype_transforms_functional.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 7ca43ce9ad2..f5fc31660e6 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -377,10 +377,10 @@ def pad_segmentation_mask(): [[1], [1, 1], [1, 1, 2, 2]], # padding ["constant", "symmetric", "edge", "reflect"], # padding mode, ): - if padding_mode == "symmetric" and mask.ndim not in [2, 3, 4]: + if padding_mode == "symmetric" and mask.ndim not in [3, 4]: continue - if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [2, 3, 4]: + if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [3, 4]: continue yield SampleInput(mask, padding=padding, padding_mode=padding_mode) @@ -1049,6 +1049,7 @@ def _compute_expected(mask, top_, left_, height_, width_, size_): torch.testing.assert_close(output_mask, expected_mask) +@pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_pad_segmentation_mask_on_fixed_input(device): mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) From 5b0d597960f4e605fdddcd57a1988bdf7446cdb4 Mon Sep 17 00:00:00 2001 From: Federico Pozzi Date: Thu, 28 Apr 2022 20:34:34 +0200 Subject: [PATCH 6/6] refactor: reshape tensor in 4d, then pad --- test/test_prototype_transforms_functional.py | 6 ------ .../prototype/transforms/functional/_geometry.py | 10 +++++++++- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index f5fc31660e6..dac43717d30 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -377,12 +377,6 @@ def pad_segmentation_mask(): [[1], [1, 1], [1, 1, 2, 2]], # padding ["constant", "symmetric", "edge", "reflect"], # padding mode, ): - if padding_mode == "symmetric" and mask.ndim not in [3, 4]: - continue - - if (padding_mode == "edge" or padding_mode == "reflect") and mask.ndim not in [3, 4]: - continue - yield SampleInput(mask, padding=padding, padding_mode=padding_mode) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 2d611c78fcc..602f865f724 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -399,7 +399,15 @@ def rotate_segmentation_mask( def pad_segmentation_mask( segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant" ) -> torch.Tensor: - return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode) + num_masks, height, width = segmentation_mask.shape[-3:] + extra_dims = segmentation_mask.shape[:-3] + + padded_mask = pad_image_tensor( + img=segmentation_mask.view(-1, num_masks, height, width), padding=padding, fill=0, padding_mode=padding_mode + ) + + new_height, new_width = padded_mask.shape[-2:] + return padded_mask.view(extra_dims + (num_masks, new_height, new_width)) def pad_bounding_box(