Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1935,7 +1935,14 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
@pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
def test_sanitize_bounding_boxes(min_size, labels_getter):
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):

if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return

H, W = 256, 128

boxes_and_validity = [
Expand Down Expand Up @@ -1970,35 +1977,56 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
)

masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))

whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
sample = {
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
"image": input_img,
"labels": labels,
"boxes": boxes,
"whatever": torch.rand(10),
"whatever": whatever,
"None": None,
"masks": masks,
}

if sample_type is tuple:
img = sample.pop("image")
sample = (img, sample)

out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)

assert out["image"] is sample["image"]
assert out["whatever"] is sample["whatever"]
if sample_type is tuple:
out_image = out[0]
out_labels = out[1]["labels"]
out_boxes = out[1]["boxes"]
out_masks = out[1]["masks"]
out_whatever = out[1]["whatever"]
else:
out_image = out["image"]
out_labels = out["labels"]
out_boxes = out["boxes"]
out_masks = out["masks"]
out_whatever = out["whatever"]

assert out_image is input_img
assert out_whatever is whatever

if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out["labels"] is sample["labels"]
assert out_labels is labels
else:
assert isinstance(out["labels"], torch.Tensor)
assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0]
assert isinstance(out_labels, torch.Tensor)
assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
# This works because we conveniently set labels to arange(num_boxes)
assert out["labels"].tolist() == valid_indices
assert out_labels.tolist() == valid_indices


@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
def test_sanitize_bounding_boxes_default_heuristic(key):
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
labels = torch.arange(10)
d = {key: labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
sample = {key: labels, "another_key": "whatever"}
if sample_type is tuple:
sample = (None, sample, "whatever_again")
Comment on lines +2027 to +2028
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this show a bug in the implementation? Shouldn't we require the tuple to only have two elements?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can, I just went for this because it felt like an artificial restriction. I.e. I don't see more things going wrong by allowing tuples with 3+ entries. And not raising an error actually makes the code simpler. But I don't have a strong opinion

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. Maybe put a note of that in the comment of _get_dict_or_second_tuple_entry. So far it only uses a two-tuple as example.

assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels

if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
Expand Down
31 changes: 22 additions & 9 deletions torchvision/transforms/v2/_misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import collections
import warnings
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union

import PIL.Image

Expand Down Expand Up @@ -269,7 +269,9 @@ def __init__(
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter]
self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None:
self._labels_getter = None
else:
Expand All @@ -278,10 +280,27 @@ def __init__(
f"Got {labels_getter} of type {type(labels_getter)}."
)

@staticmethod
def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
# datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...}
# or tuples like (img, {"labels":..., "bbox": ...})
# This hacky helper accounts for both structures.
if isinstance(inputs, tuple):
inputs = inputs[1]

if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default', "
f"then the input to forward() must be a dict or a tuple whose second element is a dict."
f" Got {type(inputs)} instead."
)
return inputs

@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
Expand All @@ -298,12 +317,6 @@ def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Ten
def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]

if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)

if self._labels_getter is None:
labels = None
else:
Expand Down