From 02f518fef62c47aaa99a3e02826c18a8e49c0fc4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 11:33:02 +0200 Subject: [PATCH 1/8] expand has_any and has_all to also accept check callables --- torchvision/prototype/transforms/_utils.py | 23 +++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 9f2ef84ced5..cb6060bd7f0 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, Type, Union +from typing import Any, Callable, cast, List, Tuple, Type, Union import PIL.Image import torch @@ -30,14 +30,27 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im return channels, height, width -def has_any(sample: Any, *types: Type) -> bool: +def _parse_types_or_checks( + types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...] +) -> List[Callable[[Any], bool]]: + return [ + cast(Callable[[Any], bool], lambda obj, typ=type_or_check: isinstance(obj, typ)) + if isinstance(type_or_check, type) + else type_or_check + for type_or_check in types_or_checks + ] + + +def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - return any(issubclass(type(obj), types) for obj in flat_sample) + checks = _parse_types_or_checks(types_or_checks) + return any(any([check(obj) for check in checks]) for obj in flat_sample) -def has_all(sample: Any, *types: Type) -> bool: +def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - return not bool(set(types) - set([type(obj) for obj in flat_sample])) + checks = _parse_types_or_checks(types_or_checks) + return any(any([check(obj) for check in checks]) for obj in flat_sample) def is_simple_tensor(inpt: Any) -> bool: From 1f12ca79b4c40a432ae8f944bec8046fb6330d4e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 14:47:37 +0200 Subject: [PATCH 2/8] add test and fix has_all --- test/test_prototype_transforms_utils.py | 51 ++++++++++++++++++++++ torchvision/prototype/transforms/_utils.py | 4 +- 2 files changed, 53 insertions(+), 2 deletions(-) create mode 100644 test/test_prototype_transforms_utils.py diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py new file mode 100644 index 00000000000..f20c45b1484 --- /dev/null +++ b/test/test_prototype_transforms_utils.py @@ -0,0 +1,51 @@ +import pytest +from torchvision.prototype.transforms._utils import has_all, has_any + + +@pytest.mark.parametrize( + ("sample", "types", "expected"), + [ + ((0, 0.0, ""), (int,), True), + ((0, 0.0, ""), (float,), True), + ((0, 0.0, ""), (str,), True), + ((0, 0.0, ""), (int, float), True), + ((0, 0.0, ""), (int, str), True), + ((0, 0.0, ""), (float, str), True), + (("",), (int, float), False), + ((0.0,), (int, str), False), + ((0,), (float, str), False), + ((0, 0.0, ""), (int, float, str), True), + ((), (int, float, str), False), + ((0, 0.0, ""), (lambda obj: isinstance(obj, int),), True), + ((0, 0.0, ""), (lambda _: False,), False), + ((0, 0.0, ""), (lambda _: True,), True), + ], +) +def test_has_any(sample, types, expected): + assert has_any(sample, *types) is expected + + +@pytest.mark.parametrize( + ("sample", "types", "expected"), + [ + ((0, 0.0, ""), (int,), True), + ((0, 0.0, ""), (float,), True), + ((0, 0.0, ""), (str,), True), + ((0, 0.0, ""), (int, float), True), + ((0, 0.0, ""), (int, str), True), + ((0, 0.0, ""), (float, str), True), + ((0, 0.0, ""), (int, float, str), True), + ((0.0, ""), (int, float), False), + ((0.0, ""), (int, str), False), + ((0, ""), (float, str), False), + ((0, 0.0, ""), (int, float, str), True), + ((0.0, ""), (int, float, str), False), + ((0, ""), (int, float, str), False), + ((0, 0.0), (int, float, str), False), + ((0, 0.0, ""), (lambda obj: isinstance(obj, (int, float, str)),), True), + ((0, 0.0, ""), (lambda _: False,), False), + ((0, 0.0, ""), (lambda _: True,), True), + ], +) +def test_has_all(sample, types, expected): + assert has_all(sample, *types) is expected diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index cb6060bd7f0..ed5a79322b9 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -44,13 +44,13 @@ def _parse_types_or_checks( def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) checks = _parse_types_or_checks(types_or_checks) - return any(any([check(obj) for check in checks]) for obj in flat_sample) + return any(any([check(obj) for obj in flat_sample]) for check in checks) def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) checks = _parse_types_or_checks(types_or_checks) - return any(any([check(obj) for check in checks]) for obj in flat_sample) + return all(any([check(obj) for obj in flat_sample]) for check in checks) def is_simple_tensor(inpt: Any) -> bool: From ed39943f25c5dc5313780fdb06dcc9bb718ae482 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 14:52:54 +0200 Subject: [PATCH 3/8] add support for simple tensor images to CutMix, MixUp and RandomIoUCrop --- torchvision/prototype/transforms/_augment.py | 6 ++++-- torchvision/prototype/transforms/_geometry.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index fd8ae9ab378..2c95396dad9 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.transforms import functional as F, Transform from ._transform import _RandomApplyTransform -from ._utils import get_image_dimensions, has_all, has_any, is_simple_tensor, query_image +from ._utils import get_image_dimensions, has_any, is_simple_tensor, query_image class RandomErasing(_RandomApplyTransform): @@ -105,7 +105,9 @@ def __init__(self, *, alpha: float) -> None: def forward(self, *inpts: Any) -> Any: sample = inpts if len(inpts) > 1 else inpts[0] - if not has_all(sample, features.Image, features.OneHotLabel): + if not ( + has_any(sample, features.Image, PIL.Image.Image, is_simple_tensor) and has_any(sample, features.OneHotLabel) + ): raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): raise TypeError( diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index e0215caaf87..401b51bae8f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -722,7 +722,7 @@ def forward(self, *inputs: Any) -> Any: # TODO: Allow image to be a torch.Tensor if not ( has_all(sample, features.BoundingBox) - and has_any(sample, PIL.Image.Image, features.Image) + and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor) and has_any(sample, features.Label, features.OneHotLabel) ): raise TypeError( From 4784b8be26228b6d7fd2feed453fd3860b2a8ec6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 14:58:10 +0200 Subject: [PATCH 4/8] remove TODO --- torchvision/prototype/transforms/_geometry.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 401b51bae8f..32f220f2f9f 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -719,7 +719,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - # TODO: Allow image to be a torch.Tensor if not ( has_all(sample, features.BoundingBox) and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor) From 86a7cf274ad368dbfa681cdc11ae86c81b5fb45f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 15:44:48 +0200 Subject: [PATCH 5/8] remove pythonic syntax sugar --- torchvision/prototype/transforms/_utils.py | 40 ++++++++++++++++------ 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index db6d6d9324f..3eea6b8ee00 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -42,24 +42,44 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im def _parse_types_or_checks( types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...] ) -> List[Callable[[Any], bool]]: - return [ - cast(Callable[[Any], bool], lambda obj, typ=type_or_check: isinstance(obj, typ)) - if isinstance(type_or_check, type) - else type_or_check - for type_or_check in types_or_checks - ] + checks = [] + for type_or_check in types_or_checks: + if isinstance(type_or_check, type): + check = cast(Callable[[Any], bool], lambda obj, typ=type_or_check: isinstance(obj, typ)) + else: + check = type_or_check + checks.append(check) + return checks def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - checks = _parse_types_or_checks(types_or_checks) - return any(any([check(obj) for obj in flat_sample]) for check in checks) + for check in _parse_types_or_checks(types_or_checks): + passed_check = False + for obj in flat_sample: + if check(obj): + passed_check = True + break + + if passed_check: + return True + + return False def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - checks = _parse_types_or_checks(types_or_checks) - return all(any([check(obj) for obj in flat_sample]) for check in checks) + for check in _parse_types_or_checks(types_or_checks): + passed_check = False + for obj in flat_sample: + if check(obj): + passed_check = True + break + + if not passed_check: + return False + + return True def is_simple_tensor(inpt: Any) -> bool: From f276362b7c4e15888f8f66af4d258bd36ae269fe Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 16:02:45 +0200 Subject: [PATCH 6/8] simplify --- torchvision/prototype/transforms/_utils.py | 23 +++++----------------- 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 3eea6b8ee00..9918ad584ac 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, cast, List, Tuple, Type, Union +from typing import Any, Callable, Tuple, Type, Union import PIL.Image import torch @@ -39,25 +39,12 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im return channels, height, width -def _parse_types_or_checks( - types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...] -) -> List[Callable[[Any], bool]]: - checks = [] - for type_or_check in types_or_checks: - if isinstance(type_or_check, type): - check = cast(Callable[[Any], bool], lambda obj, typ=type_or_check: isinstance(obj, typ)) - else: - check = type_or_check - checks.append(check) - return checks - - def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - for check in _parse_types_or_checks(types_or_checks): + for type_or_check in types_or_checks: passed_check = False for obj in flat_sample: - if check(obj): + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): passed_check = True break @@ -69,10 +56,10 @@ def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) - def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) - for check in _parse_types_or_checks(types_or_checks): + for type_or_check in types_or_checks: passed_check = False for obj in flat_sample: - if check(obj): + if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): passed_check = True break From 6f22eda410d01d1761e7bbdd84466fa60280d73c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 16:12:41 +0200 Subject: [PATCH 7/8] use concreate examples in test rather than abstract ones --- test/test_prototype_transforms_utils.py | 96 ++++++++++++++++--------- 1 file changed, 64 insertions(+), 32 deletions(-) diff --git a/test/test_prototype_transforms_utils.py b/test/test_prototype_transforms_utils.py index f20c45b1484..b83c4f3acb9 100644 --- a/test/test_prototype_transforms_utils.py +++ b/test/test_prototype_transforms_utils.py @@ -1,24 +1,44 @@ +import PIL.Image import pytest -from torchvision.prototype.transforms._utils import has_all, has_any + +import torch + +from test_prototype_transforms_functional import make_bounding_box, make_image, make_segmentation_mask + +from torchvision.prototype import features +from torchvision.prototype.transforms._utils import has_all, has_any, is_simple_tensor +from torchvision.prototype.transforms.functional import to_image_pil + + +IMAGE = make_image(color_space=features.ColorSpace.RGB) +BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) +SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size) @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((0, 0.0, ""), (int,), True), - ((0, 0.0, ""), (float,), True), - ((0, 0.0, ""), (str,), True), - ((0, 0.0, ""), (int, float), True), - ((0, 0.0, ""), (int, str), True), - ((0, 0.0, ""), (float, str), True), - (("",), (int, float), False), - ((0.0,), (int, str), False), - ((0,), (float, str), False), - ((0, 0.0, ""), (int, float, str), True), - ((), (int, float, str), False), - ((0, 0.0, ""), (lambda obj: isinstance(obj, int),), True), - ((0, 0.0, ""), (lambda _: False,), False), - ((0, 0.0, ""), (lambda _: True,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), + ((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False), + ((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False), + ((IMAGE,), (features.BoundingBox, features.SegmentationMask), False), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (features.Image, features.BoundingBox, features.SegmentationMask), + True, + ), + ((), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), + ((IMAGE,), (features.Image, PIL.Image.Image, is_simple_tensor), True), + ((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), + ((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, is_simple_tensor), True), ], ) def test_has_any(sample, types, expected): @@ -28,23 +48,35 @@ def test_has_any(sample, types, expected): @pytest.mark.parametrize( ("sample", "types", "expected"), [ - ((0, 0.0, ""), (int,), True), - ((0, 0.0, ""), (float,), True), - ((0, 0.0, ""), (str,), True), - ((0, 0.0, ""), (int, float), True), - ((0, 0.0, ""), (int, str), True), - ((0, 0.0, ""), (float, str), True), - ((0, 0.0, ""), (int, float, str), True), - ((0.0, ""), (int, float), False), - ((0.0, ""), (int, str), False), - ((0, ""), (float, str), False), - ((0, 0.0, ""), (int, float, str), True), - ((0.0, ""), (int, float, str), False), - ((0, ""), (int, float, str), False), - ((0, 0.0), (int, float, str), False), - ((0, 0.0, ""), (lambda obj: isinstance(obj, (int, float, str)),), True), - ((0, 0.0, ""), (lambda _: False,), False), - ((0, 0.0, ""), (lambda _: True,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (features.Image, features.BoundingBox, features.SegmentationMask), + True, + ), + ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False), + ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False), + ((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (features.Image, features.BoundingBox, features.SegmentationMask), + True, + ), + ((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False), + ( + (IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), + (lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),), + True, + ), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), + ((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), ], ) def test_has_all(sample, types, expected): From 4ea9bef2e57af3ace1205864aea64a4f6381ce18 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 18 Aug 2022 16:38:39 +0200 Subject: [PATCH 8/8] simplify further --- torchvision/prototype/transforms/_utils.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 9918ad584ac..fe06132ca1c 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -42,30 +42,20 @@ def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Im def has_any(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) for type_or_check in types_or_checks: - passed_check = False for obj in flat_sample: if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): - passed_check = True - break - - if passed_check: - return True - + return True return False def has_all(sample: Any, *types_or_checks: Union[Type, Callable[[Any], bool]]) -> bool: flat_sample, _ = tree_flatten(sample) for type_or_check in types_or_checks: - passed_check = False for obj in flat_sample: if isinstance(obj, type_or_check) if isinstance(type_or_check, type) else type_or_check(obj): - passed_check = True break - - if not passed_check: + else: return False - return True