From 7d460af91be6c0bc065bd70b81caeb9ced0ca489 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 9 May 2022 09:46:15 +0000 Subject: [PATCH 1/4] [proto] Added `center_crop_bounding_box` functional op --- test/test_prototype_transforms_functional.py | 64 ++++++++++++++++++- .../transforms/functional/__init__.py | 1 + .../transforms/functional/_geometry.py | 11 ++++ 3 files changed, 75 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index dac43717d30..aeca7bb9295 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -93,7 +93,7 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch cx = torch.randint(1, width - 1, ()) cy = torch.randint(1, height - 1, ()) w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) - h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) + h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1) parts = (cx, cy, w, h) else: raise pytest.UsageError() @@ -380,6 +380,14 @@ def pad_segmentation_mask(): yield SampleInput(mask, padding=padding, padding_mode=padding_mode) +@register_kernel_info_from_sample_inputs_fn +def center_crop_bounding_box(): + for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), (16, 18)]): + yield SampleInput( + bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size + ) + + @pytest.mark.parametrize( "kernel", [ @@ -1085,3 +1093,57 @@ def parse_padding(): expected_mask = _compute_expected_mask() torch.testing.assert_close(out_mask, expected_mask) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "output_size", + [(18, 18), (18, 15), (16, 19)], +) +def test_correctness_center_crop_bounding_box(device, output_size): + def _compute_expected_bbox(bbox, output_size_): + format_ = bbox.format + image_size_ = bbox.image_size + bbox = convert_bounding_box_format(bbox, format_, features.BoundingBoxFormat.XYWH) + + cy = int(round((image_size_[0] - output_size_[0]) * 0.5)) + cx = int(round((image_size_[1] - output_size_[1]) * 0.5)) + out_bbox = [ + bbox[0].item() - cx, + bbox[1].item() - cy, + bbox[2].item(), + bbox[3].item(), + ] + out_bbox = features.BoundingBox( + out_bbox, + format=features.BoundingBoxFormat.XYWH, + image_size=output_size_, + dtype=bbox.dtype, + device=bbox.device, + ) + return convert_bounding_box_format(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False) + + for bboxes in make_bounding_boxes( + image_sizes=[(32, 32), (24, 33), (32, 25)], + extra_dims=((4,),), + ): + bboxes = bboxes.to(device) + bboxes_format = bboxes.format + bboxes_image_size = bboxes.image_size + + output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, output_size, bboxes_image_size) + + if bboxes.ndim < 2: + bboxes = [bboxes] + + expected_bboxes = [] + for bbox in bboxes: + bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + expected_bboxes.append(_compute_expected_bbox(bbox, output_size)) + + if len(expected_bboxes) > 1: + expected_bboxes = torch.stack(expected_bboxes) + else: + expected_bboxes = expected_bboxes[0] + expected_bboxes = expected_bboxes.to(device=device) + torch.testing.assert_close(output_boxes, expected_bboxes) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index c13a94035ea..4d400bdc2c7 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -45,6 +45,7 @@ resize_image_tensor, resize_image_pil, resize_segmentation_mask, + center_crop_bounding_box, center_crop_image_tensor, center_crop_image_pil, resized_crop_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 602f865f724..23f5f7d03ae 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -530,6 +530,17 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width) +def center_crop_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + output_size: List[int], + image_size: Tuple[int, int], +): + crop_height, crop_width = _center_crop_parse_output_size(output_size) + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size) + return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left) + + def resized_crop_image_tensor( img: torch.Tensor, top: int, From ceb2cd92011282f2534d6e1f3cfea78284c80031 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 9 May 2022 09:59:35 +0000 Subject: [PATCH 2/4] Fixed mypy issue --- torchvision/prototype/transforms/functional/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 23f5f7d03ae..51961520ce9 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -535,7 +535,7 @@ def center_crop_bounding_box( format: features.BoundingBoxFormat, output_size: List[int], image_size: Tuple[int, int], -): +) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size) return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left) From 291d18104d1e9568fe0f08fbc35bb3e1900c08ce Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 9 May 2022 10:12:29 +0000 Subject: [PATCH 3/4] Added one more test case --- test/test_prototype_transforms_functional.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index aeca7bb9295..5f4e485ec32 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1098,7 +1098,7 @@ def parse_padding(): @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( "output_size", - [(18, 18), (18, 15), (16, 19)], + [(18, 18), [18, 15], (16, 19), [12]], ) def test_correctness_center_crop_bounding_box(device, output_size): def _compute_expected_bbox(bbox, output_size_): @@ -1106,6 +1106,9 @@ def _compute_expected_bbox(bbox, output_size_): image_size_ = bbox.image_size bbox = convert_bounding_box_format(bbox, format_, features.BoundingBoxFormat.XYWH) + if len(output_size_) == 1: + output_size_.append(output_size_[-1]) + cy = int(round((image_size_[0] - output_size_[0]) * 0.5)) cx = int(round((image_size_[1] - output_size_[1]) * 0.5)) out_bbox = [ From 28c380dc3261afe15bb6ceb74bd5fa71d80ff141 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 9 May 2022 10:20:15 +0000 Subject: [PATCH 4/4] More test cases --- test/test_prototype_transforms_functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 5f4e485ec32..23c9d2a3fd7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -382,7 +382,7 @@ def pad_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn def center_crop_bounding_box(): - for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), (16, 18)]): + for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): yield SampleInput( bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size ) @@ -1098,7 +1098,7 @@ def parse_padding(): @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize( "output_size", - [(18, 18), [18, 15], (16, 19), [12]], + [(18, 18), [18, 15], (16, 19), [12], [46, 48]], ) def test_correctness_center_crop_bounding_box(device, output_size): def _compute_expected_bbox(bbox, output_size_):