diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 4bf0236a1d1..7f16279e8c2 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -3,28 +3,23 @@ import PIL.Image import torch + +from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from torchvision.prototype.utils._internal import query_recursively from torchvision.transforms.autoaugment import AutoAugmentPolicy from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image -from ._utils import get_image_dimensions +from ._utils import get_image_dimensions, is_simple_tensor K = TypeVar("K") V = TypeVar("V") -def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any: - if not id: - return item - - parent = sample - for key in id[:-1]: - parent = parent[key] - - parent[id[-1]] = item - return sample +def _put_into_sample(sample: Any, id: int, item: Any) -> Any: + sample_flat, spec = tree_flatten(sample) + sample_flat[id] = item + return tree_unflatten(sample_flat, spec) class _AutoAugmentBase(Transform): @@ -47,18 +42,15 @@ def _extract_image( self, sample: Any, unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask), - ) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]: - def fn( - id: Tuple[Any, ...], inpt: Any - ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: - if type(inpt) in {torch.Tensor, features.Image} or isinstance(inpt, PIL.Image.Image): - return id, inpt + ) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]: + sample_flat, _ = tree_flatten(sample) + images = [] + for id, inpt in enumerate(sample_flat): + if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt): + images.append((id, inpt)) elif isinstance(inpt, unsupported_types): raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()") - else: - return None - images = list(query_recursively(fn, sample)) if not images: raise TypeError("Found no image in the sample.") if len(images) > 1: diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index 8ddf2aa6178..3dee4b59a7a 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -3,7 +3,7 @@ import io import mmap import platform -from typing import Any, BinaryIO, Callable, Collection, Iterator, Optional, Sequence, Tuple, TypeVar, Union +from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union import numpy as np import torch @@ -14,7 +14,6 @@ "add_suggestion", "fromfile", "ReadOnlyTensorBuffer", - "query_recursively", ] @@ -125,20 +124,3 @@ def read(self, size: int = -1) -> bytes: cursor = self.tell() offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR) return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() - - -def query_recursively( - fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = () -) -> Iterator[D]: - # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: - # "a" == "a"[0][0]... - if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): - for idx, item in enumerate(obj): - yield from query_recursively(fn, item, id=(*id, idx)) - elif isinstance(obj, collections.abc.Mapping): - for key, item in obj.items(): - yield from query_recursively(fn, item, id=(*id, key)) - else: - result = fn(id, obj) - if result is not None: - yield result