From 073ef7bc8add4450d736b469b4f35d1035589d03 Mon Sep 17 00:00:00 2001 From: mpearce25 Date: Thu, 23 Feb 2023 21:36:20 -0500 Subject: [PATCH 1/2] singular sanitize bounding box --- gallery/plot_transforms_v2_e2e.py | 4 +-- test/test_transforms_v2.py | 26 ++++++++++---------- test/test_transforms_v2_consistency.py | 2 +- torchvision/prototype/transforms/_augment.py | 2 +- torchvision/transforms/v2/__init__.py | 2 +- torchvision/transforms/v2/_geometry.py | 4 +-- torchvision/transforms/v2/_misc.py | 8 +++--- 7 files changed, 24 insertions(+), 24 deletions(-) diff --git a/gallery/plot_transforms_v2_e2e.py b/gallery/plot_transforms_v2_e2e.py index 938578e4af9..189239de289 100644 --- a/gallery/plot_transforms_v2_e2e.py +++ b/gallery/plot_transforms_v2_e2e.py @@ -106,13 +106,13 @@ def load_example_coco_detection_dataset(**kwargs): transforms.RandomHorizontalFlip(), transforms.ToImageTensor(), transforms.ConvertImageDtype(torch.float32), - transforms.SanitizeBoundingBoxes(), + transforms.SanitizeBoundingBox(), ] ) ######################################################################################################################## # .. note:: -# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it +# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it # should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as # the corresponding labels and optionally masks. It is particularly critical to add it if # :class:`~torchvision.transforms.v2.RandomIoUCrop` was used. diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 9173ec14f2c..93d5f17fcbe 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -275,7 +275,7 @@ def test_common(self, transform, adapter, container_type, image_or_video, device boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)), labels=torch.tensor([3]), ) - assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4) + assert transforms.SanitizeBoundingBox()(sample)["boxes"].shape == (0, 4) @parametrize( [ @@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): transforms.ConvertImageDtype(torch.float), ] if sanitize: - t += [transforms.SanitizeBoundingBoxes()] + t += [transforms.SanitizeBoundingBox()] t = transforms.Compose(t) num_boxes = 5 @@ -1917,7 +1917,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It # doesn't remove them strictly speaking, it just marks some boxes as # degenerate and those boxes will be later removed by - # SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize + # SanitizeBoundingBox(), which we add to the pipelines if the sanitize # param is True. # Note that the values below are probably specific to the random seed # set above (which is fine). @@ -1989,7 +1989,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): img = sample.pop("image") sample = (img, sample) - out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) + out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample) if sample_type is tuple: out_image = out[0] @@ -2023,13 +2023,13 @@ def test_sanitize_bounding_boxes_default_heuristic(key, sample_type): sample = {key: labels, "another_key": "whatever"} if sample_type is tuple: sample = (None, sample, "whatever_again") - assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels + assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels if key.lower() != "labels": # If "labels" is in the dict (case-insensitive), # it takes precedence over other keys which would otherwise be a match d = {key: "something_else", "labels": labels} - assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels + assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels def test_sanitize_bounding_boxes_errors(): @@ -2041,25 +2041,25 @@ def test_sanitize_bounding_boxes_errors(): ) with pytest.raises(ValueError, match="min_size must be >= 1"): - transforms.SanitizeBoundingBoxes(min_size=0) + transforms.SanitizeBoundingBox(min_size=0) with pytest.raises(ValueError, match="labels_getter should either be a str"): - transforms.SanitizeBoundingBoxes(labels_getter=12) + transforms.SanitizeBoundingBox(labels_getter=12) with pytest.raises(ValueError, match="Could not infer where the labels are"): bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])} - transforms.SanitizeBoundingBoxes()(bad_labels_key) + transforms.SanitizeBoundingBox()(bad_labels_key) with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"): not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0])) - transforms.SanitizeBoundingBoxes()(not_a_dict) + transforms.SanitizeBoundingBox()(not_a_dict) with pytest.raises(ValueError, match="must be a tensor"): not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()} - transforms.SanitizeBoundingBoxes()(not_a_tensor) + transforms.SanitizeBoundingBox()(not_a_tensor) with pytest.raises(ValueError, match="Number of boxes"): different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)} - transforms.SanitizeBoundingBoxes()(different_sizes) + transforms.SanitizeBoundingBox()(different_sizes) with pytest.raises(ValueError, match="boxes must be of shape"): bad_bbox = datapoints.BoundingBox( # batch with 2 elements @@ -2071,7 +2071,7 @@ def test_sanitize_bounding_boxes_errors(): spatial_size=(20, 20), ) different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])} - transforms.SanitizeBoundingBoxes()(different_sizes) + transforms.SanitizeBoundingBox()(different_sizes) @pytest.mark.parametrize( diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index 43f17c9b15a..059a230ee5c 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -1099,7 +1099,7 @@ def make_label(extra_dims, categories): v2_transforms.Compose( [ v2_transforms.RandomIoUCrop(), - v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]), + v2_transforms.SanitizeBoundingBox(labels_getter=lambda sample: sample[1]["labels"]), ] ), {"with_mask": False}, diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index d04baf739d1..f796df48e2f 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -218,7 +218,7 @@ def _extract_image_targets( if not (len(images) == len(bboxes) == len(masks) == len(labels)): raise TypeError( f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " - "BoundingBoxes, Masks and Labels or OneHotLabels." + "BoundingBox, Masks and Labels or OneHotLabels." ) targets = [] diff --git a/torchvision/transforms/v2/__init__.py b/torchvision/transforms/v2/__init__.py index 7ad72c00934..6573446a33a 100644 --- a/torchvision/transforms/v2/__init__.py +++ b/torchvision/transforms/v2/__init__.py @@ -40,7 +40,7 @@ TenCrop, ) from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype -from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBoxes, ToDtype +from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype from ._temporal import UniformTemporalSubsample from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index c3342eb9926..b2618bb892f 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -1114,7 +1114,7 @@ class RandomIoUCrop(Transform): .. warning:: In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop` - must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately + must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately after or later in the transforms pipeline. If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`, @@ -1222,7 +1222,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(output, datapoints.BoundingBox): # We "mark" the invalid boxes as degenreate, and they can be - # removed by a later call to SanitizeBoundingBoxes() + # removed by a later call to SanitizeBoundingBox() output[~params["is_within_crop_area"]] = 0 return output diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 8cc4aa6a3db..2d6c7ac958e 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -246,7 +246,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt.to(dtype=dtype) -class SanitizeBoundingBoxes(Transform): +class SanitizeBoundingBox(Transform): # This removes boxes and their corresponding labels: # - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1) # - boxes with any coordinate outside the range of the image (negative, or > spatial_size) @@ -269,7 +269,7 @@ def __init__( elif callable(labels_getter): self._labels_getter = labels_getter elif isinstance(labels_getter, str): - self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[ + self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[ labels_getter # type: ignore[index] ] elif labels_getter is None: @@ -300,7 +300,7 @@ def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]: def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: # Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive # Returns None if nothing is found - inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs) + inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs) candidate_key = None with suppress(StopIteration): candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") @@ -356,7 +356,7 @@ def forward(self, *inputs: Any) -> Any: params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBoxes and the labels + # _transform() will only care about BoundingBox and the labels self._transform(inpt, params) for inpt in flat_inputs ] From b17dbdeecc01a52119ce2b4cf112a306a9bc6a89 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 24 Feb 2023 09:34:53 +0000 Subject: [PATCH 2/2] Apply suggestions from code review --- torchvision/prototype/transforms/_augment.py | 2 +- torchvision/transforms/v2/_misc.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index f796df48e2f..d04baf739d1 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -218,7 +218,7 @@ def _extract_image_targets( if not (len(images) == len(bboxes) == len(masks) == len(labels)): raise TypeError( f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " - "BoundingBox, Masks and Labels or OneHotLabels." + "BoundingBoxes, Masks and Labels or OneHotLabels." ) targets = [] diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 2d6c7ac958e..53975a2ad2a 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -356,7 +356,7 @@ def forward(self, *inputs: Any) -> Any: params = dict(valid=valid, labels=labels) flat_outputs = [ # Even-though it may look like we're transforming all inputs, we don't: - # _transform() will only care about BoundingBox and the labels + # _transform() will only care about BoundingBoxes and the labels self._transform(inpt, params) for inpt in flat_inputs ]