Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,23 @@

import PIL.Image
import torch

from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.prototype.utils._internal import query_recursively
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

from ._utils import get_image_dimensions
from ._utils import get_image_dimensions, is_simple_tensor

K = TypeVar("K")
V = TypeVar("V")


def _put_into_sample(sample: Any, id: Tuple[Any, ...], item: Any) -> Any:
if not id:
return item

parent = sample
for key in id[:-1]:
parent = parent[key]

parent[id[-1]] = item
return sample
def _put_into_sample(sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)


class _AutoAugmentBase(Transform):
Expand All @@ -47,18 +42,15 @@ def _extract_image(
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]:
def fn(
id: Tuple[Any, ...], inpt: Any
) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]:
if type(inpt) in {torch.Tensor, features.Image} or isinstance(inpt, PIL.Image.Image):
return id, inpt
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")
else:
return None

images = list(query_recursively(fn, sample))
if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
Expand Down
20 changes: 1 addition & 19 deletions torchvision/prototype/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import io
import mmap
import platform
from typing import Any, BinaryIO, Callable, Collection, Iterator, Optional, Sequence, Tuple, TypeVar, Union
from typing import BinaryIO, Callable, Collection, Sequence, TypeVar, Union

import numpy as np
import torch
Expand All @@ -14,7 +14,6 @@
"add_suggestion",
"fromfile",
"ReadOnlyTensorBuffer",
"query_recursively",
]


Expand Down Expand Up @@ -125,20 +124,3 @@ def read(self, size: int = -1) -> bytes:
cursor = self.tell()
offset, whence = (0, io.SEEK_END) if size == -1 else (size, io.SEEK_CUR)
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()


def query_recursively(
fn: Callable[[Tuple[Any, ...], Any], Optional[D]], obj: Any, *, id: Tuple[Any, ...] = ()
) -> Iterator[D]:
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
for idx, item in enumerate(obj):
yield from query_recursively(fn, item, id=(*id, idx))
elif isinstance(obj, collections.abc.Mapping):
for key, item in obj.items():
yield from query_recursively(fn, item, id=(*id, key))
else:
result = fn(id, obj)
if result is not None:
yield result