From bf21a257bdcdf4fae757084199f59b47d7495f27 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 10 Aug 2022 17:03:18 +0200 Subject: [PATCH 01/12] [proto] Ported RandomIoUCrop from detection refs --- test/test_prototype_transforms.py | 10 ++ torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 95 ++++++++++++++++++- torchvision/prototype/transforms/_utils.py | 9 ++ .../transforms/functional/__init__.py | 1 + .../prototype/transforms/functional/_meta.py | 7 ++ 6 files changed, 122 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 76c0b7853c9..cb555420a09 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1102,3 +1102,13 @@ def test_ctor(self, trfms): inpt = torch.rand(1, 3, 32, 32) output = c(inpt) assert isinstance(output, torch.Tensor) + + +class TestRandomIoUCrop: + def test__get_params(self): + # TODO: + pass + + def test__transform(self): + # TODO: + pass diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5617c010e5f..8ff2cbf013a 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -24,6 +24,7 @@ RandomAffine, RandomCrop, RandomHorizontalFlip, + RandomIoUCrop, RandomPerspective, RandomResizedCrop, RandomRotation, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index aa1ca109cc4..2cf2fa0f684 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,14 +6,16 @@ import PIL.Image import torch +from torchvision.ops.boxes import box_iou from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform from torchvision.transforms.functional import InterpolationMode, pil_to_tensor from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size + from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image +from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bboxes, query_image class RandomHorizontalFlip(_RandomApplyTransform): @@ -611,3 +613,94 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, interpolation=self.interpolation, ) + + +class RandomIoUCrop(Transform): + def __init__( + self, + min_scale: float = 0.3, + max_scale: float = 1.0, + min_aspect_ratio: float = 0.5, + max_aspect_ratio: float = 2.0, + sampler_options: Optional[List[float]] = None, + trials: int = 40, + ): + super().__init__() + # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174 + self.min_scale = min_scale + self.max_scale = max_scale + self.min_aspect_ratio = min_aspect_ratio + self.max_aspect_ratio = max_aspect_ratio + if sampler_options is None: + sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] + self.options = sampler_options + self.trials = trials + + def _get_params(self, sample: Any) -> Dict[str, Any]: + + image = query_image(sample) + _, orig_h, orig_w = get_image_dimensions(image) + bboxes = query_bboxes(sample) + + while True: + # sample an option + idx = int(torch.randint(low=0, high=len(self.options), size=(1,))) + min_jaccard_overlap = self.options[idx] + if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option + return dict() + + for _ in range(self.trials): + # check the aspect ratio limitations + r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2) + new_w = int(orig_w * r[0]) + new_h = int(orig_h * r[1]) + aspect_ratio = new_w / new_h + if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio): + continue + + # check for 0 area crops + r = torch.rand(2) + left = int((orig_w - new_w) * r[0]) + top = int((orig_h - new_h) * r[1]) + right = left + new_w + bottom = top + new_h + if left == right or top == bottom: + continue + + # check for any valid boxes with centers within the crop area + xyxy_bboxes = F.convert_bounding_box_format( + bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True + ) + cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) + cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) + is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) + if not is_within_crop_area.any(): + continue + + # check at least 1 box with jaccard limitations + xyxy_bboxes = xyxy_bboxes[is_within_crop_area] + ious = box_iou( + xyxy_bboxes, + torch.tensor([[left, top, right, bottom]], dtype=xyxy_bboxes.dtype, device=xyxy_bboxes.device), + ) + if ious.max() < min_jaccard_overlap: + continue + + return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if len(params) < 1: + return inpt + + is_within_crop_area = params["is_within_crop_area"] + if isinstance(inpt, features.Label): + return features.Label.new_like(inpt, inpt[is_within_crop_area]) + + output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + + if isinstance(output, features.BoundingBox): + bboxes = output[is_within_crop_area] + bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) + output = features.BoundingBox.new_like(output, bboxes) + + return output diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 9f2ef84ced5..7a7430a1d69 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -17,6 +17,15 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im raise TypeError("No image was found in the sample") +def query_bboxes(sample: Any) -> features.BoundingBox: + flat_sample, _ = tree_flatten(sample) + for i in flat_sample: + if isinstance(i, features.BoundingBox): + return i + + raise TypeError("No bounding boxes were found in the sample") + + def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: if isinstance(image, features.Image): channels = image.num_channels diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 958f9103e06..bdd183ca24e 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,6 @@ from torchvision.transforms import InterpolationMode # usort: skip from ._meta import ( + clamp_bounding_box, convert_bounding_box_format, convert_image_color_space_tensor, convert_image_color_space_pil, diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index db7918558bc..10a6c123c5f 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -61,6 +61,13 @@ def convert_bounding_box_format( return bounding_box +def clamp_bounding_box(bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]): + xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY) + xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1]) + xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0]) + return convert_bounding_box_format(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False) + + def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: return image[..., :-1, :, :], image[..., -1:, :, :] From ed229cd043e5a156f825fff14e460625bb16ddf2 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 12 Aug 2022 13:20:04 +0000 Subject: [PATCH 02/12] Scope acceptable data types --- torchvision/prototype/transforms/_geometry.py | 10 +++++++++- torchvision/prototype/transforms/functional/_meta.py | 4 +++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 2cf2fa0f684..aaf68e812b8 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -15,7 +15,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_bboxes, query_image +from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bboxes, query_image class RandomHorizontalFlip(_RandomApplyTransform): @@ -704,3 +704,11 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: output = features.BoundingBox.new_like(output, bboxes) return output + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if not has_all(sample, features.Image, features.BoundingBox, features.Label): + raise TypeError(f"{type(self).__name__}() is only defined for Images, BoundingBoxes and Labels.") + if has_any(sample, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() does not support OneHotLabels.") + return super().forward(*inputs) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 10a6c123c5f..4fb477ff139 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -61,7 +61,9 @@ def convert_bounding_box_format( return bounding_box -def clamp_bounding_box(bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int]): +def clamp_bounding_box( + bounding_box: torch.Tensor, format: BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: xyxy_boxes = convert_bounding_box_format(bounding_box, format, BoundingBoxFormat.XYXY) xyxy_boxes[..., 0::2].clamp_(min=0, max=image_size[1]) xyxy_boxes[..., 1::2].clamp_(min=0, max=image_size[0]) From 5c2275e7405e82d3b279653372194f50310b5293 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 15 Aug 2022 12:59:03 +0200 Subject: [PATCH 03/12] Added get_params test --- test/test_prototype_transforms.py | 51 ++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 51d92f9f22a..abd49853f20 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -6,7 +6,7 @@ import pytest import torch -from common_utils import assert_equal +from common_utils import assert_equal, cpu_and_gpu from test_prototype_transforms_functional import ( make_bounding_box, make_bounding_boxes, @@ -15,6 +15,7 @@ make_one_hot_labels, make_segmentation_mask, ) +from torchvision.ops.boxes import box_iou from torchvision.prototype import features, transforms from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image @@ -1128,9 +1129,51 @@ def test_ctor(self, trfms): class TestRandomIoUCrop: - def test__get_params(self): - # TODO: - pass + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("options", [[0.5, 0.9], [2.0]]) + def test__get_params(self, device, options, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + bboxes = features.BoundingBox( + torch.tensor( + [ + [1, 1, 10, 10], + [20, 20, 23, 23], + [1, 20, 10, 23], + [20, 1, 23, 10], + ] + ), + format="XYXY", + image_size=image.image_size, + device=device, + ) + sample = [image, bboxes] + + transform = transforms.RandomIoUCrop(sampler_options=options) + + n_samples = 5 + for _ in range(n_samples): + + params = transform._get_params(sample) + + if options == [2.0]: + assert len(params) == 0 + return + + assert len(params["is_within_crop_area"]) > 0 + orig_h = image.image_size[0] + orig_w = image.image_size[1] + assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) + assert int(transform.min_scale * orig_w) <= params["width"] <= int(transform.max_scale * orig_w) + + left, top = params["left"], params["top"] + new_h, new_w = params["height"], params["width"] + ious = box_iou( + bboxes, + torch.tensor([[left, top, left + new_w, top + new_h]], dtype=bboxes.dtype, device=bboxes.device), + ) + assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}" def test__transform(self): # TODO: From 84d2f093ddef083233de78ce5fbda8e3190c8b47 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 15 Aug 2022 13:36:51 +0200 Subject: [PATCH 04/12] Added test__transform_empty_params --- test/test_prototype_transforms.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index abd49853f20..cfd5b250216 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1136,14 +1136,7 @@ def test__get_params(self, device, options, mocker): image.num_channels = 3 image.image_size = (24, 32) bboxes = features.BoundingBox( - torch.tensor( - [ - [1, 1, 10, 10], - [20, 20, 23, 23], - [1, 20, 10, 23], - [20, 1, 23, 10], - ] - ), + torch.tensor([[1, 1, 10, 10], [20, 20, 23, 23], [1, 20, 10, 23], [20, 1, 23, 10]]), format="XYXY", image_size=image.image_size, device=device, @@ -1175,6 +1168,17 @@ def test__get_params(self, device, options, mocker): ) assert ious.max() >= options[0] or ious.max() >= options[1], f"{ious} vs {options}" + def test__transform_empty_params(self, mocker): + transform = transforms.RandomIoUCrop(sampler_options=[2.0]) + image = features.Image(torch.rand(1, 3, 4, 4)) + bboxes = features.BoundingBox(torch.tensor([[1, 1, 2, 2]]), format="XYXY", image_size=(4, 4)) + label = features.Label(torch.tensor([1])) + sample = [image, bboxes, label] + # Let's mock transform._get_params to control the output: + transform._get_params = mocker.MagicMock(return_value={}) + output = transform(sample) + torch.testing.assert_close(output, sample) + def test__transform(self): # TODO: pass From 3085dbf9ea584b31bb7ad5cb065e74369c386d6a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 17 Aug 2022 13:02:12 +0000 Subject: [PATCH 05/12] Added support for OneHotLabel and tests --- test/test_prototype_transforms.py | 14 ++++++++++++-- torchvision/prototype/transforms/_geometry.py | 9 +++++++-- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 0bf754b9476..04c6ac2d2b5 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1190,13 +1190,14 @@ def test__transform(self, mocker): image = features.Image(torch.rand(3, 32, 24)) bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,)) label = features.Label(torch.randint(0, 10, size=(6,))) - sample = [image, bboxes, label] + ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) + sample = [image, bboxes, label, ohe_label] fn = mocker.patch("torchvision.prototype.transforms.functional.crop") is_within_crop_area = torch.randint(0, 2, size=(6,)) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) transform._get_params = mocker.MagicMock(return_value=params) - _ = transform(sample) + output = transform(sample) assert fn.call_count == 2 # asserts the last call @@ -1208,6 +1209,15 @@ def test__transform(self, mocker): image, top=params["top"], left=params["left"], height=params["height"], width=params["width"] ) + # check labels + output_label = output[-2] + assert isinstance(output_label, features.Label) + torch.testing.assert_close(output_label, label[is_within_crop_area]) + + output_ohe_label = output[-1] + assert isinstance(output_ohe_label, features.OneHotLabel) + torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) + class TestScaleJitter: def test__get_params(self, mocker): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 33b04f026d4..8e22aee765b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -703,6 +703,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, features.Label): return features.Label.new_like(inpt, inpt[is_within_crop_area]) + if isinstance(inpt, features.OneHotLabel): + return features.OneHotLabel.new_like(inpt, inpt[is_within_crop_area, :]) + output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) if isinstance(output, features.BoundingBox): @@ -715,11 +718,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if not ( - has_all(sample, features.Image, features.BoundingBox) + has_all(sample, features.BoundingBox) + and has_any(sample, PIL.Image.Image, features.Image) and has_any(sample, features.Label, features.OneHotLabel) ): raise TypeError( - f"{type(self).__name__}() is only defined for Images, BoundingBoxes and Labels or OneHotLabels." + f"{type(self).__name__}() is only defined for Images, PIL Images, " + "BoundingBoxes and Labels or OneHotLabels." ) return super().forward(*inputs) From abf4381ca97f161aa978a91d62eb4396f1947eaa Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 17 Aug 2022 13:26:09 +0000 Subject: [PATCH 06/12] Added tests for mask --- test/test_prototype_transforms.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 04c6ac2d2b5..5768fe0162c 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1181,7 +1181,9 @@ def test__transform_empty_params(self, mocker): def test_forward_assertion(self): transform = transforms.RandomIoUCrop() - with pytest.raises(TypeError, match="only defined for Images, BoundingBoxes and Labels or OneHotLabels"): + with pytest.raises( + TypeError, match="only defined for Images, PIL Images, BoundingBoxes and Labels or OneHotLabels" + ): transform(torch.tensor(0)) def test__transform(self, mocker): @@ -1191,7 +1193,8 @@ def test__transform(self, mocker): bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,)) label = features.Label(torch.randint(0, 10, size=(6,))) ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) - sample = [image, bboxes, label, ohe_label] + masks = make_segmentation_mask((32, 24)) + sample = [image, bboxes, label, ohe_label, masks] fn = mocker.patch("torchvision.prototype.transforms.functional.crop") is_within_crop_area = torch.randint(0, 2, size=(6,)) @@ -1199,22 +1202,22 @@ def test__transform(self, mocker): transform._get_params = mocker.MagicMock(return_value=params) output = transform(sample) - assert fn.call_count == 2 - # asserts the last call - fn.assert_called_with( - bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"] - ) + assert fn.call_count == 3 - fn.assert_any_call( - image, top=params["top"], left=params["left"], height=params["height"], width=params["width"] - ) + expected_calls = [ + mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), + mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), + mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), + ] + + fn.assert_has_calls(expected_calls) # check labels - output_label = output[-2] + output_label = output[2] assert isinstance(output_label, features.Label) torch.testing.assert_close(output_label, label[is_within_crop_area]) - output_ohe_label = output[-1] + output_ohe_label = output[3] assert isinstance(output_ohe_label, features.OneHotLabel) torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) From 62316084d5788e43ea973e68eab602336cc532d4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 17 Aug 2022 14:24:08 +0000 Subject: [PATCH 07/12] Updated error message --- torchvision/prototype/transforms/_geometry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 8e22aee765b..0ddd38a2751 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -723,8 +723,8 @@ def forward(self, *inputs: Any) -> Any: and has_any(sample, features.Label, features.OneHotLabel) ): raise TypeError( - f"{type(self).__name__}() is only defined for Images, PIL Images, " - "BoundingBoxes and Labels or OneHotLabels." + f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " + "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." ) return super().forward(*inputs) From 0b61852f6789273931474e5af67e4ddd7e29c9d7 Mon Sep 17 00:00:00 2001 From: vfdev Date: Wed, 17 Aug 2022 16:26:45 +0200 Subject: [PATCH 08/12] Apply suggestions from code review Co-authored-by: Philip Meier --- torchvision/prototype/transforms/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 7a7430a1d69..4cfe1da3649 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -17,13 +17,13 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Im raise TypeError("No image was found in the sample") -def query_bboxes(sample: Any) -> features.BoundingBox: +def query_bounding_box(sample: Any) -> features.BoundingBox: flat_sample, _ = tree_flatten(sample) for i in flat_sample: if isinstance(i, features.BoundingBox): return i - raise TypeError("No bounding boxes were found in the sample") + raise TypeError("No bounding box was found in the sample") def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: From 419ba8f360e707a56852c0d07a5adc61f9edab65 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 17 Aug 2022 17:16:18 +0000 Subject: [PATCH 09/12] Added support for OHE masks and tests --- test/test_prototype_transforms.py | 34 ++++++++++++++++--- torchvision/prototype/transforms/_geometry.py | 14 ++++---- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 5768fe0162c..2d16214b73b 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1155,6 +1155,8 @@ def test__get_params(self, device, options, mocker): return assert len(params["is_within_crop_area"]) > 0 + assert params["is_within_crop_area"].dtype == torch.bool + orig_h = image.image_size[0] orig_w = image.image_size[1] assert int(transform.min_scale * orig_h) <= params["height"] <= int(transform.max_scale * orig_h) @@ -1182,7 +1184,8 @@ def test__transform_empty_params(self, mocker): def test_forward_assertion(self): transform = transforms.RandomIoUCrop() with pytest.raises( - TypeError, match="only defined for Images, PIL Images, BoundingBoxes and Labels or OneHotLabels" + TypeError, + match="requires input sample to contain Images or PIL Images, BoundingBoxes and Labels or OneHotLabels", ): transform(torch.tensor(0)) @@ -1194,33 +1197,54 @@ def test__transform(self, mocker): label = features.Label(torch.randint(0, 10, size=(6,))) ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) masks = make_segmentation_mask((32, 24)) - sample = [image, bboxes, label, ohe_label, masks] + ohe_masks = features.SegmentationMask(torch.randint(0, 2, size=(6, 32, 24))) + sample = [image, bboxes, label, ohe_label, masks, ohe_masks] + + fn = mocker.patch("torchvision.prototype.transforms.functional.crop", side_effect=lambda x, **params: x) + is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) - fn = mocker.patch("torchvision.prototype.transforms.functional.crop") - is_within_crop_area = torch.randint(0, 2, size=(6,)) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) transform._get_params = mocker.MagicMock(return_value=params) output = transform(sample) - assert fn.call_count == 3 + assert fn.call_count == 4 expected_calls = [ mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), + mocker.call( + ohe_masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ), ] fn.assert_has_calls(expected_calls) + expected_within_targets = sum(is_within_crop_area) + + # check number of bboxes vs number of labels: + output_bboxes = output[1] + assert isinstance(output_bboxes, features.BoundingBox) + assert len(output_bboxes) == expected_within_targets + # check labels output_label = output[2] assert isinstance(output_label, features.Label) + assert len(output_label) == expected_within_targets torch.testing.assert_close(output_label, label[is_within_crop_area]) output_ohe_label = output[3] assert isinstance(output_ohe_label, features.OneHotLabel) torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) + output_masks = output[4] + assert isinstance(output_masks, features.SegmentationMask) + assert output_masks.shape[:-2] == masks.shape[:-2] + + output_ohe_masks = output[5] + assert isinstance(output_ohe_masks, features.SegmentationMask) + assert len(output_ohe_masks) == expected_within_targets + class TestScaleJitter: def test__get_params(self, mocker): diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 0ddd38a2751..1447a0eced5 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -15,7 +15,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bboxes, query_image +from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_bounding_box, query_image class RandomHorizontalFlip(_RandomApplyTransform): @@ -647,7 +647,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) _, orig_h, orig_w = get_image_dimensions(image) - bboxes = query_bboxes(sample) + bboxes = query_bounding_box(sample) while True: # sample an option @@ -700,11 +700,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt is_within_crop_area = params["is_within_crop_area"] - if isinstance(inpt, features.Label): - return features.Label.new_like(inpt, inpt[is_within_crop_area]) - if isinstance(inpt, features.OneHotLabel): - return features.OneHotLabel.new_like(inpt, inpt[is_within_crop_area, :]) + if isinstance(inpt, (features.Label, features.OneHotLabel)): + return inpt.new_like(inpt, inpt[is_within_crop_area]) output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) @@ -712,6 +710,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: bboxes = output[is_within_crop_area] bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) output = features.BoundingBox.new_like(output, bboxes) + elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1: + # apply is_within_crop_area if mask is one-hot encoded + masks = output[is_within_crop_area] + output = features.SegmentationMask.new_like(output, masks) return output From f8253aafbe2895d556f261e998bd748819bf7ea1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 17 Aug 2022 17:37:04 +0000 Subject: [PATCH 10/12] Ignored mypy error --- torchvision/prototype/transforms/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 1447a0eced5..4ce0fde77cf 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -702,7 +702,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: is_within_crop_area = params["is_within_crop_area"] if isinstance(inpt, (features.Label, features.OneHotLabel)): - return inpt.new_like(inpt, inpt[is_within_crop_area]) + return inpt.new_like(inpt, inpt[is_within_crop_area]) # type: ignore[arg-type] output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) From 79bbe76c71b8dc5f6cbae2a41184b47a3bfb1a3d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 18 Aug 2022 12:08:49 +0200 Subject: [PATCH 11/12] Fixed forward call on sample --- torchvision/prototype/transforms/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4ce0fde77cf..366f24848cd 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -728,7 +728,7 @@ def forward(self, *inputs: Any) -> Any: f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." ) - return super().forward(*inputs) + return super().forward(sample) class ScaleJitter(Transform): From 1353cfe5f4f4984ad4904843c900465e4cfa9f83 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 18 Aug 2022 12:10:57 +0200 Subject: [PATCH 12/12] Added a todo --- torchvision/prototype/transforms/_geometry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 366f24848cd..e0215caaf87 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -719,6 +719,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] + # TODO: Allow image to be a torch.Tensor if not ( has_all(sample, features.BoundingBox) and has_any(sample, PIL.Image.Image, features.Image)