From 8916cdf5180f98ab6aad14cc975e06bceefb2647 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Mar 2022 15:14:57 +0100 Subject: [PATCH 1/5] port FiveCrop and TenCrop to prototype API --- torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_geometry.py | 77 +++++++++++++++ .../transforms/functional/__init__.py | 6 ++ .../transforms/functional/_geometry.py | 95 ++++++++++++++++++- torchvision/prototype/utils/_internal.py | 40 +++++++- 5 files changed, 213 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 98ad7ae0d74..4641cc5ab86 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -7,7 +7,7 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix from ._container import Compose, RandomApply, RandomChoice, RandomOrder -from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop +from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4bc3c14070f..7cb8f597da5 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,3 +1,4 @@ +import functools import math import warnings from typing import Any, Dict, List, Union, Sequence, Tuple, cast @@ -6,6 +7,8 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F +from torchvision.prototype.utils._internal import apply_recursively +from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from ._utils import query_image, get_image_dimensions, has_any @@ -168,3 +171,77 @@ def forward(self, *inputs: Any) -> Any: if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) + + +class FiveCrop(Transform): + def __init__(self, size: Union[int, Sequence[int]]) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.five_crop_image_tensor(input, self.size) + return F._FiveCropResult(*[features.Image.new_like(input, o) for o in output]) + elif type(input) is torch.Tensor: + return F.five_crop_image_tensor(input, self.size) + elif isinstance(input, PIL.Image.Image): + return F.five_crop_image_pil(input, self.size) + else: + return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) + + +class TenCrop(Transform): + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip) + return F._TenCropResult(*[features.Image.new_like(input, o) for o in output]) + elif type(input) is torch.Tensor: + return F.ten_crop_image_tensor(input, self.size) + elif isinstance(input, PIL.Image.Image): + return F.five_crop_image_pil(input, self.size) + else: + return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) + + +class BatchMultiCrop(Transform): + _MULTI_CROP_TYPES = (F._FiveCropResult, F._TenCropResult) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, self._MULTI_CROP_TYPES): + crops = input + if isinstance(input[0], PIL.Image.Image): + crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment] + + batch = torch.stack(crops) + + if isinstance(input[0], features.Image): + batch = features.Image.new_like(input[0], batch) + + return batch + else: + return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + return apply_recursively( + functools.partial(self._transform, params=self._get_params(sample)), + sample, + exclude_sequence_types=(str, *self._MULTI_CROP_TYPES), + ) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index e3fe60a7919..9616874fa8a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -60,6 +60,12 @@ perspective_image_pil, vertical_flip_image_tensor, vertical_flip_image_pil, + _FiveCropResult, + five_crop_image_tensor, + five_crop_image_pil, + _TenCropResult, + ten_crop_image_tensor, + ten_crop_image_pil, ) from ._misc import normalize_image_tensor, gaussian_blur_image_tensor from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 080fe5da891..93624b7ee13 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,4 +1,5 @@ import numbers +from typing import NamedTuple from typing import Tuple, List, Optional, Sequence, Union import PIL.Image @@ -10,7 +11,6 @@ from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil - horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip @@ -314,3 +314,96 @@ def resized_crop_image_pil( ) -> PIL.Image.Image: img = crop_image_pil(img, top, left, height, width) return resize_image_pil(img, size, interpolation=interpolation) + + +class _FiveCropResult(NamedTuple): + top_left: torch.Tensor + top_right: torch.Tensor + bottom_left: torch.Tensor + bottom_right: torch.Tensor + center: torch.Tensor + + +def _parse_five_crop_size(size: List[int]) -> List[int]: + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) # type: ignore[assignment] + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + return size + + +def five_crop_image_tensor(img: torch.Tensor, size: List[int]) -> _FiveCropResult: + crop_height, crop_width = _parse_five_crop_size(size) + _, image_height, image_width = get_dimensions_image_tensor(img) + + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = crop_image_tensor(img, 0, 0, crop_height, crop_width) + tr = crop_image_tensor(img, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image_tensor(img, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image_tensor(img, [crop_height, crop_width]) + + return _FiveCropResult(tl, tr, bl, br, center) + + +def five_crop_image_pil(img: PIL.Image.Image, size: List[int]) -> _FiveCropResult: + crop_height, crop_width = _parse_five_crop_size(size) + _, image_height, image_width = get_dimensions_image_pil(img) + + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = crop_image_pil(img, 0, 0, crop_height, crop_width) + tr = crop_image_pil(img, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image_pil(img, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image_pil(img, [crop_height, crop_width]) + + return _FiveCropResult(tl, tr, bl, br, center) + + +class _TenCropResult(NamedTuple): + top_left: torch.Tensor + top_right: torch.Tensor + bottom_left: torch.Tensor + bottom_right: torch.Tensor + center: torch.Tensor + top_left_flip: torch.Tensor + top_right_flip: torch.Tensor + bottom_left_flip: torch.Tensor + bottom_right_flip: torch.Tensor + center_flip: torch.Tensor + + +def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> _TenCropResult: + tl, tr, bl, br, center = five_crop_image_tensor(img, size) + + if vertical_flip: + img = vertical_flip_image_tensor(img) + else: + img = horizontal_flip_image_tensor(img) + + tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size) + + return _TenCropResult(tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + + +def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> _TenCropResult: + tl, tr, bl, br, center = five_crop_image_pil(img, size) + + if vertical_flip: + img = vertical_flip_image_pil(img) + else: + img = horizontal_flip_image_pil(img) + + tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size) + + return _TenCropResult(tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 366a19f2bbc..2e441836753 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -25,6 +25,7 @@ TypeVar, Union, Optional, + Type, ) import numpy as np @@ -301,13 +302,42 @@ def read(self, size: int = -1) -> bytes: return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() -def apply_recursively(fn: Callable, obj: Any) -> Any: +def apply_recursively( + fn: Callable, + obj: Any, + *, + include_sequence_types: Collection[Type] = (collections.abc.Sequence,), # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... - if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): - return [apply_recursively(fn, item) for item in obj] - elif isinstance(obj, collections.abc.Mapping): - return {key: apply_recursively(fn, item) for key, item in obj.items()} + exclude_sequence_types: Collection[Type] = (str,), + include_mapping_types: Collection[Type] = (collections.abc.Mapping,), + exclude_mapping_types: Collection[Type] = (), +) -> Any: + if isinstance(obj, tuple(include_sequence_types)) and not isinstance(obj, tuple(exclude_sequence_types)): + return [ + apply_recursively( + fn, + item, + include_sequence_types=include_sequence_types, + exclude_sequence_types=exclude_sequence_types, + include_mapping_types=include_mapping_types, + exclude_mapping_types=exclude_mapping_types, + ) + for item in obj + ] + + if isinstance(obj, tuple(include_mapping_types)) and not isinstance(obj, tuple(exclude_mapping_types)): + return { + key: apply_recursively( + fn, + item, + include_sequence_types=include_sequence_types, + exclude_sequence_types=exclude_sequence_types, + include_mapping_types=include_mapping_types, + exclude_mapping_types=exclude_mapping_types, + ) + for key, item in obj.items() + } else: return fn(obj) From 4673727ddc80fcaa50816f630c96788a10dae759 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 2 Mar 2022 16:41:11 +0100 Subject: [PATCH 2/5] fix ten crop for pil --- torchvision/prototype/transforms/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 7cb8f597da5..4cd5c7beab9 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -209,7 +209,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: elif type(input) is torch.Tensor: return F.ten_crop_image_tensor(input, self.size) elif isinstance(input, PIL.Image.Image): - return F.five_crop_image_pil(input, self.size) + return F.ten_crop_image_pil(input, self.size) else: return input From 1d769c34507ea6546600aca8f0196bb9f172873c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 4 Mar 2022 12:29:23 +0000 Subject: [PATCH 3/5] Update torchvision/prototype/transforms/_geometry.py Co-authored-by: Philip Meier --- torchvision/prototype/transforms/_geometry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 7cb8f597da5..4cd5c7beab9 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -209,7 +209,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: elif type(input) is torch.Tensor: return F.ten_crop_image_tensor(input, self.size) elif isinstance(input, PIL.Image.Image): - return F.five_crop_image_pil(input, self.size) + return F.ten_crop_image_pil(input, self.size) else: return input From dec31cd1108db0979f281cdf14fd17dace8323cd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Mar 2022 11:34:21 +0100 Subject: [PATCH 4/5] simplify implementation --- torchvision/prototype/transforms/_geometry.py | 77 +++++++++++-------- .../transforms/functional/__init__.py | 2 - .../transforms/functional/_geometry.py | 42 +++------- torchvision/prototype/utils/_internal.py | 40 ++-------- 4 files changed, 61 insertions(+), 100 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 4da8d5c6d15..e04e9f819f3 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,4 +1,4 @@ -import functools +import collections.abc import math import warnings from typing import Any, Dict, List, Union, Sequence, Tuple, cast @@ -7,7 +7,6 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F -from torchvision.prototype.utils._internal import apply_recursively from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int @@ -173,6 +172,17 @@ def forward(self, *inputs: Any) -> Any: return super().forward(sample) +class MultiCropResult(list): + """Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`. + + Outputs of multi crop transforms such as :class:`~torchvision.prototype.transforms.FiveCrop` and + `:class:`~torchvision.prototype.transforms.TenCrop` should be wrapped in this in order to be batched correctly by + :class:`~torchvision.prototype.transforms.BatchMultiCrop`. + """ + + pass + + class FiveCrop(Transform): def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() @@ -181,11 +191,11 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image): output = F.five_crop_image_tensor(input, self.size) - return F._FiveCropResult(*[features.Image.new_like(input, o) for o in output]) - elif type(input) is torch.Tensor: - return F.five_crop_image_tensor(input, self.size) + return MultiCropResult(features.Image.new_like(input, o) for o in output) + elif is_simple_tensor(input): + return MultiCropResult(F.five_crop_image_tensor(input, self.size)) elif isinstance(input, PIL.Image.Image): - return F.five_crop_image_pil(input, self.size) + return MultiCropResult(F.five_crop_image_pil(input, self.size)) else: return input @@ -205,11 +215,11 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if isinstance(input, features.Image): output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip) - return F._TenCropResult(*[features.Image.new_like(input, o) for o in output]) - elif type(input) is torch.Tensor: - return F.ten_crop_image_tensor(input, self.size) + return MultiCropResult(features.Image.new_like(input, o) for o in output) + elif is_simple_tensor(input): + return MultiCropResult(F.ten_crop_image_tensor(input, self.size)) elif isinstance(input, PIL.Image.Image): - return F.ten_crop_image_pil(input, self.size) + return MultiCropResult(F.ten_crop_image_pil(input, self.size)) else: return input @@ -221,27 +231,28 @@ def forward(self, *inputs: Any) -> Any: class BatchMultiCrop(Transform): - _MULTI_CROP_TYPES = (F._FiveCropResult, F._TenCropResult) - - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, self._MULTI_CROP_TYPES): - crops = input - if isinstance(input[0], PIL.Image.Image): - crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment] - - batch = torch.stack(crops) - - if isinstance(input[0], features.Image): - batch = features.Image.new_like(input[0], batch) - - return batch - else: - return input - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - return apply_recursively( - functools.partial(self._transform, params=self._get_params(sample)), - sample, - exclude_sequence_types=(str, *self._MULTI_CROP_TYPES), - ) + # This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one + # significant difference: + # Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from + # the sequence case. + def apply_recursively(obj: Any) -> Any: + if isinstance(obj, MultiCropResult): + crops = obj + if isinstance(obj[0], PIL.Image.Image): + crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment] + + batch = torch.stack(crops) + + if isinstance(obj[0], features.Image): + batch = features.Image.new_like(obj[0], batch) + + return batch + elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + return [apply_recursively(item) for item in obj] + elif isinstance(obj, collections.abc.Mapping): + return {key: apply_recursively(item) for key, item in obj.items()} + else: + return obj + + return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 9616874fa8a..c0825784f66 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -60,10 +60,8 @@ perspective_image_pil, vertical_flip_image_tensor, vertical_flip_image_pil, - _FiveCropResult, five_crop_image_tensor, five_crop_image_pil, - _TenCropResult, ten_crop_image_tensor, ten_crop_image_pil, ) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 93624b7ee13..c8b5189355b 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,5 +1,4 @@ import numbers -from typing import NamedTuple from typing import Tuple, List, Optional, Sequence, Union import PIL.Image @@ -316,14 +315,6 @@ def resized_crop_image_pil( return resize_image_pil(img, size, interpolation=interpolation) -class _FiveCropResult(NamedTuple): - top_left: torch.Tensor - top_right: torch.Tensor - bottom_left: torch.Tensor - bottom_right: torch.Tensor - center: torch.Tensor - - def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): size = (int(size), int(size)) @@ -336,7 +327,9 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: return size -def five_crop_image_tensor(img: torch.Tensor, size: List[int]) -> _FiveCropResult: +def five_crop_image_tensor( + img: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: crop_height, crop_width = _parse_five_crop_size(size) _, image_height, image_width = get_dimensions_image_tensor(img) @@ -350,10 +343,12 @@ def five_crop_image_tensor(img: torch.Tensor, size: List[int]) -> _FiveCropResul br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) center = center_crop_image_tensor(img, [crop_height, crop_width]) - return _FiveCropResult(tl, tr, bl, br, center) + return tl, tr, bl, br, center -def five_crop_image_pil(img: PIL.Image.Image, size: List[int]) -> _FiveCropResult: +def five_crop_image_pil( + img: PIL.Image.Image, size: List[int] +) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: crop_height, crop_width = _parse_five_crop_size(size) _, image_height, image_width = get_dimensions_image_pil(img) @@ -367,23 +362,10 @@ def five_crop_image_pil(img: PIL.Image.Image, size: List[int]) -> _FiveCropResul br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) center = center_crop_image_pil(img, [crop_height, crop_width]) - return _FiveCropResult(tl, tr, bl, br, center) - - -class _TenCropResult(NamedTuple): - top_left: torch.Tensor - top_right: torch.Tensor - bottom_left: torch.Tensor - bottom_right: torch.Tensor - center: torch.Tensor - top_left_flip: torch.Tensor - top_right_flip: torch.Tensor - bottom_left_flip: torch.Tensor - bottom_right_flip: torch.Tensor - center_flip: torch.Tensor + return tl, tr, bl, br, center -def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> _TenCropResult: +def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: tl, tr, bl, br, center = five_crop_image_tensor(img, size) if vertical_flip: @@ -393,10 +375,10 @@ def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: boo tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size) - return _TenCropResult(tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] -def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> _TenCropResult: +def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]: tl, tr, bl, br, center = five_crop_image_pil(img, size) if vertical_flip: @@ -406,4 +388,4 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size) - return _TenCropResult(tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index d9faafed154..864bff9ce02 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -24,7 +24,6 @@ TypeVar, Union, Optional, - Type, ) import numpy as np @@ -289,42 +288,13 @@ def read(self, size: int = -1) -> bytes: return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() -def apply_recursively( - fn: Callable, - obj: Any, - *, - include_sequence_types: Collection[Type] = (collections.abc.Sequence,), +def apply_recursively(fn: Callable, obj: Any) -> Any: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... - exclude_sequence_types: Collection[Type] = (str,), - include_mapping_types: Collection[Type] = (collections.abc.Mapping,), - exclude_mapping_types: Collection[Type] = (), -) -> Any: - if isinstance(obj, tuple(include_sequence_types)) and not isinstance(obj, tuple(exclude_sequence_types)): - return [ - apply_recursively( - fn, - item, - include_sequence_types=include_sequence_types, - exclude_sequence_types=exclude_sequence_types, - include_mapping_types=include_mapping_types, - exclude_mapping_types=exclude_mapping_types, - ) - for item in obj - ] - - if isinstance(obj, tuple(include_mapping_types)) and not isinstance(obj, tuple(exclude_mapping_types)): - return { - key: apply_recursively( - fn, - item, - include_sequence_types=include_sequence_types, - exclude_sequence_types=exclude_sequence_types, - include_mapping_types=include_mapping_types, - exclude_mapping_types=exclude_mapping_types, - ) - for key, item in obj.items() - } + if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + return [apply_recursively(fn, item) for item in obj] + elif isinstance(obj, collections.abc.Mapping): + return {key: apply_recursively(fn, item) for key, item in obj.items()} else: return fn(obj) From 4302084a91c03afbfe58af40d8595f5c0db4f3af Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Mar 2022 11:36:00 +0100 Subject: [PATCH 5/5] minor cleanup --- torchvision/prototype/transforms/functional/_geometry.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c8b5189355b..6c9309749af 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -10,6 +10,7 @@ from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil + horizontal_flip_image_tensor = _FT.hflip horizontal_flip_image_pil = _FP.hflip