diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py new file mode 100644 index 00000000000..b83c4f3acb9 --- /dev/null +++ b/test/test_prototype_transforms_utils.py @@ -0,0 +1,83 @@ +import PIL.Image +import pytest + +import torch + +from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask + +from torchvision.prototype import features +from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor +from torchvision.prototype.transforms.functional import to_image_pil + + +IMAGE = make_image(color_space=features.ColorSpace.RGB) +BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) +SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size) + + +@pytest.mark.parametrize( + ("sample", "types", "expected"), + [ + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), + ((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False), + ((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False), + ((IMAGE,), (features.BoundingBox, features.SegmentationMask), False), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (features.Image, features.BoundingBox, features.SegmentationMask), + True, + ), + ((), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), + ((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True), + ((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), + ((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), + ], +) +def test_has_any(sample, types, expected): + assert has_any(sample, *types) is expected + + +@pytest.mark.parametrize( + ("sample", "types", "expected"), + [ + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (features.Image, features.BoundingBox, features.SegmentationMask), + True, + ), + ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False), + ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False), + ((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (features.Image, features.BoundingBox, features.SegmentationMask), + True, + ), + ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),), + True, + ), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), + ], +) +def test_has_all(sample, types, expected): + assert has_all(sample, *types) is expected diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index d928e618c9f..bb884a6cb77 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.transforms import functional as F from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image +from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image class RandomErasing(_RandomApplyTransform): @@ -105,7 +105,9 @@ def __init__(self, *, alpha: float, p: float = 0.5) -> None: def forward(self, *inpts: Any) -> Any: sample = inpts if len(inpts) > 1 else inpts[0] - if not has_all(sample, features.Image, features.OneHotLabel): + if not ( + has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel) + ): raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): raise TypeError( diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index e0215caaf87..32f220f2f9f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -719,10 +719,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - # TODO: Allow image to be a torch.Tensor if not ( has_all(sample, features.BoundingBox) - and has_any(sample, PIL.Image.Image, features.Image) + and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor) and has_any(sample, features.Label, features.OneHotLabel) ): raise TypeError( diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 4cfe1da3649..fe06132ca1c 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Type, Union +from typing import Any, Callable, Tuple, Type, Union import PIL.Image import torch @@ -39,14 +39,24 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im return channels, height, width -def has_any(sample: Any, *types: Type) -> bool: +def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - return any(issubclass(type(obj), types) for obj in flat_sample) + for type_or_check in types_or_checks: + for obj in flat_sample: + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): + return True + return False -def has_all(sample: Any, *types: Type) -> bool: +def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - return not bool(set(types) - set([type(obj) for obj in flat_sample])) + for type_or_check in types_or_checks: + for obj in flat_sample: + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): + break + else: + return False + return True def is_simple_tensor(inpt: Any) -> bool: