From 5e52ed26108115dee0ef40dbbea70087cf12621e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 14 Feb 2022 16:30:00 +0100 Subject: [PATCH 01/25] add prototype transforms that don't need dispatchers --- test/test_prototype_builtin_datasets.py | 1 - torchvision/prototype/transforms/__init__.py | 13 +- torchvision/prototype/transforms/_augment.py | 136 ++++++++ .../prototype/transforms/_auto_augment.py | 317 ++++++++++++++++++ .../prototype/transforms/_container.py | 63 ++++ torchvision/prototype/transforms/_geometry.py | 105 ++++++ .../prototype/transforms/_meta_conversion.py | 58 ++++ torchvision/prototype/transforms/_misc.py | 63 ++++ .../prototype/transforms/_transform.py | 20 ++ .../prototype/transforms/_type_conversion.py | 38 +++ .../transforms/functional/__init__.py | 14 - .../transforms/functional/_augment.py | 57 ---- .../prototype/transforms/functional/_color.py | 119 ------- .../transforms/functional/_geometry.py | 95 ------ .../prototype/transforms/functional/_misc.py | 21 -- .../prototype/transforms/functional/_utils.py | 89 ----- torchvision/prototype/transforms/utils.py | 91 +++++ torchvision/prototype/utils/_internal.py | 36 +- 18 files changed, 920 insertions(+), 416 deletions(-) create mode 100644 torchvision/prototype/transforms/_augment.py create mode 100644 torchvision/prototype/transforms/_auto_augment.py create mode 100644 torchvision/prototype/transforms/_container.py create mode 100644 torchvision/prototype/transforms/_geometry.py create mode 100644 torchvision/prototype/transforms/_meta_conversion.py create mode 100644 torchvision/prototype/transforms/_misc.py create mode 100644 torchvision/prototype/transforms/_transform.py create mode 100644 torchvision/prototype/transforms/_type_conversion.py delete mode 100644 torchvision/prototype/transforms/functional/__init__.py delete mode 100644 torchvision/prototype/transforms/functional/_augment.py delete mode 100644 torchvision/prototype/transforms/functional/_color.py delete mode 100644 torchvision/prototype/transforms/functional/_geometry.py delete mode 100644 torchvision/prototype/transforms/functional/_misc.py delete mode 100644 torchvision/prototype/transforms/functional/_utils.py create mode 100644 torchvision/prototype/transforms/utils.py diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 067359cac2b..eaa92094ad7 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -99,7 +99,6 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config): f"{sequence_to_str(sorted(vanilla_tensors), separate_last='and ')} contained vanilla tensors." ) - @pytest.mark.xfail @parametrize_dataset_mocks(DATASET_MOCKS) def test_transformable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index c9988be1930..3efd7a4130f 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,5 +1,14 @@ +from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip + from . import kernels # usort: skip -from . import functional # usort: skip -from .kernels import InterpolationMode # usort: skip +from ._transform import Transform # usort: skip + +from ._augment import RandomErasing, RandomMixup, RandomCutmix +from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment +from ._container import Compose, RandomApply, RandomChoice, RandomOrder +from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop +from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace +from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval +from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py new file mode 100644 index 00000000000..c7324f2b2a4 --- /dev/null +++ b/torchvision/prototype/transforms/_augment.py @@ -0,0 +1,136 @@ +import math +from typing import Any, Dict, Tuple + +import torch +from torchvision import transforms as _transforms +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.transforms import functional as _F + +from .utils import Query + + +class RandomErasing(Transform): + _LEGACY_TRANSFORM_CLS = _transforms.RandomErasing + + def __init__( + self, + p: float = 0.5, + scale: Tuple[float, float] = (0.02, 0.33), + ratio: Tuple[float, float] = (0.3, 3.3), + value: float = 0, + ): + super().__init__() + legacy_transform = self._LEGACY_TRANSFORM_CLS(p=p, scale=scale, ratio=ratio, value=value, inplace=False) + # TODO: deprecate p in favor of wrapping the transform in a RandomApply + self.p = legacy_transform.p + self.scale = legacy_transform.scale + self.ratio = legacy_transform.ratio + self.value = legacy_transform.value + + def get_params(self, sample: Any) -> Dict[str, Any]: + image = Query(sample).image_for_size_and_channels_extraction() + + if isinstance(self.value, (int, float)): + value = [self.value] + elif isinstance(self.value, str): + value = None + elif isinstance(self.value, tuple): + value = list(self.value) + else: + value = self.value + + if value is not None and not (len(value) in (1, image.shape[-3])): + raise ValueError( + "If value is a sequence, it should have either a single value or " + f"{image.shape[-3]} (number of input channels)" + ) + + return dict( + zip("ijhwv", self._LEGACY_TRANSFORM_CLS.get_params(image, scale=self.scale, ratio=self.ratio, value=value)) + ) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if torch.rand(1) >= self.p: + return input + + if type(input) is torch.Tensor: + return _F.erase(input, **params) + elif type(input) is features.Image: + return features.Image.new_like(input, K.erase_image(input, **params)) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("p", "scale", "ratio", "value") + + +class RandomMixup(Transform): + def __init__(self, *, alpha: float) -> None: + super().__init__() + self.alpha = alpha + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + + def get_params(self, sample: Any) -> Dict[str, Any]: + return dict(lam=float(self._dist.sample(()))) + + def _supports(self, obj: Any) -> bool: + return type(obj) in {features.Image, features.OneHotLabel} + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.mixup_image(input, **params) + return features.Image.new_like(input, output) + elif type(input) is features.OneHotLabel: + output = K.mixup_one_hot_label(input, **params) + return features.OneHotLabel.new_like(input, output) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("alpha") + + +class RandomCutmix(Transform): + def __init__(self, *, alpha: float) -> None: + super().__init__() + self.alpha = alpha + self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + + def get_params(self, sample: Any) -> Dict[str, Any]: + lam = float(self._dist.sample(())) + + H, W = Query(sample).image_size() + + r_x = torch.randint(W, ()) + r_y = torch.randint(H, ()) + + r = 0.5 * math.sqrt(1.0 - lam) + r_w_half = int(r * W) + r_h_half = int(r * H) + + x1 = int(torch.clamp(r_x - r_w_half, min=0)) + y1 = int(torch.clamp(r_y - r_h_half, min=0)) + x2 = int(torch.clamp(r_x + r_w_half, max=W)) + y2 = int(torch.clamp(r_y + r_h_half, max=H)) + box = (x1, y1, x2, y2) + + lam_adjusted = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) + + return dict(box=box, lam_adjusted=lam_adjusted) + + def _supports(self, obj: Any) -> bool: + return type(obj) in {features.Image, features.OneHotLabel} + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.cutmix_image(input, box=params["box"]) + return features.Image.new_like(input, output) + elif type(input) is features.OneHotLabel: + output = K.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) + return features.OneHotLabel.new_like(input, output) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("alpha") diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py new file mode 100644 index 00000000000..77a0d470e29 --- /dev/null +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -0,0 +1,317 @@ +import dataclasses +import math +from typing import Any, Dict, Tuple, Optional, Callable, List, cast, Iterator + +import PIL.Image +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, kernels as K +from torchvision.prototype.utils._internal import apply_recursively +from torchvision.transforms import AutoAugment as _AutoAugment, functional as _F + +from .utils import Query + + +@dataclasses.dataclass +class AutoAugmentDispatcher: + kernel: Callable + legacy_kernel: Callable + magnitude_fn: Optional[Callable[[float], Dict[str, Any]]] = None + extra_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) + takes_interpolation_kwargs: bool = False + + def __call__( + self, input: Any, *, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]] + ) -> Any: + kwargs = self.extra_kwargs.copy() + if self.magnitude_fn is not None: + kwargs.update(self.magnitude_fn(magnitude)) + if self.takes_interpolation_kwargs: + kwargs.update(dict(interpolation=interpolation, fill=fill)) + + kernel = self.kernel if type(input) is features.Image else self.legacy_kernel + return kernel(input, **kwargs) + + +class _AutoAugmentBase(Transform): + def __init__( + self, *, interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None + ) -> None: + super().__init__() + self.interpolation = interpolation + self.fill = fill + + _DISPATCHER_MAP = { + "ShearX": AutoAugmentDispatcher( + K.affine_image, + _F.affine, + magnitude_fn=lambda magnitude: dict(shear=[math.degrees(magnitude), 0]), + extra_kwargs=dict(angle=0.0, translate=[0, 0], scale=1.0), + takes_interpolation_kwargs=True, + ), + "ShearY": AutoAugmentDispatcher( + K.affine_image, + _F.affine, + magnitude_fn=lambda magnitude: dict(shear=[0, math.degrees(magnitude)]), + extra_kwargs=dict(angle=0.0, translate=[0, 0], scale=1.0), + takes_interpolation_kwargs=True, + ), + "TranslateX": AutoAugmentDispatcher( + K.affine_image, + _F.affine, + magnitude_fn=lambda magnitude: dict(translate=[int(magnitude), 0]), + extra_kwargs=dict(angle=0.0, scale=1.0, shear=[0.0, 0.0]), + takes_interpolation_kwargs=True, + ), + "TranslateY": AutoAugmentDispatcher( + K.affine_image, + _F.affine, + magnitude_fn=lambda magnitude: dict(translate=[0, int(magnitude)]), + extra_kwargs=dict(angle=0.0, scale=1.0, shear=[0.0, 0.0]), + takes_interpolation_kwargs=True, + ), + "Rotate": AutoAugmentDispatcher( + K.rotate_image, _F.rotate, magnitude_fn=lambda magnitude: dict(angle=magnitude) + ), + "Brightness": AutoAugmentDispatcher( + K.adjust_brightness_image, + _F.adjust_brightness, + magnitude_fn=lambda magnitude: dict(brightness_factor=1.0 + magnitude), + ), + "Color": AutoAugmentDispatcher( + K.adjust_saturation_image, + _F.adjust_saturation, + magnitude_fn=lambda magnitude: dict(saturation_factor=1.0 + magnitude), + ), + "Contrast": AutoAugmentDispatcher( + K.adjust_contrast_image, + _F.adjust_contrast, + magnitude_fn=lambda magnitude: dict(contrast_factor=1.0 + magnitude), + ), + "Sharpness": AutoAugmentDispatcher( + K.adjust_sharpness_image, + _F.adjust_sharpness, + magnitude_fn=lambda magnitude: dict(sharpness_factor=1.0 + magnitude), + ), + "Posterize": AutoAugmentDispatcher( + K.posterize_image, _F.posterize, magnitude_fn=lambda magnitude: dict(bits=int(magnitude)) + ), + "Solarize": AutoAugmentDispatcher( + K.solarize_image, _F.solarize, magnitude_fn=lambda magnitude: dict(threshold=magnitude) + ), + "AutoContrast": AutoAugmentDispatcher(K.autocontrast_image, _F.autocontrast), + "Equalize": AutoAugmentDispatcher(K.equalize_image, _F.equalize), + "Invert": AutoAugmentDispatcher(K.invert_image, _F.invert), + } + + def get_params(self, sample: Any) -> Dict[str, Any]: + image = Query(sample).image_for_size_and_channels_extraction() + + fill = self.fill + if isinstance(fill, (int, float)): + fill = [float(fill)] * image.num_channels + elif fill is not None: + fill = [float(f) for f in fill] + + return dict(fill=fill) + + def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: + raise NotImplementedError + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + params = params or self.get_params(sample) + + for transform_id, magnitude in self.get_transforms_meta(Query(sample).image_size()): + dispatcher = self._DISPATCHER_MAP[transform_id] + + def transform(input: Any) -> Any: + if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + return dispatcher( # type: ignore[arg-type] + input, + magnitude=magnitude, + interpolation=self.interpolation, + **params, + ) + else: + return input + + sample = apply_recursively(transform, sample) + + return sample + + def _randbool(self, p: float = 0.5) -> bool: + """Randomly returns either ``True`` or ``False``. + + Args: + p: Probability to return ``True``. Defaults to ``0.5``. + """ + return float(torch.rand(())) <= p + + +@dataclasses.dataclass +class AugmentationMeta: + dispatcher_id: str + magnitudes_fn: Callable[[int, Tuple[int, int]], Optional[torch.Tensor]] + signed: bool + + +class AutoAugment(_AutoAugmentBase): + _LEGACY_CLS = _AutoAugment + _AUGMENTATION_SPACE = ( + AugmentationMeta("ShearX", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + AugmentationMeta("ShearY", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + AugmentationMeta( + "TranslateX", + lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), + True, + ), + AugmentationMeta( + "TranslateY", + lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), + True, + ), + AugmentationMeta("Rotate", lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + AugmentationMeta("Brightness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta("Color", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta("Contrast", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta("Sharpness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta( + "Posterize", + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + .round() + .int(), + False, + ), + AugmentationMeta("Solarize", lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + AugmentationMeta("AutoContrast", lambda num_bins, image_size: None, False), + AugmentationMeta("Equalize", lambda num_bins, image_size: None, False), + AugmentationMeta("Invert", lambda num_bins, image_size: None, False), + ) + _AUGMENTATION_SPACE = { + augmentation_meta.dispatcher_id: augmentation_meta for augmentation_meta in _AUGMENTATION_SPACE + } + + def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.policy = policy + self._policies = self._LEGACY_CLS._get_policies(None, policy) # type: ignore[arg-type] + + def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: + policy = self._policies[int(torch.randint(len(self._policies), ()))] + + for dispatcher_id, probability, magnitude_idx in policy: + if not self._randbool(probability): + continue + + augmentation_meta = self._AUGMENTATION_SPACE[dispatcher_id] + + magnitudes = augmentation_meta.magnitudes_fn(10, image_size) + if magnitudes is not None: + magnitude = float(magnitudes[magnitude_idx]) + if augmentation_meta.signed and self._randbool(): + magnitude *= -1 + else: + magnitude = 0.0 + + yield augmentation_meta.dispatcher_id, magnitude + + +class RandAugment(_AutoAugmentBase): + _AUGMENTATION_SPACE = ( + AugmentationMeta("Identity", lambda num_bins, image_size: None, False), + AugmentationMeta("ShearX", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + AugmentationMeta("ShearY", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + AugmentationMeta( + "TranslateX", + lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), + True, + ), + AugmentationMeta( + "TranslateY", + lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), + True, + ), + AugmentationMeta("Rotate", lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + AugmentationMeta("Brightness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta("Color", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta("Contrast", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta("Sharpness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + AugmentationMeta( + "Posterize", + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) + .round() + .int(), + False, + ), + AugmentationMeta("Solarize", lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + AugmentationMeta("AutoContrast", lambda num_bins, image_size: None, False), + AugmentationMeta("Equalize", lambda num_bins, image_size: None, False), + ) + + def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.num_ops = num_ops + self.magnitude = magnitude + self.num_magnitude_bins = num_magnitude_bins + + def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: + for _ in range(self.num_ops): + augmentation_meta = self._AUGMENTATION_SPACE[int(torch.randint(len(self._AUGMENTATION_SPACE), ()))] + if augmentation_meta.dispatcher_id == "Identity": + continue + + magnitudes = augmentation_meta.magnitudes_fn(self.num_magnitude_bins, image_size) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) + if augmentation_meta.signed and self._randbool(): + magnitude *= -1 + else: + magnitude = 0.0 + + yield augmentation_meta.dispatcher_id, magnitude + + +class TrivialAugmentWide(_AutoAugmentBase): + _AUGMENTATION_SPACE = ( + AugmentationMeta("Identity", lambda num_bins, image_size: None, False), + AugmentationMeta("ShearX", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + AugmentationMeta("ShearY", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + AugmentationMeta("TranslateX", lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + AugmentationMeta("TranslateY", lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + AugmentationMeta("Rotate", lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), + AugmentationMeta("Brightness", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + AugmentationMeta("Color", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + AugmentationMeta("Contrast", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + AugmentationMeta("Sharpness", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + AugmentationMeta( + "Posterize", + lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) + .round() + .int(), + False, + ), + AugmentationMeta("Solarize", lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + AugmentationMeta("AutoContrast", lambda num_bins, image_size: None, False), + AugmentationMeta("Equalize", lambda num_bins, image_size: None, False), + ) + + def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): + super().__init__(**kwargs) + self.num_magnitude_bins = num_magnitude_bins + + def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: + augmentation_meta = self._AUGMENTATION_SPACE[int(torch.randint(len(self._AUGMENTATION_SPACE), ()))] + + if augmentation_meta.dispatcher_id == "Identity": + return + + magnitudes = augmentation_meta.magnitudes_fn(self.num_magnitude_bins, image_size) + if magnitudes is not None: + magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) + if augmentation_meta.signed and self._randbool(): + magnitude *= -1 + else: + magnitude = 0.0 + + yield augmentation_meta.dispatcher_id, magnitude diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py new file mode 100644 index 00000000000..3a183a7b884 --- /dev/null +++ b/torchvision/prototype/transforms/_container.py @@ -0,0 +1,63 @@ +from typing import Any, Optional, Dict + +import torch + +from ._transform import Transform + + +class Compose(Transform): + def __init__(self, *transforms: Transform) -> None: + super().__init__() + self.transforms = transforms + for idx, transform in enumerate(transforms): + self.add_module(str(idx), transform) + + def forward(self, *inputs: Any) -> Any: # type: ignore[override] + sample = inputs if len(inputs) > 1 else inputs[0] + for transform in self.transforms: + sample = transform(sample) + return sample + + +class RandomApply(Transform): + def __init__(self, transform: Transform, *, p: float = 0.5) -> None: + super().__init__() + self.transform = transform + self.p = p + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if float(torch.rand(())) < self.p: + return sample + + return self.transform(sample, params=params) + + def extra_repr(self) -> str: + return f"p={self.p}" + + +class RandomChoice(Transform): + def __init__(self, *transforms: Transform) -> None: + super().__init__() + self.transforms = transforms + for idx, transform in enumerate(transforms): + self.add_module(str(idx), transform) + + def forward(self, *inputs: Any) -> Any: # type: ignore[override] + idx = int(torch.randint(len(self.transforms), size=())) + transform = self.transforms[idx] + return transform(*inputs) + + +class RandomOrder(Transform): + def __init__(self, *transforms: Transform) -> None: + super().__init__() + self.transforms = transforms + for idx, transform in enumerate(transforms): + self.add_module(str(idx), transform) + + def forward(self, *inputs: Any) -> Any: # type: ignore[override] + for idx in torch.randperm(len(self.transforms)): + transform = self.transforms[idx] + inputs = transform(*inputs) + return inputs diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py new file mode 100644 index 00000000000..60906e0f841 --- /dev/null +++ b/torchvision/prototype/transforms/_geometry.py @@ -0,0 +1,105 @@ +from typing import Any, Dict, List, Union, Sequence, Tuple + +from torchvision import transforms as _transforms +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, InterpolationMode, kernels as K +from torchvision.transforms import functional as _F + +from .utils import Query, legacy_transform + + +class HorizontalFlip(Transform): + @legacy_transform(_F.hflip) + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.horizontal_flip_image(input) + return features.Image.new_like(input, output) + elif type(input) is features.BoundingBox: + output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + return features.BoundingBox.new_like(input, output) + else: + return input + + +class Resize(Transform): + def __init__( + self, + size: Union[int, Sequence[int]], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.size = [size, size] if isinstance(size, int) else list(size) + self.interpolation = interpolation + + @legacy_transform(_F.resize, "size", "interpolation") + def _dispatch(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.resize_image(input, size=self.size, interpolation=self.interpolation) + return features.Image.new_like(input, output) + elif type(input) is features.SegmentationMask: + return features.SegmentationMask.new_like(input, K.resize_segmentation_mask(input, size=self.size)) + elif type(input) is features.BoundingBox: + output = K.resize_bounding_box(input, size=self.size, image_size=input.image_size) + return features.BoundingBox.new_like(input, output, image_size=self.size) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("size", "interpolation") + + +class CenterCrop(Transform): + def __init__(self, output_size: List[int]): + super().__init__() + self.output_size = output_size + + @legacy_transform(_F.center_crop, "output_size") + def _dispatch(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.center_crop_image(input, **params) + return features.Image.new_like(input, output) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("output_size") + + +class RandomResizedCrop(Transform): + _LEGACY_CLS = _transforms.RandomResizedCrop + + def __init__( + self, + size: Union[int, Sequence[int]], + scale: Tuple[float, float] = (0.08, 1.0), + ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + legacy_transform = self._LEGACY_CLS(size=size, scale=scale, ratio=ratio, interpolation=interpolation) + self.size = legacy_transform.size + self.scale = legacy_transform.scale + self.ratio = legacy_transform.ratio + self.interpolation = legacy_transform.interpolation + + def get_params(self, sample: Any) -> Dict[str, Any]: + image = Query(sample).image_for_size_extraction() + top, left, height, width = _transforms.RandomResizedCrop.get_params( + image, scale=list(self.scale), ratio=list(self.ratio) + ) + return dict( + top=top, + left=left, + height=height, + width=width, + ) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.resized_crop_image(input, size=self.size, interpolation=self.interpolation, **params) + return features.Image.new_like(input, output) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("size", "scale", "ratio", "interpolation") diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta_conversion.py new file mode 100644 index 00000000000..9ed830cdadd --- /dev/null +++ b/torchvision/prototype/transforms/_meta_conversion.py @@ -0,0 +1,58 @@ +from typing import Union, Any, Dict + +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.transforms import functional as _F + + +class ConvertBoundingBoxFormat(Transform): + def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: + super().__init__() + if isinstance(format, str): + format = features.BoundingBoxFormat[format] + self.format = format + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.BoundingBox: + output = K.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"]) + return features.BoundingBox.new_like(input, output, format=params["format"]) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("format") + + +class ConvertImageDtype(Transform): + def __init__(self, dtype: torch.dtype = torch.float32) -> None: + super().__init__() + self.dtype = dtype + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = _F.convert_image_dtype(input, dtype=self.dtype) + return features.Image.new_like(input, output, dtype=self.dtype) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("dtype") + + +class ConvertColorSpace(Transform): + def __init__(self, color_space: Union[str, features.ColorSpace]) -> None: + super().__init__() + if isinstance(color_space, str): + color_space = features.ColorSpace[color_space] + self.color_space = color_space + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=self.color_space) + return features.Image.new_like(input, output, color_space=self.color_space) + else: + return input + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("color_space") diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py new file mode 100644 index 00000000000..701b7441da3 --- /dev/null +++ b/torchvision/prototype/transforms/_misc.py @@ -0,0 +1,63 @@ +import functools +from typing import Any, List, Type, Callable, Dict + +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.transforms import functional as _F + +from .utils import legacy_transform + + +class Identity(Transform): + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + return input + + +class Lambda(Transform): + def __init__(self, fn: Callable[[Any], Any], *types: Type): + super().__init__() + self.fn = fn + self.types = types + + def _supports(self, obj: Any) -> bool: + return type(obj) in self.types + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) in self.types: + return self.fn(input) + else: + return input + + def extra_repr(self) -> str: + extras = [] + name = getattr(self.fn, "__name__", None) + if name: + extras.append(name) + extras.append(f"types={[type.__name__ for type in self.types]}") + return ", ".join(extras) + + +class Normalize(Transform): + def __init__(self, mean: List[float], std: List[float]): + super().__init__() + self.mean = mean + self.std = std + + @legacy_transform(_F.normalize, "mean", "std") + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Image: + output = K.normalize_image(input, **params) + return features.Image.new_like(input, output) + + def extra_repr(self) -> str: + return self._extra_repr_from_attrs("mean", "std") + + +class ToDtype(Lambda): + def __init__(self, dtype: torch.dtype, *types: Type) -> None: + self.dtype = dtype + super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types) + + def extra_repr(self) -> str: + return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"]) diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py new file mode 100644 index 00000000000..3e535c25a0d --- /dev/null +++ b/torchvision/prototype/transforms/_transform.py @@ -0,0 +1,20 @@ +import functools +from typing import Any, Dict, Optional + +from torch import nn +from torchvision.prototype.utils._internal import apply_recursively + + +class Transform(nn.Module): + def get_params(self, sample: Any) -> Dict[str, Any]: + return dict() + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + raise NotImplementedError + + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + return apply_recursively(functools.partial(self._transform, params=params or self.get_params(sample)), sample) + + def _extra_repr_from_attrs(self, *names: str) -> str: + return ", ".join(f"{name}={getattr(self, name)}" for name in names) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py new file mode 100644 index 00000000000..8efd356768a --- /dev/null +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -0,0 +1,38 @@ +from typing import Any, Dict + +from torchvision.prototype import features +from torchvision.prototype.transforms import Transform, kernels as K + + +class DecodeImage(Transform): + def _supports(self, obj: Any) -> bool: + return isinstance(obj, features.EncodedImage) + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.EncodedImage: + output = K.decode_image_with_pil(input) + return features.Image(output) + else: + return input + + +class LabelToOneHot(Transform): + def __init__(self, num_categories: int = -1): + super().__init__() + self.num_categories = num_categories + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is features.Label: + num_categories = self.num_categories + if num_categories == -1 and input.categories is not None: + num_categories = len(input.categories) + output = K.label_to_one_hot(input, num_categories=num_categories) + return features.OneHotLabel(output, categories=input.categories) + else: + return input + + def extra_repr(self) -> str: + if self.num_categories == -1: + return "" + + return f"num_categories={self.num_categories}" diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py deleted file mode 100644 index 9f05f16df2d..00000000000 --- a/torchvision/prototype/transforms/functional/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from ._augment import erase, mixup, cutmix -from ._color import ( - adjust_brightness, - adjust_contrast, - adjust_saturation, - adjust_sharpness, - posterize, - solarize, - autocontrast, - equalize, - invert, -) -from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate -from ._misc import normalize diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py deleted file mode 100644 index 2eafe0d3c1f..00000000000 --- a/torchvision/prototype/transforms/functional/_augment.py +++ /dev/null @@ -1,57 +0,0 @@ -from typing import TypeVar, Any - -import torch -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F - -from ._utils import dispatch - -T = TypeVar("T", bound=features._Feature) - - -@dispatch( - { - torch.Tensor: _F.erase, - features.Image: K.erase_image, - } -) -def erase(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - features.Image: K.mixup_image, - features.OneHotLabel: K.mixup_one_hot_label, - } -) -def mixup(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - features.Image: K.cutmix_image, - features.OneHotLabel: K.cutmix_one_hot_label, - } -) -def cutmix(input: T, *args: Any, **kwargs: Any) -> T: - """Perform the CutMix operation as introduced in the paper - `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. - - Dispatch to the corresponding kernels happens according to this table: - - .. table:: - :widths: 30 70 - - ==================================================== ================================================================ - :class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image` - :class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label` - ==================================================== ================================================================ - - Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. - """ - ... diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py deleted file mode 100644 index 23e128b7856..00000000000 --- a/torchvision/prototype/transforms/functional/_color.py +++ /dev/null @@ -1,119 +0,0 @@ -from typing import TypeVar, Any - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F - -from ._utils import dispatch - -T = TypeVar("T", bound=features._Feature) - - -@dispatch( - { - torch.Tensor: _F.adjust_brightness, - PIL.Image.Image: _F.adjust_brightness, - features.Image: K.adjust_brightness_image, - } -) -def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.adjust_saturation, - PIL.Image.Image: _F.adjust_saturation, - features.Image: K.adjust_saturation_image, - } -) -def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.adjust_contrast, - PIL.Image.Image: _F.adjust_contrast, - features.Image: K.adjust_contrast_image, - } -) -def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.adjust_sharpness, - PIL.Image.Image: _F.adjust_sharpness, - features.Image: K.adjust_sharpness_image, - } -) -def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.posterize, - PIL.Image.Image: _F.posterize, - features.Image: K.posterize_image, - } -) -def posterize(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.solarize, - PIL.Image.Image: _F.solarize, - features.Image: K.solarize_image, - } -) -def solarize(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.autocontrast, - PIL.Image.Image: _F.autocontrast, - features.Image: K.autocontrast_image, - } -) -def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.equalize, - PIL.Image.Image: _F.equalize, - features.Image: K.equalize_image, - } -) -def equalize(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.invert, - PIL.Image.Image: _F.invert, - features.Image: K.invert_image, - } -) -def invert(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py deleted file mode 100644 index 147baa3a066..00000000000 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import TypeVar, Any, cast - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F - -from ._utils import dispatch - -T = TypeVar("T", bound=features._Feature) - - -@dispatch( - { - torch.Tensor: _F.hflip, - PIL.Image.Image: _F.hflip, - features.Image: K.horizontal_flip_image, - features.BoundingBox: None, - }, -) -def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - if isinstance(input, features.BoundingBox): - output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) - return cast(T, features.BoundingBox.new_like(input, output)) - - raise RuntimeError - - -@dispatch( - { - torch.Tensor: _F.resize, - PIL.Image.Image: _F.resize, - features.Image: K.resize_image, - features.SegmentationMask: K.resize_segmentation_mask, - features.BoundingBox: None, - } -) -def resize(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - if isinstance(input, features.BoundingBox): - size = kwargs.pop("size") - output = K.resize_bounding_box(input, size=size, image_size=input.image_size) - return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) - - raise RuntimeError - - -@dispatch( - { - torch.Tensor: _F.center_crop, - PIL.Image.Image: _F.center_crop, - features.Image: K.center_crop_image, - } -) -def center_crop(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.resized_crop, - PIL.Image.Image: _F.resized_crop, - features.Image: K.resized_crop_image, - } -) -def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.affine, - PIL.Image.Image: _F.affine, - features.Image: K.affine_image, - } -) -def affine(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... - - -@dispatch( - { - torch.Tensor: _F.rotate, - PIL.Image.Image: _F.rotate, - features.Image: K.rotate_image, - } -) -def rotate(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py deleted file mode 100644 index 7cf0765105a..00000000000 --- a/torchvision/prototype/transforms/functional/_misc.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import TypeVar, Any - -import torch -from torchvision.prototype import features -from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F - -from ._utils import dispatch - -T = TypeVar("T", bound=features._Feature) - - -@dispatch( - { - torch.Tensor: _F.normalize, - features.Image: K.normalize_image, - } -) -def normalize(input: T, *args: Any, **kwargs: Any) -> T: - """ADDME""" - ... diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py deleted file mode 100644 index 591f9a83101..00000000000 --- a/torchvision/prototype/transforms/functional/_utils.py +++ /dev/null @@ -1,89 +0,0 @@ -import functools -import inspect -from typing import Any, Optional, Callable, TypeVar, Dict - -import torch -import torch.overrides -from torchvision.prototype import features - -F = TypeVar("F", bound=features._Feature) - - -def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]: - """Decorates a function to automatically dispatch to registered kernels based on the call arguments. - - The dispatch function should have this signature - - .. code:: python - - @dispatch( - ... - ) - def dispatch_fn(input, *args, **kwargs): - ... - - where ``input`` is used to determine which kernel to dispatch to. - - Args: - kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for - exact type matches first and if none is found falls back to checking for subclasses. If a value is - ``None``, the decorated function is called. - - Raises: - TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``. - TypeError: If the decorated function is called with an input that cannot be dispatched. - """ - - def check_kernel(kernel: Any) -> bool: - if kernel is None: - return True - - if not callable(kernel): - return False - - params = list(inspect.signature(kernel).parameters.values()) - if not params: - return False - - return params[0].kind != inspect.Parameter.KEYWORD_ONLY - - for feature_type, kernel in kernels.items(): - if not check_kernel(kernel): - raise TypeError( - f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)." - ) - - def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]: - @functools.wraps(dispatch_fn) - def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: - feature_type = type(input) - try: - kernel = kernels[feature_type] - except KeyError: - try: - feature_type, kernel = next( - (feature_type, kernel) - for feature_type, kernel in kernels.items() - if isinstance(input, feature_type) - ) - except StopIteration: - raise TypeError(f"No support for {type(input).__name__}") from None - - if kernel is None: - output = dispatch_fn(input, *args, **kwargs) - if output is None: - raise RuntimeError( - f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} " - f"although it was configured to do so." - ) - else: - output = kernel(input, *args, **kwargs) - - if issubclass(feature_type, features._Feature) and type(output) is torch.Tensor: - output = feature_type.new_like(input, output) - - return output - - return inner_wrapper - - return outer_wrapper diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/prototype/transforms/utils.py new file mode 100644 index 00000000000..cc5949a17af --- /dev/null +++ b/torchvision/prototype/transforms/utils.py @@ -0,0 +1,91 @@ +import functools +from typing import Callable, Tuple, TypeVar, Optional, Any, cast, Dict + +import PIL.Image +import torch +from torchvision.prototype import features, transforms +from torchvision.prototype.utils._internal import query_recursively + +T = TypeVar("T") + + +def legacy_transform(kernel: Callable, *attrs: str) -> Callable[[Callable], Callable]: + def outer_wrapper(fn: Callable) -> Any: + @functools.wraps(fn) + def inner_wrapper(self: transforms.Transform, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): + return kernel(input, **params, **{attr: getattr(self, attr) for attr in attrs}) + + return fn(self, input, params) + + return inner_wrapper + + return outer_wrapper + + +class Query: + def __init__(self, sample: Any) -> None: + self.sample = sample + + def query(self, fn: Callable[[Any], Optional[T]]) -> T: + try: + return next(query_recursively(fn, self.sample)) + except StopIteration: + raise RuntimeError from None + + def image(self) -> features.Image: + def fn(sample: Any) -> Optional[features.Image]: + if isinstance(sample, features.Image): + return sample + else: + return None + + return self.query(fn) + + def image_size(self) -> Tuple[int, int]: + def fn(sample: Any) -> Optional[Tuple[int, int]]: + if isinstance(sample, (features.Image, features.BoundingBox)): + return sample.image_size + elif isinstance(sample, torch.Tensor): + return cast(Tuple[int, int], sample.shape[-2:]) + elif isinstance(sample, PIL.Image.Image): + return sample.height, sample.width + else: + return None + + return self.query(fn) + + def image_for_size_extraction(self) -> features.Image: + def fn(sample: Any) -> Optional[features.Image]: + if isinstance(sample, features.Image): + return sample + + if isinstance(sample, features.BoundingBox): + height, width = sample.image_size + elif isinstance(sample, torch.Tensor): + height, width = sample.shape[-2:] + elif isinstance(sample, PIL.Image.Image): + height, width = sample.height, sample.width + else: + return None + + return features.Image(torch.empty(0, height, width)) + + return self.query(fn) + + def image_for_size_and_channels_extraction(self) -> features.Image: + def fn(sample: Any) -> Optional[features.Image]: + if isinstance(sample, features.Image): + return sample + + if isinstance(sample, torch.Tensor): + num_channels, height, width = sample.shape[-3:] + elif isinstance(sample, PIL.Image.Image): + height, width = sample.height, sample.width + num_channels = len(sample.num_bands()) + else: + return None + + return features.Image(torch.empty(0, num_channels, height, width)) + + return self.query(fn) diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index fe75c19eb75..2e38471ea65 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -24,8 +24,7 @@ Tuple, TypeVar, Union, - List, - Dict, + Optional, ) import numpy as np @@ -42,6 +41,7 @@ "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", + "query_recursively", ] @@ -305,22 +305,22 @@ def apply_recursively(fn: Callable, obj: Any) -> Any: # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: # "a" == "a"[0][0]... if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): - sequence: List[Any] = [] - for item in obj: - result = apply_recursively(fn, item) - if isinstance(result, collections.abc.Sequence) and hasattr(result, "__inline__"): - sequence.extend(result) - else: - sequence.append(result) - return sequence + return [apply_recursively(fn, item) for item in obj] elif isinstance(obj, collections.abc.Mapping): - mapping: Dict[Any, Any] = {} - for name, item in obj.items(): - result = apply_recursively(fn, item) - if isinstance(result, collections.abc.Mapping) and hasattr(result, "__inline__"): - mapping.update(result) - else: - mapping[name] = result - return mapping + return {key: apply_recursively(fn, item) for key, item in obj.items()} else: return fn(obj) + + +def query_recursively(fn: Callable[[Any], Optional[D]], obj: Any) -> Iterator[D]: + # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: + # "a" == "a"[0][0]... + if (isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str)) or isinstance( + obj, collections.abc.Mapping + ): + for item in obj.values() if isinstance(obj, collections.abc.Mapping) else obj: + yield from query_recursively(fn, item) + else: + result = fn(obj) + if result is not None: + yield result From fe82e9411584feb2ace652cf7adc0731d20f7cfc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 14 Feb 2022 22:54:41 +0100 Subject: [PATCH 02/25] cleanup --- torchvision/prototype/transforms/_geometry.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 60906e0f841..bdffcb90c55 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -32,12 +32,13 @@ def __init__( self.interpolation = interpolation @legacy_transform(_F.resize, "size", "interpolation") - def _dispatch(self, input: Any, params: Dict[str, Any]) -> Any: + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: output = K.resize_image(input, size=self.size, interpolation=self.interpolation) return features.Image.new_like(input, output) elif type(input) is features.SegmentationMask: - return features.SegmentationMask.new_like(input, K.resize_segmentation_mask(input, size=self.size)) + output = K.resize_segmentation_mask(input, size=self.size) + return features.SegmentationMask.new_like(input, output) elif type(input) is features.BoundingBox: output = K.resize_bounding_box(input, size=self.size, image_size=input.image_size) return features.BoundingBox.new_like(input, output, image_size=self.size) @@ -54,7 +55,7 @@ def __init__(self, output_size: List[int]): self.output_size = output_size @legacy_transform(_F.center_crop, "output_size") - def _dispatch(self, input: Any, params: Dict[str, Any]) -> Any: + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: output = K.center_crop_image(input, **params) return features.Image.new_like(input, output) From 36f3e0d1c06267238bbfe58c197fd5b4b77158d9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 10:26:01 +0100 Subject: [PATCH 03/25] remove legacy_transform decorator --- torchvision/prototype/transforms/_geometry.py | 19 ++++++++++++------- torchvision/prototype/transforms/_misc.py | 5 ++--- torchvision/prototype/transforms/utils.py | 19 ++----------------- 3 files changed, 16 insertions(+), 27 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index bdffcb90c55..05145bf69f3 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,17 +1,20 @@ from typing import Any, Dict, List, Union, Sequence, Tuple +import PIL.Image +import torch from torchvision import transforms as _transforms from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, kernels as K from torchvision.transforms import functional as _F -from .utils import Query, legacy_transform +from .utils import Query class HorizontalFlip(Transform): - @legacy_transform(_F.hflip) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: + if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): + return _F.hflip(input) + elif type(input) is features.Image: output = K.horizontal_flip_image(input) return features.Image.new_like(input, output) elif type(input) is features.BoundingBox: @@ -31,9 +34,10 @@ def __init__( self.size = [size, size] if isinstance(size, int) else list(size) self.interpolation = interpolation - @legacy_transform(_F.resize, "size", "interpolation") def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: + if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): + return _F.resize(input, size=self.size, interpolation=self.interpolation) + elif type(input) is features.Image: output = K.resize_image(input, size=self.size, interpolation=self.interpolation) return features.Image.new_like(input, output) elif type(input) is features.SegmentationMask: @@ -54,9 +58,10 @@ def __init__(self, output_size: List[int]): super().__init__() self.output_size = output_size - @legacy_transform(_F.center_crop, "output_size") def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: + if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): + return _F.center_crop(input, output_size=self.output_size) + elif type(input) is features.Image: output = K.center_crop_image(input, **params) return features.Image.new_like(input, output) else: diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 701b7441da3..d9e9d963c3f 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -6,8 +6,6 @@ from torchvision.prototype.transforms import Transform, kernels as K from torchvision.transforms import functional as _F -from .utils import legacy_transform - class Identity(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: @@ -44,8 +42,9 @@ def __init__(self, mean: List[float], std: List[float]): self.mean = mean self.std = std - @legacy_transform(_F.normalize, "mean", "std") def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if type(input) is torch.Tensor: + return _F.normalize(input, mean=self.mean, std=self.std) if type(input) is features.Image: output = K.normalize_image(input, **params) return features.Image.new_like(input, output) diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/prototype/transforms/utils.py index cc5949a17af..796f4496882 100644 --- a/torchvision/prototype/transforms/utils.py +++ b/torchvision/prototype/transforms/utils.py @@ -1,28 +1,13 @@ -import functools -from typing import Callable, Tuple, TypeVar, Optional, Any, cast, Dict +from typing import Callable, Tuple, TypeVar, Optional, Any, cast import PIL.Image import torch -from torchvision.prototype import features, transforms +from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively T = TypeVar("T") -def legacy_transform(kernel: Callable, *attrs: str) -> Callable[[Callable], Callable]: - def outer_wrapper(fn: Callable) -> Any: - @functools.wraps(fn) - def inner_wrapper(self: transforms.Transform, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): - return kernel(input, **params, **{attr: getattr(self, attr) for attr in attrs}) - - return fn(self, input, params) - - return inner_wrapper - - return outer_wrapper - - class Query: def __init__(self, sample: Any) -> None: self.sample = sample From 757fbedcc4fa51645a95e2ca646676681dbc5a88 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 11:22:09 +0100 Subject: [PATCH 04/25] remove legacy classes --- torchvision/prototype/transforms/_augment.py | 65 ++++++++++--- .../prototype/transforms/_auto_augment.py | 95 ++++++++++++++++++- torchvision/prototype/transforms/_geometry.py | 83 ++++++++++++---- 3 files changed, 208 insertions(+), 35 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index c7324f2b2a4..129640f07c1 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -1,8 +1,9 @@ import math +import numbers +import warnings from typing import Any, Dict, Tuple import torch -from torchvision import transforms as _transforms from torchvision.prototype import features from torchvision.prototype.transforms import Transform, kernels as K from torchvision.transforms import functional as _F @@ -11,8 +12,6 @@ class RandomErasing(Transform): - _LEGACY_TRANSFORM_CLS = _transforms.RandomErasing - def __init__( self, p: float = 0.5, @@ -21,15 +20,29 @@ def __init__( value: float = 0, ): super().__init__() - legacy_transform = self._LEGACY_TRANSFORM_CLS(p=p, scale=scale, ratio=ratio, value=value, inplace=False) + if not isinstance(value, (numbers.Number, str, tuple, list)): + raise TypeError("Argument value should be either a number or str or a sequence") + if isinstance(value, str) and value != "random": + raise ValueError("If value is str, it should be 'random'") + if not isinstance(scale, (tuple, list)): + raise TypeError("Scale should be a sequence") + if not isinstance(ratio, (tuple, list)): + raise TypeError("Ratio should be a sequence") + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + if scale[0] < 0 or scale[1] > 1: + raise ValueError("Scale should be between 0 and 1") + if p < 0 or p > 1: + raise ValueError("Random erasing probability should be between 0 and 1") # TODO: deprecate p in favor of wrapping the transform in a RandomApply - self.p = legacy_transform.p - self.scale = legacy_transform.scale - self.ratio = legacy_transform.ratio - self.value = legacy_transform.value + self.p = p + self.scale = scale + self.ratio = ratio + self.value = value def get_params(self, sample: Any) -> Dict[str, Any]: - image = Query(sample).image_for_size_and_channels_extraction() + image = Query(sample).image() + img_c, (img_h, img_w) = image.num_channels, image.image_size if isinstance(self.value, (int, float)): value = [self.value] @@ -40,15 +53,41 @@ def get_params(self, sample: Any) -> Dict[str, Any]: else: value = self.value - if value is not None and not (len(value) in (1, image.shape[-3])): + if value is not None and not (len(value) in (1, img_c)): raise ValueError( "If value is a sequence, it should have either a single value or " f"{image.shape[-3]} (number of input channels)" ) - return dict( - zip("ijhwv", self._LEGACY_TRANSFORM_CLS.get_params(image, scale=self.scale, ratio=self.ratio, value=value)) - ) + area = img_h * img_w + + log_ratio = torch.log(torch.tensor(self.ratio)) + for _ in range(10): + erase_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + h = int(round(math.sqrt(erase_area * aspect_ratio))) + w = int(round(math.sqrt(erase_area / aspect_ratio))) + if not (h < img_h and w < img_w): + continue + + if value is None: + v = torch.empty([img_c, h, w], dtype=torch.float32).normal_() + else: + v = torch.tensor(value)[:, None, None] + + i = torch.randint(0, img_h - h + 1, size=(1,)).item() + j = torch.randint(0, img_w - w + 1, size=(1,)).item() + break + else: + i, j, h, w, v = 0, 0, img_h, img_w, image + + return dict(zip("ijhwv", (i, j, h, w, v))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if torch.rand(1) >= self.p: diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 77a0d470e29..8d11d5d54d4 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -7,7 +7,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, kernels as K from torchvision.prototype.utils._internal import apply_recursively -from torchvision.transforms import AutoAugment as _AutoAugment, functional as _F +from torchvision.transforms import functional as _F from .utils import Query @@ -157,7 +157,6 @@ class AugmentationMeta: class AutoAugment(_AutoAugmentBase): - _LEGACY_CLS = _AutoAugment _AUGMENTATION_SPACE = ( AugmentationMeta("ShearX", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), AugmentationMeta("ShearY", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), @@ -195,7 +194,97 @@ class AutoAugment(_AutoAugmentBase): def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: super().__init__(**kwargs) self.policy = policy - self._policies = self._LEGACY_CLS._get_policies(None, policy) # type: ignore[arg-type] + self._policies = self._get_policies(policy) + + def _get_policies( + self, policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: + raise ValueError(f"The provided policy {policy} is not recognized.") def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: policy = self._policies[int(torch.randint(len(self._policies), ()))] diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 05145bf69f3..c8437ab1ec3 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,11 +1,13 @@ -from typing import Any, Dict, List, Union, Sequence, Tuple +import math +import warnings +from typing import Any, Dict, List, Union, Sequence, Tuple, cast import PIL.Image import torch -from torchvision import transforms as _transforms from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, kernels as K from torchvision.transforms import functional as _F +from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from .utils import Query @@ -72,8 +74,6 @@ def extra_repr(self) -> str: class RandomResizedCrop(Transform): - _LEGACY_CLS = _transforms.RandomResizedCrop - def __init__( self, size: Union[int, Sequence[int]], @@ -82,23 +82,68 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - legacy_transform = self._LEGACY_CLS(size=size, scale=scale, ratio=ratio, interpolation=interpolation) - self.size = legacy_transform.size - self.scale = legacy_transform.scale - self.ratio = legacy_transform.ratio - self.interpolation = legacy_transform.interpolation + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if not isinstance(scale, Sequence): + raise TypeError("Scale should be a sequence") + scale = cast(Tuple[float, float], scale) + if not isinstance(ratio, Sequence): + raise TypeError("Ratio should be a sequence") + ratio = cast(Tuple[float, float], ratio) + if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): + warnings.warn("Scale and ratio should be of kind (min, max)") + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationMode instead of int. " + "Please, use InterpolationMode enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + + self.size = size + self.scale = scale + self.ratio = ratio + self.interpolation = interpolation def get_params(self, sample: Any) -> Dict[str, Any]: - image = Query(sample).image_for_size_extraction() - top, left, height, width = _transforms.RandomResizedCrop.get_params( - image, scale=list(self.scale), ratio=list(self.ratio) - ) - return dict( - top=top, - left=left, - height=height, - width=width, - ) + image = Query(sample).image() + height, width = image.image_size + area = height * width + + log_ratio = torch.log(torch.tensor(self.ratio)) + for _ in range(10): + target_area = area * torch.empty(1).uniform_(self.scale[0], self.scale[1]).item() + aspect_ratio = torch.exp( + torch.empty(1).uniform_( + log_ratio[0], # type: ignore[arg-type] + log_ratio[1], # type: ignore[arg-type] + ) + ).item() + + w = int(round(math.sqrt(target_area * aspect_ratio))) + h = int(round(math.sqrt(target_area / aspect_ratio))) + + if 0 < w <= width and 0 < h <= height: + i = torch.randint(0, height - h + 1, size=(1,)).item() + j = torch.randint(0, width - w + 1, size=(1,)).item() + break + else: + # Fallback to central crop + in_ratio = float(width) / float(height) + if in_ratio < min(self.ratio): + w = width + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = height + w = int(round(h * max(self.ratio))) + else: # whole image + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 + + return dict(top=i, left=j, height=h, width=w) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: From dc6127124886272be81b724d8b0eb64438440925 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 18:30:11 +0100 Subject: [PATCH 05/25] remove explicit param passing --- torchvision/prototype/transforms/_augment.py | 6 +++--- torchvision/prototype/transforms/_auto_augment.py | 4 ++-- torchvision/prototype/transforms/_container.py | 12 ++++++------ torchvision/prototype/transforms/_geometry.py | 2 +- torchvision/prototype/transforms/_transform.py | 13 +++++++++---- 5 files changed, 21 insertions(+), 16 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 129640f07c1..946e5a28f8f 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -40,7 +40,7 @@ def __init__( self.ratio = ratio self.value = value - def get_params(self, sample: Any) -> Dict[str, Any]: + def _get_params(self, sample: Any) -> Dict[str, Any]: image = Query(sample).image() img_c, (img_h, img_w) = image.num_channels, image.image_size @@ -110,7 +110,7 @@ def __init__(self, *, alpha: float) -> None: self.alpha = alpha self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) - def get_params(self, sample: Any) -> Dict[str, Any]: + def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) def _supports(self, obj: Any) -> bool: @@ -136,7 +136,7 @@ def __init__(self, *, alpha: float) -> None: self.alpha = alpha self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) - def get_params(self, sample: Any) -> Dict[str, Any]: + def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) H, W = Query(sample).image_size() diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 8d11d5d54d4..5652bccbe9b 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -104,7 +104,7 @@ def __init__( "Invert": AutoAugmentDispatcher(K.invert_image, _F.invert), } - def get_params(self, sample: Any) -> Dict[str, Any]: + def _get_params(self, sample: Any) -> Dict[str, Any]: image = Query(sample).image_for_size_and_channels_extraction() fill = self.fill @@ -120,7 +120,7 @@ def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self.get_params(sample) + params = params or self._get_params(sample) for transform_id, magnitude in self.get_transforms_meta(Query(sample).image_size()): dispatcher = self._DISPATCHER_MAP[transform_id] diff --git a/torchvision/prototype/transforms/_container.py b/torchvision/prototype/transforms/_container.py index 3a183a7b884..bd20d0c701a 100644 --- a/torchvision/prototype/transforms/_container.py +++ b/torchvision/prototype/transforms/_container.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Dict +from typing import Any import torch @@ -12,7 +12,7 @@ def __init__(self, *transforms: Transform) -> None: for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) - def forward(self, *inputs: Any) -> Any: # type: ignore[override] + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] for transform in self.transforms: sample = transform(sample) @@ -25,12 +25,12 @@ def __init__(self, transform: Transform, *, p: float = 0.5) -> None: self.transform = transform self.p = p - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] if float(torch.rand(())) < self.p: return sample - return self.transform(sample, params=params) + return self.transform(sample) def extra_repr(self) -> str: return f"p={self.p}" @@ -43,7 +43,7 @@ def __init__(self, *transforms: Transform) -> None: for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) - def forward(self, *inputs: Any) -> Any: # type: ignore[override] + def forward(self, *inputs: Any) -> Any: idx = int(torch.randint(len(self.transforms), size=())) transform = self.transforms[idx] return transform(*inputs) @@ -56,7 +56,7 @@ def __init__(self, *transforms: Transform) -> None: for idx, transform in enumerate(transforms): self.add_module(str(idx), transform) - def forward(self, *inputs: Any) -> Any: # type: ignore[override] + def forward(self, *inputs: Any) -> Any: for idx in torch.randperm(len(self.transforms)): transform = self.transforms[idx] inputs = transform(*inputs) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index c8437ab1ec3..91aeb44ef93 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -106,7 +106,7 @@ def __init__( self.ratio = ratio self.interpolation = interpolation - def get_params(self, sample: Any) -> Dict[str, Any]: + def _get_params(self, sample: Any) -> Dict[str, Any]: image = Query(sample).image() height, width = image.image_size area = height * width diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index 3e535c25a0d..c8881a817ef 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,20 +1,25 @@ import functools -from typing import Any, Dict, Optional +from typing import Any, Dict from torch import nn from torchvision.prototype.utils._internal import apply_recursively +from torchvision.utils import _log_api_usage_once class Transform(nn.Module): - def get_params(self, sample: Any) -> Dict[str, Any]: + def __init__(self) -> None: + super().__init__() + _log_api_usage_once(self) + + def _get_params(self, sample: Any) -> Dict[str, Any]: return dict() def _transform(self, input: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - return apply_recursively(functools.partial(self._transform, params=params or self.get_params(sample)), sample) + return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample) def _extra_repr_from_attrs(self, *names: str) -> str: return ", ".join(f"{name}={getattr(self, name)}" for name in names) From c7c46084e4499b1f8ff26866e41f33609b303b1a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 18:32:12 +0100 Subject: [PATCH 06/25] streamline extra_repr --- torchvision/prototype/transforms/_augment.py | 9 --------- torchvision/prototype/transforms/_geometry.py | 9 --------- .../prototype/transforms/_meta_conversion.py | 9 --------- torchvision/prototype/transforms/_misc.py | 3 --- torchvision/prototype/transforms/_transform.py | 15 +++++++++++++-- 5 files changed, 13 insertions(+), 32 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 946e5a28f8f..e009e17d898 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -100,9 +100,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("p", "scale", "ratio", "value") - class RandomMixup(Transform): def __init__(self, *, alpha: float) -> None: @@ -126,9 +123,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("alpha") - class RandomCutmix(Transform): def __init__(self, *, alpha: float) -> None: @@ -170,6 +164,3 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.OneHotLabel.new_like(input, output) else: return input - - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("alpha") diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 91aeb44ef93..9bfb112bb56 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -51,9 +51,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("size", "interpolation") - class CenterCrop(Transform): def __init__(self, output_size: List[int]): @@ -69,9 +66,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("output_size") - class RandomResizedCrop(Transform): def __init__( @@ -151,6 +145,3 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(input, output) else: return input - - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("size", "scale", "ratio", "interpolation") diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta_conversion.py index 9ed830cdadd..41027af7b34 100644 --- a/torchvision/prototype/transforms/_meta_conversion.py +++ b/torchvision/prototype/transforms/_meta_conversion.py @@ -20,9 +20,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("format") - class ConvertImageDtype(Transform): def __init__(self, dtype: torch.dtype = torch.float32) -> None: @@ -36,9 +33,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("dtype") - class ConvertColorSpace(Transform): def __init__(self, color_space: Union[str, features.ColorSpace]) -> None: @@ -53,6 +47,3 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(input, output, color_space=self.color_space) else: return input - - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("color_space") diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index d9e9d963c3f..e519c376a8a 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -49,9 +49,6 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: output = K.normalize_image(input, **params) return features.Image.new_like(input, output) - def extra_repr(self) -> str: - return self._extra_repr_from_attrs("mean", "std") - class ToDtype(Lambda): def __init__(self, dtype: torch.dtype, *types: Type) -> None: diff --git a/torchvision/prototype/transforms/_transform.py b/torchvision/prototype/transforms/_transform.py index c8881a817ef..923b90c6777 100644 --- a/torchvision/prototype/transforms/_transform.py +++ b/torchvision/prototype/transforms/_transform.py @@ -1,3 +1,4 @@ +import enum import functools from typing import Any, Dict @@ -21,5 +22,15 @@ def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample) - def _extra_repr_from_attrs(self, *names: str) -> str: - return ", ".join(f"{name}={getattr(self, name)}" for name in names) + def extra_repr(self) -> str: + extra = [] + for name, value in self.__dict__.items(): + if name.startswith("_") or name == "training": + continue + + if not isinstance(value, (bool, int, float, str, tuple, list, enum.Enum)): + continue + + extra.append(f"{name}={value}") + + return ", ".join(extra) From 13d49cb08a75dcb8378346f3c9067a05504047f8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 18:36:36 +0100 Subject: [PATCH 07/25] remove obsolete ._supports() method --- torchvision/prototype/transforms/_augment.py | 6 ------ torchvision/prototype/transforms/_misc.py | 3 --- torchvision/prototype/transforms/_type_conversion.py | 3 --- 3 files changed, 12 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index e009e17d898..cb5adc5039d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -110,9 +110,6 @@ def __init__(self, *, alpha: float) -> None: def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) - def _supports(self, obj: Any) -> bool: - return type(obj) in {features.Image, features.OneHotLabel} - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: output = K.mixup_image(input, **params) @@ -152,9 +149,6 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) - def _supports(self, obj: Any) -> bool: - return type(obj) in {features.Image, features.OneHotLabel} - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: output = K.cutmix_image(input, box=params["box"]) diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index e519c376a8a..aa95e29b69e 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -18,9 +18,6 @@ def __init__(self, fn: Callable[[Any], Any], *types: Type): self.fn = fn self.types = types - def _supports(self, obj: Any) -> bool: - return type(obj) in self.types - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) in self.types: return self.fn(input) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 8efd356768a..26120d660d6 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -5,9 +5,6 @@ class DecodeImage(Transform): - def _supports(self, obj: Any) -> bool: - return isinstance(obj, features.EncodedImage) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.EncodedImage: output = K.decode_image_with_pil(input) From 4771e2518a436dd48d492b54881e508154db3537 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 15 Feb 2022 18:39:12 +0100 Subject: [PATCH 08/25] cleanup --- torchvision/prototype/transforms/kernels/_geometry.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py index 72afc2e62a3..9880ca6a685 100644 --- a/torchvision/prototype/transforms/kernels/_geometry.py +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -1,4 +1,4 @@ -from typing import Tuple, List, Optional, TypeVar +from typing import Tuple, List, Optional import torch from torchvision.prototype import features @@ -7,9 +7,6 @@ from ._meta_conversion import convert_bounding_box_format -T = TypeVar("T", bound=features._Feature) - - horizontal_flip_image = _F.hflip From c393a430999d1772539b816e3cc40842fc757c6f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Feb 2022 14:18:30 +0100 Subject: [PATCH 09/25] remove Query --- torchvision/prototype/transforms/_augment.py | 10 ++- .../prototype/transforms/_auto_augment.py | 13 ++-- torchvision/prototype/transforms/_geometry.py | 6 +- torchvision/prototype/transforms/_utils.py | 24 +++++- torchvision/prototype/transforms/utils.py | 76 ------------------- 5 files changed, 40 insertions(+), 89 deletions(-) delete mode 100644 torchvision/prototype/transforms/utils.py diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index cb5adc5039d..a8c1c647121 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -8,7 +8,7 @@ from torchvision.prototype.transforms import Transform, kernels as K from torchvision.transforms import functional as _F -from .utils import Query +from ._utils import query_image, get_image_size, get_image_num_channels class RandomErasing(Transform): @@ -41,8 +41,9 @@ def __init__( self.value = value def _get_params(self, sample: Any) -> Dict[str, Any]: - image = Query(sample).image() - img_c, (img_h, img_w) = image.num_channels, image.image_size + image = query_image(sample) + img_c = get_image_num_channels(image) + img_h, img_w = get_image_size(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -130,7 +131,8 @@ def __init__(self, *, alpha: float) -> None: def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) - H, W = Query(sample).image_size() + image = query_image(sample) + H, W = get_image_size(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 5652bccbe9b..a7dbcee6230 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.utils._internal import apply_recursively from torchvision.transforms import functional as _F -from .utils import Query +from ._utils import query_image, get_image_size, get_image_num_channels @dataclasses.dataclass @@ -105,11 +105,11 @@ def __init__( } def _get_params(self, sample: Any) -> Dict[str, Any]: - image = Query(sample).image_for_size_and_channels_extraction() - fill = self.fill if isinstance(fill, (int, float)): - fill = [float(fill)] * image.num_channels + image = query_image(sample) + num_channels = get_image_num_channels(image) + fill = [float(fill)] * num_channels elif fill is not None: fill = [float(f) for f in fill] @@ -122,7 +122,10 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] params = params or self._get_params(sample) - for transform_id, magnitude in self.get_transforms_meta(Query(sample).image_size()): + image = query_image(sample) + image_size = get_image_size(image) + + for transform_id, magnitude in self.get_transforms_meta(image_size): dispatcher = self._DISPATCHER_MAP[transform_id] def transform(input: Any) -> Any: diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 9bfb112bb56..046504982ba 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -9,7 +9,7 @@ from torchvision.transforms import functional as _F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from .utils import Query +from ._utils import query_image, get_image_size class HorizontalFlip(Transform): @@ -101,8 +101,8 @@ def __init__( self.interpolation = interpolation def _get_params(self, sample: Any) -> Dict[str, Any]: - image = Query(sample).image() - height, width = image.image_size + image = query_image(sample) + height, width = get_image_size(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 7f29d817499..4e900d38caf 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Union, Optional +from typing import Any, Optional, Tuple, cast, Union import PIL.Image import torch @@ -17,3 +17,25 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima return next(query_recursively(fn, sample)) except StopIteration: raise TypeError("No image was found in the sample") + + +def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: + if type(image) is torch.Tensor: + return cast(Tuple[int, int], image.shape[-2:]) + elif isinstance(image, PIL.Image.Image): + return image.height, image.width + elif type(image) in {features.Image, features.BoundingBox}: + return cast(Union[features.Image, features.BoundingBox], image).image_size + else: + raise TypeError(f"unable to get image size from object of type {type(image).__name__}") + + +def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: + if type(image) is torch.Tensor: + return image.shape[-3] + elif isinstance(image, PIL.Image.Image): + return len(image.getbands()) + elif type(image) is features.Image: + return image.num_channels + else: + raise TypeError(f"unable to get image size from object of type {type(image).__name__}") diff --git a/torchvision/prototype/transforms/utils.py b/torchvision/prototype/transforms/utils.py deleted file mode 100644 index 796f4496882..00000000000 --- a/torchvision/prototype/transforms/utils.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Callable, Tuple, TypeVar, Optional, Any, cast - -import PIL.Image -import torch -from torchvision.prototype import features -from torchvision.prototype.utils._internal import query_recursively - -T = TypeVar("T") - - -class Query: - def __init__(self, sample: Any) -> None: - self.sample = sample - - def query(self, fn: Callable[[Any], Optional[T]]) -> T: - try: - return next(query_recursively(fn, self.sample)) - except StopIteration: - raise RuntimeError from None - - def image(self) -> features.Image: - def fn(sample: Any) -> Optional[features.Image]: - if isinstance(sample, features.Image): - return sample - else: - return None - - return self.query(fn) - - def image_size(self) -> Tuple[int, int]: - def fn(sample: Any) -> Optional[Tuple[int, int]]: - if isinstance(sample, (features.Image, features.BoundingBox)): - return sample.image_size - elif isinstance(sample, torch.Tensor): - return cast(Tuple[int, int], sample.shape[-2:]) - elif isinstance(sample, PIL.Image.Image): - return sample.height, sample.width - else: - return None - - return self.query(fn) - - def image_for_size_extraction(self) -> features.Image: - def fn(sample: Any) -> Optional[features.Image]: - if isinstance(sample, features.Image): - return sample - - if isinstance(sample, features.BoundingBox): - height, width = sample.image_size - elif isinstance(sample, torch.Tensor): - height, width = sample.shape[-2:] - elif isinstance(sample, PIL.Image.Image): - height, width = sample.height, sample.width - else: - return None - - return features.Image(torch.empty(0, height, width)) - - return self.query(fn) - - def image_for_size_and_channels_extraction(self) -> features.Image: - def fn(sample: Any) -> Optional[features.Image]: - if isinstance(sample, features.Image): - return sample - - if isinstance(sample, torch.Tensor): - num_channels, height, width = sample.shape[-3:] - elif isinstance(sample, PIL.Image.Image): - height, width = sample.height, sample.width - num_channels = len(sample.num_bands()) - else: - return None - - return features.Image(torch.empty(0, num_channels, height, width)) - - return self.query(fn) From e7502edeebd0a6665b3f4b4ad7806b106d647503 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Feb 2022 14:27:10 +0100 Subject: [PATCH 10/25] cleanup --- test/test_prototype_transforms.py | 29 -------------------- torchvision/prototype/transforms/_augment.py | 9 ++++-- torchvision/prototype/transforms/_misc.py | 2 +- 3 files changed, 7 insertions(+), 33 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 9aa0688e7a0..42a679a613e 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -141,35 +141,6 @@ def test_auto_augment(self, transform, input): def test_normalize(self, transform, input): transform(input) - @parametrize( - [ - ( - transforms.ConvertColorSpace("grayscale"), - itertools.chain( - make_images(), - make_vanilla_tensor_images(color_spaces=["rgb"]), - make_pil_images(color_spaces=["rgb"]), - ), - ) - ] - ) - def test_convert_bounding_color_space(self, transform, input): - transform(input) - - @parametrize( - [ - ( - transforms.ConvertBoundingBoxFormat("xyxy", old_format="xywh"), - itertools.chain( - make_bounding_boxes(), - make_vanilla_tensor_bounding_boxes(formats=["xywh"]), - ), - ) - ] - ) - def test_convert_bounding_box_format(self, transform, input): - transform(input) - @parametrize( [ ( diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index a8c1c647121..afccfc58279 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -91,9 +91,6 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(zip("ijhwv", (i, j, h, w, v))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if torch.rand(1) >= self.p: - return input - if type(input) is torch.Tensor: return _F.erase(input, **params) elif type(input) is features.Image: @@ -101,6 +98,12 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: else: return input + def forward(self, *inputs: Any) -> Any: + if torch.rand(1) >= self.p: + return inputs if len(inputs) > 1 else inputs[0] + + return super().forward(*inputs) + class RandomMixup(Transform): def __init__(self, *, alpha: float) -> None: diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index aa95e29b69e..f8c5e4c41b6 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -43,7 +43,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is torch.Tensor: return _F.normalize(input, mean=self.mean, std=self.std) if type(input) is features.Image: - output = K.normalize_image(input, **params) + output = K.normalize_image(input, mean=self.mean, std=self.std) return features.Image.new_like(input, output) From fd752a6cf46fa20fa67cd5a5b0106828c368a2b2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 16 Feb 2022 14:59:05 +0100 Subject: [PATCH 11/25] fix tests --- test/test_prototype_transforms.py | 32 ++++++++++------------ torchvision/prototype/transforms/_utils.py | 4 +-- 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 42a679a613e..3156e584a8d 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,6 +1,5 @@ import itertools -import PIL.Image import pytest import torch from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels @@ -25,15 +24,6 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs): yield bounding_box.data -INPUT_CREATIONS_FNS = { - features.Image: make_images, - features.BoundingBox: make_bounding_boxes, - features.OneHotLabel: make_one_hot_labels, - torch.Tensor: make_vanilla_tensor_images, - PIL.Image.Image: make_pil_images, -} - - def parametrize(transforms_with_inputs): return pytest.mark.parametrize( ("transform", "input"), @@ -52,15 +42,21 @@ def parametrize(transforms_with_inputs): def parametrize_from_transforms(*transforms): transforms_with_inputs = [] for transform in transforms: - dispatcher = transform._DISPATCHER - if dispatcher is None: - continue - - for type_ in dispatcher._kernels: + for creation_fn in [ + make_images, + make_bounding_boxes, + make_one_hot_labels, + make_vanilla_tensor_images, + make_pil_images, + ]: + inputs = list(creation_fn()) try: - inputs = INPUT_CREATIONS_FNS[type_]() - except KeyError: + output = transform(inputs[0]) + except Exception: continue + else: + if output is inputs[0]: + continue transforms_with_inputs.append((transform, inputs)) @@ -69,7 +65,7 @@ def parametrize_from_transforms(*transforms): class TestSmoke: @parametrize_from_transforms( - transforms.RandomErasing(), + transforms.RandomErasing(p=1.0), transforms.HorizontalFlip(), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 4e900d38caf..31f2575c3d3 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -24,8 +24,8 @@ def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) return cast(Tuple[int, int], image.shape[-2:]) elif isinstance(image, PIL.Image.Image): return image.height, image.width - elif type(image) in {features.Image, features.BoundingBox}: - return cast(Union[features.Image, features.BoundingBox], image).image_size + elif type(image) is features.Image: + return image.image_size else: raise TypeError(f"unable to get image size from object of type {type(image).__name__}") From 283c4749133b3c726a36c994c161e4f32de8e7cc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 21 Feb 2022 09:09:41 +0100 Subject: [PATCH 12/25] kernels -> functional --- test/test_prototype_transforms.py | 2 +- ...> test_prototype_transforms_functional.py} | 38 +++++++++---------- .../prototype/features/_bounding_box.py | 2 +- torchvision/prototype/features/_encoded.py | 2 +- torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_augment.py | 12 +++--- .../prototype/transforms/_auto_augment.py | 30 +++++++-------- torchvision/prototype/transforms/_geometry.py | 16 ++++---- .../prototype/transforms/_meta_conversion.py | 6 +-- torchvision/prototype/transforms/_misc.py | 4 +- .../prototype/transforms/_type_conversion.py | 6 +-- .../{kernels => functional}/__init__.py | 0 .../{kernels => functional}/_augment.py | 0 .../{kernels => functional}/_color.py | 0 .../{kernels => functional}/_geometry.py | 0 .../_meta_conversion.py | 0 .../{kernels => functional}/_misc.py | 0 .../_type_conversion.py | 0 18 files changed, 60 insertions(+), 60 deletions(-) rename test/{test_prototype_transforms_kernels.py => test_prototype_transforms_functional.py} (85%) rename torchvision/prototype/transforms/{kernels => functional}/__init__.py (100%) rename torchvision/prototype/transforms/{kernels => functional}/_augment.py (100%) rename torchvision/prototype/transforms/{kernels => functional}/_color.py (100%) rename torchvision/prototype/transforms/{kernels => functional}/_geometry.py (100%) rename torchvision/prototype/transforms/{kernels => functional}/_meta_conversion.py (100%) rename torchvision/prototype/transforms/{kernels => functional}/_misc.py (100%) rename torchvision/prototype/transforms/{kernels => functional}/_type_conversion.py (100%) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 3156e584a8d..190867523eb 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -2,7 +2,7 @@ import pytest import torch -from test_prototype_transforms_kernels import make_images, make_bounding_boxes, make_one_hot_labels +from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels from torchvision.prototype import transforms, features from torchvision.transforms.functional import to_pil_image diff --git a/test/test_prototype_transforms_kernels.py b/test/test_prototype_transforms_functional.py similarity index 85% rename from test/test_prototype_transforms_kernels.py rename to test/test_prototype_transforms_functional.py index fb436a6a830..e04add0bd60 100644 --- a/test/test_prototype_transforms_kernels.py +++ b/test/test_prototype_transforms_functional.py @@ -3,7 +3,7 @@ import pytest import torch.testing -import torchvision.prototype.transforms.kernels as K +import torchvision.prototype.transforms.functional as F from torch import jit from torch.nn.functional import one_hot from torchvision.prototype import features @@ -134,10 +134,10 @@ def __init__(self, *args, **kwargs): self.kwargs = kwargs -class KernelInfo: +class FunctionalInfo: def __init__(self, name, *, sample_inputs_fn): self.name = name - self.kernel = getattr(K, name) + self.functional = getattr(F, name) self._sample_inputs_fn = sample_inputs_fn def sample_inputs(self): @@ -146,16 +146,16 @@ def sample_inputs(self): def __call__(self, *args, **kwargs): if len(args) == 1 and not kwargs and isinstance(args[0], SampleInput): sample_input = args[0] - return self.kernel(*sample_input.args, **sample_input.kwargs) + return self.functional(*sample_input.args, **sample_input.kwargs) - return self.kernel(*args, **kwargs) + return self.functional(*args, **kwargs) -KERNEL_INFOS = [] +FUNCTIONAL_INFOS = [] def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): - KERNEL_INFOS.append(KernelInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) + FUNCTIONAL_INFOS.append(FunctionalInfo(sample_inputs_fn.__name__, sample_inputs_fn=sample_inputs_fn)) return sample_inputs_fn @@ -176,8 +176,8 @@ def resize_image(): for image, interpolation in itertools.product( make_images(), [ - K.InterpolationMode.BILINEAR, - K.InterpolationMode.NEAREST, + F.InterpolationMode.BILINEAR, + F.InterpolationMode.NEAREST, ], ): height, width = image.shape[-2:] @@ -200,20 +200,20 @@ def resize_bounding_box(): class TestKernelsCommon: - @pytest.mark.parametrize("kernel_info", KERNEL_INFOS, ids=lambda kernel_info: kernel_info.name) - def test_scriptable(self, kernel_info): - jit.script(kernel_info.kernel) + @pytest.mark.parametrize("functional_info", FUNCTIONAL_INFOS, ids=lambda functional_info: functional_info.name) + def test_scriptable(self, functional_info): + jit.script(functional_info.functional) @pytest.mark.parametrize( - ("kernel_info", "sample_input"), + ("functional_info", "sample_input"), [ - pytest.param(kernel_info, sample_input, id=f"{kernel_info.name}-{idx}") - for kernel_info in KERNEL_INFOS - for idx, sample_input in enumerate(kernel_info.sample_inputs()) + pytest.param(functional_info, sample_input, id=f"{functional_info.name}-{idx}") + for functional_info in FUNCTIONAL_INFOS + for idx, sample_input in enumerate(functional_info.sample_inputs()) ], ) - def test_eager_vs_scripted(self, kernel_info, sample_input): - eager = kernel_info(sample_input) - scripted = jit.script(kernel_info.kernel)(*sample_input.args, **sample_input.kwargs) + def test_eager_vs_scripted(self, functional_info, sample_input): + eager = functional_info(sample_input) + scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs) torch.testing.assert_close(eager, scripted) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index fbe19549dca..5b60d7ee55c 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -41,7 +41,7 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: # promote this out of the prototype state # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.kernels import convert_bounding_box_format + from torchvision.prototype.transforms.functional import convert_bounding_box_format if isinstance(format, str): format = BoundingBoxFormat[format] diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index ab6b821d673..276aeec2529 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -43,7 +43,7 @@ def decode(self) -> Image: # promote this out of the prototype state # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.kernels import decode_image_with_pil + from torchvision.prototype.transforms.functional import decode_image_with_pil return Image(decode_image_with_pil(self)) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 3efd7a4130f..3dd3158f3fc 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,6 +1,6 @@ from torchvision.transforms import InterpolationMode, AutoAugmentPolicy # usort: skip -from . import kernels # usort: skip +from . import functional # usort: skip from ._transform import Transform # usort: skip diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 8d563e4c9a1..31c8d180f5d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -6,7 +6,7 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms import functional as _F from ._utils import query_image, get_image_size, get_image_num_channels @@ -95,7 +95,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is torch.Tensor: return _F.erase(input, **params) elif type(input) is features.Image: - return features.Image.new_like(input, K.erase_image(input, **params)) + return features.Image.new_like(input, F.erase_image(input, **params)) elif type(input) in {features.BoundingBox, features.SegmentationMask} or isinstance(input, PIL.Image.Image): raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") else: @@ -119,10 +119,10 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: - output = K.mixup_image(input, **params) + output = F.mixup_image(input, **params) return features.Image.new_like(input, output) elif type(input) is features.OneHotLabel: - output = K.mixup_one_hot_label(input, **params) + output = F.mixup_one_hot_label(input, **params) return features.OneHotLabel.new_like(input, output) elif type(input) in {torch.Tensor, features.BoundingBox, features.SegmentationMask}: raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") @@ -161,10 +161,10 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: - output = K.cutmix_image(input, box=params["box"]) + output = F.cutmix_image(input, box=params["box"]) return features.Image.new_like(input, output) elif type(input) is features.OneHotLabel: - output = K.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) + output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) return features.OneHotLabel.new_like(input, output) elif type(input) in {torch.Tensor, features.BoundingBox, features.SegmentationMask}: raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index a7dbcee6230..da4130cb0ff 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -5,7 +5,7 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, kernels as K +from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.utils._internal import apply_recursively from torchvision.transforms import functional as _F @@ -43,65 +43,65 @@ def __init__( _DISPATCHER_MAP = { "ShearX": AutoAugmentDispatcher( - K.affine_image, + F.affine_image, _F.affine, magnitude_fn=lambda magnitude: dict(shear=[math.degrees(magnitude), 0]), extra_kwargs=dict(angle=0.0, translate=[0, 0], scale=1.0), takes_interpolation_kwargs=True, ), "ShearY": AutoAugmentDispatcher( - K.affine_image, + F.affine_image, _F.affine, magnitude_fn=lambda magnitude: dict(shear=[0, math.degrees(magnitude)]), extra_kwargs=dict(angle=0.0, translate=[0, 0], scale=1.0), takes_interpolation_kwargs=True, ), "TranslateX": AutoAugmentDispatcher( - K.affine_image, + F.affine_image, _F.affine, magnitude_fn=lambda magnitude: dict(translate=[int(magnitude), 0]), extra_kwargs=dict(angle=0.0, scale=1.0, shear=[0.0, 0.0]), takes_interpolation_kwargs=True, ), "TranslateY": AutoAugmentDispatcher( - K.affine_image, + F.affine_image, _F.affine, magnitude_fn=lambda magnitude: dict(translate=[0, int(magnitude)]), extra_kwargs=dict(angle=0.0, scale=1.0, shear=[0.0, 0.0]), takes_interpolation_kwargs=True, ), "Rotate": AutoAugmentDispatcher( - K.rotate_image, _F.rotate, magnitude_fn=lambda magnitude: dict(angle=magnitude) + F.rotate_image, _F.rotate, magnitude_fn=lambda magnitude: dict(angle=magnitude) ), "Brightness": AutoAugmentDispatcher( - K.adjust_brightness_image, + F.adjust_brightness_image, _F.adjust_brightness, magnitude_fn=lambda magnitude: dict(brightness_factor=1.0 + magnitude), ), "Color": AutoAugmentDispatcher( - K.adjust_saturation_image, + F.adjust_saturation_image, _F.adjust_saturation, magnitude_fn=lambda magnitude: dict(saturation_factor=1.0 + magnitude), ), "Contrast": AutoAugmentDispatcher( - K.adjust_contrast_image, + F.adjust_contrast_image, _F.adjust_contrast, magnitude_fn=lambda magnitude: dict(contrast_factor=1.0 + magnitude), ), "Sharpness": AutoAugmentDispatcher( - K.adjust_sharpness_image, + F.adjust_sharpness_image, _F.adjust_sharpness, magnitude_fn=lambda magnitude: dict(sharpness_factor=1.0 + magnitude), ), "Posterize": AutoAugmentDispatcher( - K.posterize_image, _F.posterize, magnitude_fn=lambda magnitude: dict(bits=int(magnitude)) + F.posterize_image, _F.posterize, magnitude_fn=lambda magnitude: dict(bits=int(magnitude)) ), "Solarize": AutoAugmentDispatcher( - K.solarize_image, _F.solarize, magnitude_fn=lambda magnitude: dict(threshold=magnitude) + F.solarize_image, _F.solarize, magnitude_fn=lambda magnitude: dict(threshold=magnitude) ), - "AutoContrast": AutoAugmentDispatcher(K.autocontrast_image, _F.autocontrast), - "Equalize": AutoAugmentDispatcher(K.equalize_image, _F.equalize), - "Invert": AutoAugmentDispatcher(K.invert_image, _F.invert), + "AutoContrast": AutoAugmentDispatcher(F.autocontrast_image, _F.autocontrast), + "Equalize": AutoAugmentDispatcher(F.equalize_image, _F.equalize), + "Invert": AutoAugmentDispatcher(F.invert_image, _F.invert), } def _get_params(self, sample: Any) -> Dict[str, Any]: diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 169b08e88e8..af4e53ce6da 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -5,7 +5,7 @@ import PIL.Image import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, InterpolationMode, kernels as K +from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F from torchvision.transforms import functional as _F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int @@ -17,10 +17,10 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): return _F.hflip(input) elif type(input) is features.Image: - output = K.horizontal_flip_image(input) + output = F.horizontal_flip_image(input) return features.Image.new_like(input, output) elif type(input) is features.BoundingBox: - output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) return features.BoundingBox.new_like(input, output) else: return input @@ -40,13 +40,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): return _F.resize(input, size=self.size, interpolation=self.interpolation) elif type(input) is features.Image: - output = K.resize_image(input, size=self.size, interpolation=self.interpolation) + output = F.resize_image(input, size=self.size, interpolation=self.interpolation) return features.Image.new_like(input, output) elif type(input) is features.SegmentationMask: - output = K.resize_segmentation_mask(input, size=self.size) + output = F.resize_segmentation_mask(input, size=self.size) return features.SegmentationMask.new_like(input, output) elif type(input) is features.BoundingBox: - output = K.resize_bounding_box(input, size=self.size, image_size=input.image_size) + output = F.resize_bounding_box(input, size=self.size, image_size=input.image_size) return features.BoundingBox.new_like(input, output, image_size=self.size) else: return input @@ -61,7 +61,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): return _F.center_crop(input, output_size=self.output_size) elif type(input) is features.Image: - output = K.center_crop_image(input, **params) + output = F.center_crop_image(input, **params) return features.Image.new_like(input, output) elif type(input) in {features.BoundingBox, features.SegmentationMask}: raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") @@ -143,7 +143,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: - output = K.resized_crop_image(input, size=self.size, interpolation=self.interpolation, **params) + output = F.resized_crop_image(input, size=self.size, interpolation=self.interpolation, **params) return features.Image.new_like(input, output) elif type(input) in {features.BoundingBox, features.SegmentationMask}: raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta_conversion.py index 41027af7b34..e33df450ce3 100644 --- a/torchvision/prototype/transforms/_meta_conversion.py +++ b/torchvision/prototype/transforms/_meta_conversion.py @@ -2,7 +2,7 @@ import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms import functional as _F @@ -15,7 +15,7 @@ def __init__(self, format: Union[str, features.BoundingBoxFormat]) -> None: def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.BoundingBox: - output = K.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"]) + output = F.convert_bounding_box_format(input, old_format=input.format, new_format=params["format"]) return features.BoundingBox.new_like(input, output, format=params["format"]) else: return input @@ -43,7 +43,7 @@ def __init__(self, color_space: Union[str, features.ColorSpace]) -> None: def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: - output = K.convert_color_space(input, old_color_space=input.color_space, new_color_space=self.color_space) + output = F.convert_color_space(input, old_color_space=input.color_space, new_color_space=self.color_space) return features.Image.new_like(input, output, color_space=self.color_space) else: return input diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index f8c5e4c41b6..b1b57d4d63a 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -3,7 +3,7 @@ import torch from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms import functional as _F @@ -43,7 +43,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is torch.Tensor: return _F.normalize(input, mean=self.mean, std=self.std) if type(input) is features.Image: - output = K.normalize_image(input, mean=self.mean, std=self.std) + output = F.normalize_image(input, mean=self.mean, std=self.std) return features.Image.new_like(input, output) diff --git a/torchvision/prototype/transforms/_type_conversion.py b/torchvision/prototype/transforms/_type_conversion.py index 26120d660d6..fa49c35265e 100644 --- a/torchvision/prototype/transforms/_type_conversion.py +++ b/torchvision/prototype/transforms/_type_conversion.py @@ -1,13 +1,13 @@ from typing import Any, Dict from torchvision.prototype import features -from torchvision.prototype.transforms import Transform, kernels as K +from torchvision.prototype.transforms import Transform, functional as F class DecodeImage(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.EncodedImage: - output = K.decode_image_with_pil(input) + output = F.decode_image_with_pil(input) return features.Image(output) else: return input @@ -23,7 +23,7 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: num_categories = self.num_categories if num_categories == -1 and input.categories is not None: num_categories = len(input.categories) - output = K.label_to_one_hot(input, num_categories=num_categories) + output = F.label_to_one_hot(input, num_categories=num_categories) return features.OneHotLabel(output, categories=input.categories) else: return input diff --git a/torchvision/prototype/transforms/kernels/__init__.py b/torchvision/prototype/transforms/functional/__init__.py similarity index 100% rename from torchvision/prototype/transforms/kernels/__init__.py rename to torchvision/prototype/transforms/functional/__init__.py diff --git a/torchvision/prototype/transforms/kernels/_augment.py b/torchvision/prototype/transforms/functional/_augment.py similarity index 100% rename from torchvision/prototype/transforms/kernels/_augment.py rename to torchvision/prototype/transforms/functional/_augment.py diff --git a/torchvision/prototype/transforms/kernels/_color.py b/torchvision/prototype/transforms/functional/_color.py similarity index 100% rename from torchvision/prototype/transforms/kernels/_color.py rename to torchvision/prototype/transforms/functional/_color.py diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py similarity index 100% rename from torchvision/prototype/transforms/kernels/_geometry.py rename to torchvision/prototype/transforms/functional/_geometry.py diff --git a/torchvision/prototype/transforms/kernels/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta_conversion.py similarity index 100% rename from torchvision/prototype/transforms/kernels/_meta_conversion.py rename to torchvision/prototype/transforms/functional/_meta_conversion.py diff --git a/torchvision/prototype/transforms/kernels/_misc.py b/torchvision/prototype/transforms/functional/_misc.py similarity index 100% rename from torchvision/prototype/transforms/kernels/_misc.py rename to torchvision/prototype/transforms/functional/_misc.py diff --git a/torchvision/prototype/transforms/kernels/_type_conversion.py b/torchvision/prototype/transforms/functional/_type_conversion.py similarity index 100% rename from torchvision/prototype/transforms/kernels/_type_conversion.py rename to torchvision/prototype/transforms/functional/_type_conversion.py From b3c0452767433c5500795457fe9d4ae2073a83b0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 21 Feb 2022 09:13:33 +0100 Subject: [PATCH 13/25] move image size and num channels extraction to functional --- torchvision/prototype/transforms/_augment.py | 8 +++--- .../prototype/transforms/_auto_augment.py | 6 ++--- torchvision/prototype/transforms/_geometry.py | 4 +-- torchvision/prototype/transforms/_utils.py | 24 +---------------- .../transforms/functional/__init__.py | 1 + .../prototype/transforms/functional/_utils.py | 27 +++++++++++++++++++ 6 files changed, 38 insertions(+), 32 deletions(-) create mode 100644 torchvision/prototype/transforms/functional/_utils.py diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 31c8d180f5d..cd6df9aca95 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms import functional as _F -from ._utils import query_image, get_image_size, get_image_num_channels +from ._utils import query_image class RandomErasing(Transform): @@ -43,8 +43,8 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - img_c = get_image_num_channels(image) - img_h, img_w = get_image_size(image) + img_c = F.get_image_num_channels(image) + img_h, img_w = F.get_image_size(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -140,7 +140,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) image = query_image(sample) - H, W = get_image_size(image) + H, W = F.get_image_size(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index da4130cb0ff..6a41bb11984 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.utils._internal import apply_recursively from torchvision.transforms import functional as _F -from ._utils import query_image, get_image_size, get_image_num_channels +from ._utils import query_image @dataclasses.dataclass @@ -108,7 +108,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: fill = self.fill if isinstance(fill, (int, float)): image = query_image(sample) - num_channels = get_image_num_channels(image) + num_channels = F.get_image_num_channels(image) fill = [float(fill)] * num_channels elif fill is not None: fill = [float(f) for f in fill] @@ -123,7 +123,7 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: params = params or self._get_params(sample) image = query_image(sample) - image_size = get_image_size(image) + image_size = F.get_image_size(image) for transform_id, magnitude in self.get_transforms_meta(image_size): dispatcher = self._DISPATCHER_MAP[transform_id] diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index af4e53ce6da..e78bab22251 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -9,7 +9,7 @@ from torchvision.transforms import functional as _F from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int -from ._utils import query_image, get_image_size +from ._utils import query_image class HorizontalFlip(Transform): @@ -104,7 +104,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - height, width = get_image_size(image) + height, width = F.get_image_size(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 31f2575c3d3..24d794a2cb4 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -1,4 +1,4 @@ -from typing import Any, Optional, Tuple, cast, Union +from typing import Any, Optional, Union import PIL.Image import torch @@ -17,25 +17,3 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima return next(query_recursively(fn, sample)) except StopIteration: raise TypeError("No image was found in the sample") - - -def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: - if type(image) is torch.Tensor: - return cast(Tuple[int, int], image.shape[-2:]) - elif isinstance(image, PIL.Image.Image): - return image.height, image.width - elif type(image) is features.Image: - return image.image_size - else: - raise TypeError(f"unable to get image size from object of type {type(image).__name__}") - - -def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: - if type(image) is torch.Tensor: - return image.shape[-3] - elif isinstance(image, PIL.Image.Image): - return len(image.getbands()) - elif type(image) is features.Image: - return image.num_channels - else: - raise TypeError(f"unable to get image size from object of type {type(image).__name__}") diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 1cac91d29c1..fb29703118a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,4 +1,5 @@ from torchvision.transforms import InterpolationMode # usort: skip +from ._utils import get_image_size, get_image_num_channels # usort: skip from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip from ._augment import ( diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py new file mode 100644 index 00000000000..117808a9e73 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -0,0 +1,27 @@ +from typing import Tuple, cast, Union + +import PIL.Image +import torch +from torchvision.prototype import features + + +def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: + if type(image) is torch.Tensor: + return cast(Tuple[int, int], image.shape[-2:]) + elif isinstance(image, PIL.Image.Image): + return image.height, image.width + elif type(image) is features.Image: + return image.image_size + else: + raise TypeError(f"unable to get image size from object of type {type(image).__name__}") + + +def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: + if type(image) is torch.Tensor: + return image.shape[-3] + elif isinstance(image, PIL.Image.Image): + return len(image.getbands()) + elif type(image) is features.Image: + return image.num_channels + else: + raise TypeError(f"unable to get num channels from object of type {type(image).__name__}") From c129dea89b126d007ad3db62cd7920ec4abcf42f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 21 Feb 2022 09:16:30 +0100 Subject: [PATCH 14/25] extend legacy function to extract image size and num channels --- .../prototype/transforms/functional/_utils.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 117808a9e73..5f54247a169 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -3,13 +3,12 @@ import PIL.Image import torch from torchvision.prototype import features +from torchvision.transforms import functional as _F def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: - if type(image) is torch.Tensor: - return cast(Tuple[int, int], image.shape[-2:]) - elif isinstance(image, PIL.Image.Image): - return image.height, image.width + if type(image) is torch.Tensor or isinstance(image, PIL.Image.Image): + return cast(Tuple[int, int], tuple(_F.get_image_size(image))) elif type(image) is features.Image: return image.image_size else: @@ -17,10 +16,8 @@ def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: - if type(image) is torch.Tensor: - return image.shape[-3] - elif isinstance(image, PIL.Image.Image): - return len(image.getbands()) + if type(image) is torch.Tensor or isinstance(image, PIL.Image.Image): + return _F.get_image_num_channels(image) elif type(image) is features.Image: return image.num_channels else: From 9b18c284ab525d7924414235fce99355223ac6a3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Feb 2022 08:40:47 +0100 Subject: [PATCH 15/25] implement dispatching for auto augment --- .../prototype/transforms/_auto_augment.py | 389 ++++++++---------- 1 file changed, 178 insertions(+), 211 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 6a41bb11984..2036f2e8198 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -1,6 +1,5 @@ -import dataclasses import math -from typing import Any, Dict, Tuple, Optional, Callable, List, cast, Iterator +from typing import Any, Dict, Tuple, Optional, Callable, List, cast, TypeVar import PIL.Image import torch @@ -11,26 +10,8 @@ from ._utils import query_image - -@dataclasses.dataclass -class AutoAugmentDispatcher: - kernel: Callable - legacy_kernel: Callable - magnitude_fn: Optional[Callable[[float], Dict[str, Any]]] = None - extra_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) - takes_interpolation_kwargs: bool = False - - def __call__( - self, input: Any, *, magnitude: float, interpolation: InterpolationMode, fill: Optional[List[float]] - ) -> Any: - kwargs = self.extra_kwargs.copy() - if self.magnitude_fn is not None: - kwargs.update(self.magnitude_fn(magnitude)) - if self.takes_interpolation_kwargs: - kwargs.update(dict(interpolation=interpolation, fill=fill)) - - kernel = self.kernel if type(input) is features.Image else self.legacy_kernel - return kernel(input, **kwargs) +K = TypeVar("K") +V = TypeVar("V") class _AutoAugmentBase(Transform): @@ -41,157 +22,136 @@ def __init__( self.interpolation = interpolation self.fill = fill - _DISPATCHER_MAP = { - "ShearX": AutoAugmentDispatcher( - F.affine_image, - _F.affine, - magnitude_fn=lambda magnitude: dict(shear=[math.degrees(magnitude), 0]), - extra_kwargs=dict(angle=0.0, translate=[0, 0], scale=1.0), - takes_interpolation_kwargs=True, - ), - "ShearY": AutoAugmentDispatcher( - F.affine_image, - _F.affine, - magnitude_fn=lambda magnitude: dict(shear=[0, math.degrees(magnitude)]), - extra_kwargs=dict(angle=0.0, translate=[0, 0], scale=1.0), - takes_interpolation_kwargs=True, - ), - "TranslateX": AutoAugmentDispatcher( - F.affine_image, - _F.affine, - magnitude_fn=lambda magnitude: dict(translate=[int(magnitude), 0]), - extra_kwargs=dict(angle=0.0, scale=1.0, shear=[0.0, 0.0]), - takes_interpolation_kwargs=True, - ), - "TranslateY": AutoAugmentDispatcher( - F.affine_image, - _F.affine, - magnitude_fn=lambda magnitude: dict(translate=[0, int(magnitude)]), - extra_kwargs=dict(angle=0.0, scale=1.0, shear=[0.0, 0.0]), - takes_interpolation_kwargs=True, - ), - "Rotate": AutoAugmentDispatcher( - F.rotate_image, _F.rotate, magnitude_fn=lambda magnitude: dict(angle=magnitude) - ), - "Brightness": AutoAugmentDispatcher( - F.adjust_brightness_image, - _F.adjust_brightness, - magnitude_fn=lambda magnitude: dict(brightness_factor=1.0 + magnitude), + _DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = { + "Identity": lambda input, magnitude, interpolation, fill: input, + "ShearX": lambda input, magnitude, interpolation, fill: ( + F.affine_image if type(input) is features.Image else _F.affine + )( + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, + fill=fill, ), - "Color": AutoAugmentDispatcher( - F.adjust_saturation_image, - _F.adjust_saturation, - magnitude_fn=lambda magnitude: dict(saturation_factor=1.0 + magnitude), + "ShearY": lambda input, magnitude, interpolation, fill: ( + F.affine_image if type(input) is features.Image else _F.affine + )( + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, + fill=fill, ), - "Contrast": AutoAugmentDispatcher( - F.adjust_contrast_image, - _F.adjust_contrast, - magnitude_fn=lambda magnitude: dict(contrast_factor=1.0 + magnitude), + "TranslateX": lambda input, magnitude, interpolation, fill: ( + F.affine_image if type(input) is features.Image else _F.affine + )( + input, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, ), - "Sharpness": AutoAugmentDispatcher( - F.adjust_sharpness_image, - _F.adjust_sharpness, - magnitude_fn=lambda magnitude: dict(sharpness_factor=1.0 + magnitude), + "TranslateY": lambda input, magnitude, interpolation, fill: ( + F.affine_image if type(input) is features.Image else _F.affine + )( + input, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, ), - "Posterize": AutoAugmentDispatcher( - F.posterize_image, _F.posterize, magnitude_fn=lambda magnitude: dict(bits=int(magnitude)) - ), - "Solarize": AutoAugmentDispatcher( - F.solarize_image, _F.solarize, magnitude_fn=lambda magnitude: dict(threshold=magnitude) - ), - "AutoContrast": AutoAugmentDispatcher(F.autocontrast_image, _F.autocontrast), - "Equalize": AutoAugmentDispatcher(F.equalize_image, _F.equalize), - "Invert": AutoAugmentDispatcher(F.invert_image, _F.invert), + "Rotate": lambda input, magnitude, interpolation, fill: ( + F.rotate_image if type(input) is features.Image else _F.rotate + )(input, angle=magnitude), + "Brightness": lambda input, magnitude, interpolation, fill: ( + F.adjust_brightness_image if type(input) is features.Image else _F.adjust_brightness + )(input, brightness_factor=1.0 + magnitude), + "Color": lambda input, magnitude, interpolation, fill: ( + F.adjust_saturation_image if type(input) is features.Image else _F.adjust_saturation + )(input, saturation_factor=1.0 + magnitude), + "Contrast": lambda input, magnitude, interpolation, fill: ( + F.adjust_contrast_image if type(input) is features.Image else _F.adjust_contrast + )(input, contrast_factor=1.0 + magnitude), + "Sharpness": lambda input, magnitude, interpolation, fill: ( + F.adjust_sharpness_image if type(input) is features.Image else _F.adjust_sharpness + )(input, sharpness_factor=1.0 + magnitude), + "Posterize": lambda input, magnitude, interpolation, fill: ( + F.posterize_image if type(input) is features.Image else _F.posterize + )(input, bits=int(magnitude)), + "Solarize": lambda input, magnitude, interpolation, fill: ( + F.solarize_image if type(input) is features.Image else _F.solarize + )(input, threshold=magnitude), + "AutoContrast": lambda input, magnitude, interpolation, fill: ( + F.autocontrast_image if type(input) is features.Image else _F.autocontrast + )(input), + "Equalize": lambda input, magnitude, interpolation, fill: ( + F.equalize_image if type(input) is features.Image else _F.equalize + )(input), + "Invert": lambda input, magnitude, interpolation, fill: ( + F.invert_image if type(input) is features.Image else _F.invert + )(input), } def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + num_channels = F.get_image_num_channels(image) + fill = self.fill if isinstance(fill, (int, float)): - image = query_image(sample) - num_channels = F.get_image_num_channels(image) fill = [float(fill)] * num_channels elif fill is not None: fill = [float(f) for f in fill] - return dict(fill=fill) + return dict(interpolation=self.interpolation, fill=fill) - def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: - raise NotImplementedError + def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: + keys = tuple(dct.keys()) + key = keys[int(torch.randint(len(keys), ()))] + return key, dct[key] - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self._get_params(sample) - - image = query_image(sample) - image_size = F.get_image_size(image) - - for transform_id, magnitude in self.get_transforms_meta(image_size): - dispatcher = self._DISPATCHER_MAP[transform_id] - - def transform(input: Any) -> Any: - if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): - return dispatcher( # type: ignore[arg-type] - input, - magnitude=magnitude, - interpolation=self.interpolation, - **params, - ) - else: - return input - - sample = apply_recursively(transform, sample) - - return sample - - def _randbool(self, p: float = 0.5) -> bool: - """Randomly returns either ``True`` or ``False``. - - Args: - p: Probability to return ``True``. Defaults to ``0.5``. - """ - return float(torch.rand(())) <= p + def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any: + dispatcher = self._DISPATCHER_MAP[transform_id] + def transform(input: Any) -> Any: + if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image): + return dispatcher(input, magnitude, params["interpolation"], params["fill"]) + elif type(input) in {features.BoundingBox, features.SegmentationMask}: + raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") + else: + return input -@dataclasses.dataclass -class AugmentationMeta: - dispatcher_id: str - magnitudes_fn: Callable[[int, Tuple[int, int]], Optional[torch.Tensor]] - signed: bool + return apply_recursively(transform, sample) class AutoAugment(_AutoAugmentBase): - _AUGMENTATION_SPACE = ( - AugmentationMeta("ShearX", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - AugmentationMeta("ShearY", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - AugmentationMeta( - "TranslateX", - lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), - True, - ), - AugmentationMeta( - "TranslateY", - lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), - True, - ), - AugmentationMeta("Rotate", lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - AugmentationMeta("Brightness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta("Color", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta("Contrast", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta("Sharpness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta( - "Posterize", + _AUGMENTATION_SPACE = { + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - AugmentationMeta("Solarize", lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - AugmentationMeta("AutoContrast", lambda num_bins, image_size: None, False), - AugmentationMeta("Equalize", lambda num_bins, image_size: None, False), - AugmentationMeta("Invert", lambda num_bins, image_size: None, False), - ) - _AUGMENTATION_SPACE = { - augmentation_meta.dispatcher_id: augmentation_meta for augmentation_meta in _AUGMENTATION_SPACE + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + "Invert": (lambda num_bins, image_size: None, False), } def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, **kwargs: Any) -> None: @@ -289,57 +249,56 @@ def _get_policies( else: raise ValueError(f"The provided policy {policy} is not recognized.") - def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + params = params or self._get_params(sample) + + image = query_image(sample) + image_size = F.get_image_size(image) + policy = self._policies[int(torch.randint(len(self._policies), ()))] - for dispatcher_id, probability, magnitude_idx in policy: - if not self._randbool(probability): + for transform_id, probability, magnitude_idx in policy: + if not torch.rand(()) <= probability: continue - augmentation_meta = self._AUGMENTATION_SPACE[dispatcher_id] + magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id] - magnitudes = augmentation_meta.magnitudes_fn(10, image_size) + magnitudes = magnitudes_fn(10, image_size) if magnitudes is not None: magnitude = float(magnitudes[magnitude_idx]) - if augmentation_meta.signed and self._randbool(): + if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 - yield augmentation_meta.dispatcher_id, magnitude + sample = self._apply_transform(sample, params, transform_id, magnitude) + + return sample class RandAugment(_AutoAugmentBase): - _AUGMENTATION_SPACE = ( - AugmentationMeta("Identity", lambda num_bins, image_size: None, False), - AugmentationMeta("ShearX", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - AugmentationMeta("ShearY", lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - AugmentationMeta( - "TranslateX", - lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), - True, - ), - AugmentationMeta( - "TranslateY", - lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), - True, - ), - AugmentationMeta("Rotate", lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), - AugmentationMeta("Brightness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta("Color", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta("Contrast", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta("Sharpness", lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - AugmentationMeta( - "Posterize", + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, image_size: None, False), + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": ( lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 4))) .round() .int(), False, ), - AugmentationMeta("Solarize", lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - AugmentationMeta("AutoContrast", lambda num_bins, image_size: None, False), - AugmentationMeta("Equalize", lambda num_bins, image_size: None, False), - ) + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + } def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: int = 31, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -347,63 +306,71 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins - def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + params = params or self._get_params(sample) + + image = query_image(sample) + image_size = F.get_image_size(image) + for _ in range(self.num_ops): - augmentation_meta = self._AUGMENTATION_SPACE[int(torch.randint(len(self._AUGMENTATION_SPACE), ()))] - if augmentation_meta.dispatcher_id == "Identity": - continue + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = augmentation_meta.magnitudes_fn(self.num_magnitude_bins, image_size) + magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) - if augmentation_meta.signed and self._randbool(): + if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 - yield augmentation_meta.dispatcher_id, magnitude + sample = self._apply_transform(sample, params, transform_id, magnitude) + + return sample class TrivialAugmentWide(_AutoAugmentBase): - _AUGMENTATION_SPACE = ( - AugmentationMeta("Identity", lambda num_bins, image_size: None, False), - AugmentationMeta("ShearX", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - AugmentationMeta("ShearY", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - AugmentationMeta("TranslateX", lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - AugmentationMeta("TranslateY", lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), - AugmentationMeta("Rotate", lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), - AugmentationMeta("Brightness", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - AugmentationMeta("Color", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - AugmentationMeta("Contrast", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - AugmentationMeta("Sharpness", lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - AugmentationMeta( - "Posterize", + _AUGMENTATION_SPACE = { + "Identity": (lambda num_bins, image_size: None, False), + "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), + "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), + "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Posterize": ( lambda num_bins, image_size: cast(torch.Tensor, 8 - (torch.arange(num_bins) / ((num_bins - 1) / 6))) .round() .int(), False, ), - AugmentationMeta("Solarize", lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - AugmentationMeta("AutoContrast", lambda num_bins, image_size: None, False), - AugmentationMeta("Equalize", lambda num_bins, image_size: None, False), - ) + "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (lambda num_bins, image_size: None, False), + "Equalize": (lambda num_bins, image_size: None, False), + } def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): super().__init__(**kwargs) self.num_magnitude_bins = num_magnitude_bins - def get_transforms_meta(self, image_size: Tuple[int, int]) -> Iterator[Tuple[str, float]]: - augmentation_meta = self._AUGMENTATION_SPACE[int(torch.randint(len(self._AUGMENTATION_SPACE), ()))] + def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + params = params or self._get_params(sample) + + image = query_image(sample) + image_size = F.get_image_size(image) - if augmentation_meta.dispatcher_id == "Identity": - return + transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE) - magnitudes = augmentation_meta.magnitudes_fn(self.num_magnitude_bins, image_size) + magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size) if magnitudes is not None: magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))]) - if augmentation_meta.signed and self._randbool(): + if signed and torch.rand(()) <= 0.5: magnitude *= -1 else: magnitude = 0.0 - yield augmentation_meta.dispatcher_id, magnitude + return self._apply_transform(sample, params, transform_id, magnitude) From 3348f89d9784c36bb62f36a7b85bc826ad4e9aa6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Feb 2022 09:18:18 +0100 Subject: [PATCH 16/25] fix auto augment dispatch --- .../prototype/transforms/_auto_augment.py | 250 +++++++++--------- 1 file changed, 128 insertions(+), 122 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 2036f2e8198..3c22815177a 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -22,84 +22,6 @@ def __init__( self.interpolation = interpolation self.fill = fill - _DISPATCHER_MAP: Dict[str, Callable[[Any, float, InterpolationMode, Optional[List[float]]], Any]] = { - "Identity": lambda input, magnitude, interpolation, fill: input, - "ShearX": lambda input, magnitude, interpolation, fill: ( - F.affine_image if type(input) is features.Image else _F.affine - )( - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[math.degrees(magnitude), 0.0], - interpolation=interpolation, - fill=fill, - ), - "ShearY": lambda input, magnitude, interpolation, fill: ( - F.affine_image if type(input) is features.Image else _F.affine - )( - input, - angle=0.0, - translate=[0, 0], - scale=1.0, - shear=[0.0, math.degrees(magnitude)], - interpolation=interpolation, - fill=fill, - ), - "TranslateX": lambda input, magnitude, interpolation, fill: ( - F.affine_image if type(input) is features.Image else _F.affine - )( - input, - angle=0.0, - translate=[int(magnitude), 0], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ), - "TranslateY": lambda input, magnitude, interpolation, fill: ( - F.affine_image if type(input) is features.Image else _F.affine - )( - input, - angle=0.0, - translate=[0, int(magnitude)], - scale=1.0, - shear=[0.0, 0.0], - interpolation=interpolation, - fill=fill, - ), - "Rotate": lambda input, magnitude, interpolation, fill: ( - F.rotate_image if type(input) is features.Image else _F.rotate - )(input, angle=magnitude), - "Brightness": lambda input, magnitude, interpolation, fill: ( - F.adjust_brightness_image if type(input) is features.Image else _F.adjust_brightness - )(input, brightness_factor=1.0 + magnitude), - "Color": lambda input, magnitude, interpolation, fill: ( - F.adjust_saturation_image if type(input) is features.Image else _F.adjust_saturation - )(input, saturation_factor=1.0 + magnitude), - "Contrast": lambda input, magnitude, interpolation, fill: ( - F.adjust_contrast_image if type(input) is features.Image else _F.adjust_contrast - )(input, contrast_factor=1.0 + magnitude), - "Sharpness": lambda input, magnitude, interpolation, fill: ( - F.adjust_sharpness_image if type(input) is features.Image else _F.adjust_sharpness - )(input, sharpness_factor=1.0 + magnitude), - "Posterize": lambda input, magnitude, interpolation, fill: ( - F.posterize_image if type(input) is features.Image else _F.posterize - )(input, bits=int(magnitude)), - "Solarize": lambda input, magnitude, interpolation, fill: ( - F.solarize_image if type(input) is features.Image else _F.solarize - )(input, threshold=magnitude), - "AutoContrast": lambda input, magnitude, interpolation, fill: ( - F.autocontrast_image if type(input) is features.Image else _F.autocontrast - )(input), - "Equalize": lambda input, magnitude, interpolation, fill: ( - F.equalize_image if type(input) is features.Image else _F.equalize - )(input), - "Invert": lambda input, magnitude, interpolation, fill: ( - F.invert_image if type(input) is features.Image else _F.invert - )(input), - } - def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) num_channels = F.get_image_num_channels(image) @@ -110,24 +32,111 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: elif fill is not None: fill = [float(f) for f in fill] - return dict(interpolation=self.interpolation, fill=fill) + return dict(fill=fill) def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: keys = tuple(dct.keys()) key = keys[int(torch.randint(len(keys), ()))] return key, dct[key] - def _apply_transform(self, sample: Any, params: Dict[str, Any], transform_id: str, magnitude: float) -> Any: - dispatcher = self._DISPATCHER_MAP[transform_id] - - def transform(input: Any) -> Any: - if type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image): - return dispatcher(input, magnitude, params["interpolation"], params["fill"]) + def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any: + def dispatch(image_kernel: Callable, legacy_kernel: Callable, input: Any, *args: Any, **kwargs: Any) -> Any: + if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): + return legacy_kernel(input, *args, **kwargs) + elif type(input) is features.Image: + output = image_kernel(input, *args, **kwargs) + return features.Image.new_like(input, output) elif type(input) in {features.BoundingBox, features.SegmentationMask}: raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") else: return input + interpolation = self.interpolation + fill = self._get_params(sample)["fill"] + + def transform(input: Any) -> Any: + if type(input) in {features.BoundingBox, features.SegmentationMask}: + raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") + elif not (type(input) in {features.Image, torch.Tensor} or isinstance(input, PIL.Image.Image)): + return input + + if transform_id == "Identity": + return input + elif transform_id == "ShearX": + return dispatch( + F.affine_image, + _F.affine, + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "ShearY": + return dispatch( + F.affine_image, + _F.affine, + input, + angle=0.0, + translate=[0, 0], + scale=1.0, + shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateX": + return dispatch( + F.affine_image, + _F.affine, + input, + angle=0.0, + translate=[int(magnitude), 0], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "TranslateY": + return dispatch( + F.affine_image, + _F.affine, + input, + angle=0.0, + translate=[0, int(magnitude)], + scale=1.0, + shear=[0.0, 0.0], + interpolation=interpolation, + fill=fill, + ) + elif transform_id == "Rotate": + return dispatch(F.rotate_image, _F.rotate, input, angle=magnitude) + elif transform_id == "Brightness": + return dispatch( + F.adjust_brightness_image, _F.adjust_brightness, input, brightness_factor=1.0 + magnitude + ) + elif transform_id == "Saturation": + return dispatch( + F.adjust_saturation_image, _F.adjust_saturation, input, saturation_factor=1.0 + magnitude + ) + elif transform_id == "Contrast": + return dispatch(F.adjust_contrast_image, _F.adjust_contrast, input, contrast_factor=1.0 + magnitude) + elif transform_id == "Sharpness": + return dispatch(F.adjust_sharpness_image, _F.adjust_sharpness, input, sharpness_factor=1.0 + magnitude) + elif transform_id == "Posterize": + return dispatch(F.posterize_image, _F.posterize, input, bits=int(magnitude)) + elif transform_id == "Solarize": + return dispatch(F.solarize_image, _F.solarize, input, threshold=magnitude) + elif transform_id == "Autocontrast": + return dispatch(F.autocontrast_image, _F.autocontrast, input) + elif transform_id == "Equalize": + return dispatch(F.equalize_image, _F.equalize, input) + elif transform_id == "Invert": + return dispatch(F.invert_image, _F.invert, input) + else: + raise ValueError(f"No transform available for {transform_id}") + return apply_recursively(transform, sample) @@ -136,10 +145,10 @@ class AutoAugment(_AutoAugmentBase): "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Saturation": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( @@ -149,7 +158,7 @@ class AutoAugment(_AutoAugmentBase): False, ), "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), + "Autocontrast": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, image_size: None, False), "Invert": (lambda num_bins, image_size: None, False), } @@ -165,7 +174,7 @@ def _get_policies( if policy == AutoAugmentPolicy.IMAGENET: return [ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Solarize", 0.6, 5), ("Autocontrast", 0.6, None)), (("Equalize", 0.8, None), ("Equalize", 0.6, None)), (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), @@ -174,20 +183,20 @@ def _get_policies( (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.8, 8), ("Saturation", 0.4, 0)), (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), (("Equalize", 0.0, None), ("Equalize", 0.8, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Saturation", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Saturation", 1.0, 2)), + (("Saturation", 0.8, 8), ("Solarize", 0.8, 7)), (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Saturation", 0.4, 0), ("Equalize", 0.6, None)), (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Solarize", 0.6, 5), ("Autocontrast", 0.6, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Saturation", 0.6, 4), ("Contrast", 1.0, 8)), (("Equalize", 0.8, None), ("Equalize", 0.6, None)), ] elif policy == AutoAugmentPolicy.CIFAR10: @@ -196,27 +205,27 @@ def _get_policies( (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("Autocontrast", 0.5, None), ("Equalize", 0.9, None)), (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Saturation", 0.4, 3), ("Brightness", 0.6, 7)), (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), (("Equalize", 0.6, None), ("Equalize", 0.5, None)), (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("Saturation", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("Autocontrast", 0.4, None)), (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Brightness", 0.9, 6), ("Saturation", 0.2, 8)), (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Autocontrast", 0.6, None)), (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("Saturation", 0.9, 9), ("Equalize", 0.6, None)), + (("Autocontrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Saturation", 0.7, 0)), + (("Solarize", 0.4, 5), ("Autocontrast", 0.9, None)), (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Autocontrast", 0.9, None), ("Solarize", 0.8, 3)), (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.7, 9), ("Autocontrast", 0.9, None)), ] elif policy == AutoAugmentPolicy.SVHN: return [ @@ -225,10 +234,10 @@ def _get_policies( (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearX", 0.9, 4), ("Autocontrast", 0.8, None)), (("ShearY", 0.9, 8), ("Invert", 0.4, None)), (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Invert", 0.9, None), ("Autocontrast", 0.8, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), (("ShearY", 0.8, 8), ("Invert", 0.7, None)), @@ -243,7 +252,7 @@ def _get_policies( (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), (("ShearY", 0.8, 4), ("Invert", 0.8, None)), (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearY", 0.8, 5), ("Autocontrast", 0.7, None)), (("ShearX", 0.7, 2), ("Invert", 0.1, None)), ] else: @@ -251,7 +260,6 @@ def _get_policies( def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self._get_params(sample) image = query_image(sample) image_size = F.get_image_size(image) @@ -272,7 +280,7 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, params, transform_id, magnitude) + sample = self._apply_transform(sample, transform_id, magnitude) return sample @@ -283,10 +291,10 @@ class RandAugment(_AutoAugmentBase): "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Saturation": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( @@ -296,7 +304,7 @@ class RandAugment(_AutoAugmentBase): False, ), "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), + "Autocontrast": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, image_size: None, False), } @@ -308,7 +316,6 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self._get_params(sample) image = query_image(sample) image_size = F.get_image_size(image) @@ -324,7 +331,7 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: else: magnitude = 0.0 - sample = self._apply_transform(sample, params, transform_id, magnitude) + sample = self._apply_transform(sample, transform_id, magnitude) return sample @@ -338,7 +345,7 @@ class TrivialAugmentWide(_AutoAugmentBase): "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Saturation": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( @@ -348,7 +355,7 @@ class TrivialAugmentWide(_AutoAugmentBase): False, ), "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "AutoContrast": (lambda num_bins, image_size: None, False), + "Autocontrast": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, image_size: None, False), } @@ -358,7 +365,6 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] - params = params or self._get_params(sample) image = query_image(sample) image_size = F.get_image_size(image) @@ -373,4 +379,4 @@ def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: else: magnitude = 0.0 - return self._apply_transform(sample, params, transform_id, magnitude) + return self._apply_transform(sample, transform_id, magnitude) From 90f9fa7bdcff9b9adda0e4d690b5a0fa11336a5c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Feb 2022 09:58:20 +0100 Subject: [PATCH 17/25] revert some naming changes --- .../prototype/transforms/_auto_augment.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 3c22815177a..f2c549b244c 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -116,7 +116,7 @@ def transform(input: Any) -> Any: return dispatch( F.adjust_brightness_image, _F.adjust_brightness, input, brightness_factor=1.0 + magnitude ) - elif transform_id == "Saturation": + elif transform_id == "Color": return dispatch( F.adjust_saturation_image, _F.adjust_saturation, input, saturation_factor=1.0 + magnitude ) @@ -128,7 +128,7 @@ def transform(input: Any) -> Any: return dispatch(F.posterize_image, _F.posterize, input, bits=int(magnitude)) elif transform_id == "Solarize": return dispatch(F.solarize_image, _F.solarize, input, threshold=magnitude) - elif transform_id == "Autocontrast": + elif transform_id == "AutoContrast": return dispatch(F.autocontrast_image, _F.autocontrast, input) elif transform_id == "Equalize": return dispatch(F.equalize_image, _F.equalize, input) @@ -148,7 +148,7 @@ class AutoAugment(_AutoAugmentBase): "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Saturation": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( @@ -158,7 +158,7 @@ class AutoAugment(_AutoAugmentBase): False, ), "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "Autocontrast": (lambda num_bins, image_size: None, False), + "AutoContrast": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, image_size: None, False), "Invert": (lambda num_bins, image_size: None, False), } @@ -174,7 +174,7 @@ def _get_policies( if policy == AutoAugmentPolicy.IMAGENET: return [ (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("Autocontrast", 0.6, None)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), (("Equalize", 0.8, None), ("Equalize", 0.6, None)), (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), @@ -183,20 +183,20 @@ def _get_policies( (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Saturation", 0.4, 0)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), (("Equalize", 0.0, None), ("Equalize", 0.8, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Saturation", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Saturation", 1.0, 2)), - (("Saturation", 0.8, 8), ("Solarize", 0.8, 7)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Saturation", 0.4, 0), ("Equalize", 0.6, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("Autocontrast", 0.6, None)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Saturation", 0.6, 4), ("Contrast", 1.0, 8)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), (("Equalize", 0.8, None), ("Equalize", 0.6, None)), ] elif policy == AutoAugmentPolicy.CIFAR10: @@ -205,27 +205,27 @@ def _get_policies( (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("Autocontrast", 0.5, None), ("Equalize", 0.9, None)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Saturation", 0.4, 3), ("Brightness", 0.6, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), (("Equalize", 0.6, None), ("Equalize", 0.5, None)), (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Saturation", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("Autocontrast", 0.4, None)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Saturation", 0.2, 8)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("Autocontrast", 0.6, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Saturation", 0.9, 9), ("Equalize", 0.6, None)), - (("Autocontrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Saturation", 0.7, 0)), - (("Solarize", 0.4, 5), ("Autocontrast", 0.9, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("Autocontrast", 0.9, None), ("Solarize", 0.8, 3)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("Autocontrast", 0.9, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), ] elif policy == AutoAugmentPolicy.SVHN: return [ @@ -234,10 +234,10 @@ def _get_policies( (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), (("Invert", 0.9, None), ("Equalize", 0.6, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Autocontrast", 0.8, None)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), (("ShearY", 0.9, 8), ("Invert", 0.4, None)), (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("Autocontrast", 0.8, None)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), (("ShearY", 0.8, 8), ("Invert", 0.7, None)), @@ -252,7 +252,7 @@ def _get_policies( (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), (("ShearY", 0.8, 4), ("Invert", 0.8, None)), (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("Autocontrast", 0.7, None)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), (("ShearX", 0.7, 2), ("Invert", 0.1, None)), ] else: @@ -294,7 +294,7 @@ class RandAugment(_AutoAugmentBase): "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), - "Saturation": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Posterize": ( @@ -304,7 +304,7 @@ class RandAugment(_AutoAugmentBase): False, ), "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "Autocontrast": (lambda num_bins, image_size: None, False), + "AutoContrast": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, image_size: None, False), } @@ -345,7 +345,7 @@ class TrivialAugmentWide(_AutoAugmentBase): "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 32.0, num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 135.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), - "Saturation": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), + "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Contrast": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Sharpness": (lambda num_bins, image_size: torch.linspace(0.0, 0.99, num_bins), True), "Posterize": ( @@ -355,7 +355,7 @@ class TrivialAugmentWide(_AutoAugmentBase): False, ), "Solarize": (lambda num_bins, image_size: torch.linspace(255.0, 0.0, num_bins), False), - "Autocontrast": (lambda num_bins, image_size: None, False), + "AutoContrast": (lambda num_bins, image_size: None, False), "Equalize": (lambda num_bins, image_size: None, False), } From ddf28d2047ca294cd636332cf16fc66d62d88299 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Feb 2022 09:59:55 +0100 Subject: [PATCH 18/25] remove ability to pass params to autoaugment --- .../prototype/transforms/_auto_augment.py | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index f2c549b244c..c2c73099773 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -22,18 +22,6 @@ def __init__( self.interpolation = interpolation self.fill = fill - def _get_params(self, sample: Any) -> Dict[str, Any]: - image = query_image(sample) - num_channels = F.get_image_num_channels(image) - - fill = self.fill - if isinstance(fill, (int, float)): - fill = [float(fill)] * num_channels - elif fill is not None: - fill = [float(f) for f in fill] - - return dict(fill=fill) - def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: keys = tuple(dct.keys()) key = keys[int(torch.randint(len(keys), ()))] @@ -51,8 +39,16 @@ def dispatch(image_kernel: Callable, legacy_kernel: Callable, input: Any, *args: else: return input + image = query_image(sample) + num_channels = F.get_image_num_channels(image) + + fill = self.fill + if isinstance(fill, (int, float)): + fill = [float(fill)] * num_channels + elif fill is not None: + fill = [float(f) for f in fill] + interpolation = self.interpolation - fill = self._get_params(sample)["fill"] def transform(input: Any) -> Any: if type(input) in {features.BoundingBox, features.SegmentationMask}: @@ -258,7 +254,7 @@ def _get_policies( else: raise ValueError(f"The provided policy {policy} is not recognized.") - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) @@ -314,7 +310,7 @@ def __init__(self, *, num_ops: int = 2, magnitude: int = 9, num_magnitude_bins: self.magnitude = magnitude self.num_magnitude_bins = num_magnitude_bins - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) @@ -363,7 +359,7 @@ def __init__(self, *, num_magnitude_bins: int = 31, **kwargs: Any): super().__init__(**kwargs) self.num_magnitude_bins = num_magnitude_bins - def forward(self, *inputs: Any, params: Optional[Dict[str, Any]] = None) -> Any: + def forward(self, *inputs: Any) -> Any: sample = inputs if len(inputs) > 1 else inputs[0] image = query_image(sample) From 68bbb2b06a332ae89666a2105bc5db964a945f09 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 22 Feb 2022 10:39:20 +0100 Subject: [PATCH 19/25] fix legacy image size extraction --- torchvision/prototype/transforms/functional/_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 5f54247a169..2c84cf2cb55 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,4 +1,4 @@ -from typing import Tuple, cast, Union +from typing import Tuple, Union import PIL.Image import torch @@ -8,7 +8,8 @@ def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: if type(image) is torch.Tensor or isinstance(image, PIL.Image.Image): - return cast(Tuple[int, int], tuple(_F.get_image_size(image))) + width, height = _F.get_image_size(image) + return height, width elif type(image) is features.Image: return image.image_size else: From 1587588573760f1871af6459aff99030fe147e6b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 10:23:06 +0100 Subject: [PATCH 20/25] align prototype.transforms.functional with transforms.functional --- test/test_prototype_transforms_functional.py | 6 +- .../prototype/datasets/_builtin/caltech.py | 2 +- torchvision/prototype/transforms/__init__.py | 2 +- torchvision/prototype/transforms/_augment.py | 35 +++-- .../prototype/transforms/_auto_augment.py | 79 ++++++---- torchvision/prototype/transforms/_geometry.py | 64 +++++--- .../prototype/transforms/_meta_conversion.py | 42 +++++- torchvision/prototype/transforms/_misc.py | 17 ++- .../transforms/functional/__init__.py | 79 ++++++---- .../transforms/functional/_augment.py | 16 +- .../prototype/transforms/functional/_color.py | 48 ++++-- .../transforms/functional/_geometry.py | 140 +++++++++++++++--- .../transforms/functional/_meta_conversion.py | 32 +++- .../prototype/transforms/functional/_misc.py | 43 +++++- .../prototype/transforms/functional/_utils.py | 2 + 15 files changed, 441 insertions(+), 166 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e04add0bd60..c187ccc737e 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -160,7 +160,7 @@ def register_kernel_info_from_sample_inputs_fn(sample_inputs_fn): @register_kernel_info_from_sample_inputs_fn -def horizontal_flip_image(): +def horizontal_flip_image_tensor(): for image in make_images(): yield SampleInput(image) @@ -172,7 +172,7 @@ def horizontal_flip_bounding_box(): @register_kernel_info_from_sample_inputs_fn -def resize_image(): +def resize_image_tensor(): for image, interpolation in itertools.product( make_images(), [ @@ -185,7 +185,7 @@ def resize_image(): (height, width), (int(height * 0.75), int(width * 1.25)), ]: - yield SampleInput(image, size=size, interpolation=interpolation) + yield SampleInput(image, size=size, interpolation=interpolation.value) @register_kernel_info_from_sample_inputs_fn diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 1a052860ebf..4c66a1e70d8 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -95,7 +95,7 @@ def _prepare_sample( bounding_box=BoundingBox( ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size ), - contour=_Feature(ann["obj_contour"].T), + contour=_Feature(ann["obj_contour"]._FT), ) def _make_datapipe( diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 3dd3158f3fc..73235720d58 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -8,7 +8,7 @@ from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop -from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertColorSpace +from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval from ._type_conversion import DecodeImage, LabelToOneHot diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index cd6df9aca95..96449d67790 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -3,11 +3,9 @@ import warnings from typing import Any, Dict, Tuple -import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from torchvision.transforms import functional as _F from ._utils import query_image @@ -92,12 +90,13 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(zip("ijhwv", (i, j, h, w, v))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is torch.Tensor: - return _F.erase(input, **params) - elif type(input) is features.Image: - return features.Image.new_like(input, F.erase_image(input, **params)) - elif type(input) in {features.BoundingBox, features.SegmentationMask} or isinstance(input, PIL.Image.Image): - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.erase_image_tensor(input, **params) + return features.Image.new_like(input, output) + elif isinstance(input, torch.Tensor): + return F.erase_image_tensor(input, **params) else: return input @@ -118,14 +117,14 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: - output = F.mixup_image(input, **params) + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.mixup_image_tensor(input, **params) return features.Image.new_like(input, output) - elif type(input) is features.OneHotLabel: + elif isinstance(input, features.OneHotLabel): output = F.mixup_one_hot_label(input, **params) return features.OneHotLabel.new_like(input, output) - elif type(input) in {torch.Tensor, features.BoundingBox, features.SegmentationMask}: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") else: return input @@ -160,13 +159,13 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: - output = F.cutmix_image(input, box=params["box"]) + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.cutmix_image_tensor(input, box=params["box"]) return features.Image.new_like(input, output) - elif type(input) is features.OneHotLabel: + elif isinstance(input, features.OneHotLabel): output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) return features.OneHotLabel.new_like(input, output) - elif type(input) in {torch.Tensor, features.BoundingBox, features.SegmentationMask}: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") else: return input diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index c2c73099773..9a9bc7f0f29 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -6,7 +6,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F from torchvision.prototype.utils._internal import apply_recursively -from torchvision.transforms import functional as _F +from torchvision.transforms.functional import pil_modes_mapping from ._utils import query_image @@ -28,14 +28,28 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]: return key, dct[key] def _apply_transform(self, sample: Any, transform_id: str, magnitude: float) -> Any: - def dispatch(image_kernel: Callable, legacy_kernel: Callable, input: Any, *args: Any, **kwargs: Any) -> Any: - if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): - return legacy_kernel(input, *args, **kwargs) - elif type(input) is features.Image: - output = image_kernel(input, *args, **kwargs) + def dispatch( + image_tensor_kernel: Callable, + image_pil_kernel: Callable, + input: Any, + *args: Any, + **kwargs: Any, + ) -> Any: + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + if "interpolation" in kwargs: + kwargs["interpolation"] = kwargs["interpolation"].value + output = image_tensor_kernel(input, *args, **kwargs) return features.Image.new_like(input, output) - elif type(input) in {features.BoundingBox, features.SegmentationMask}: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") + elif isinstance(input, torch.Tensor): + if "interpolation" in kwargs: + kwargs["interpolation"] = kwargs["interpolation"].value + return image_tensor_kernel(input, *args, **kwargs) + elif isinstance(input, PIL.Image.Image): + if "interpolation" in kwargs: + kwargs["interpolation"] = pil_modes_mapping[kwargs["interpolation"]] + return image_pil_kernel(input, *args, **kwargs) else: return input @@ -60,8 +74,8 @@ def transform(input: Any) -> Any: return input elif transform_id == "ShearX": return dispatch( - F.affine_image, - _F.affine, + F.affine_image_tensor, + F.affine_image_pil, input, angle=0.0, translate=[0, 0], @@ -72,8 +86,8 @@ def transform(input: Any) -> Any: ) elif transform_id == "ShearY": return dispatch( - F.affine_image, - _F.affine, + F.affine_image_tensor, + F.affine_image_pil, input, angle=0.0, translate=[0, 0], @@ -84,8 +98,8 @@ def transform(input: Any) -> Any: ) elif transform_id == "TranslateX": return dispatch( - F.affine_image, - _F.affine, + F.affine_image_tensor, + F.affine_image_pil, input, angle=0.0, translate=[int(magnitude), 0], @@ -96,8 +110,8 @@ def transform(input: Any) -> Any: ) elif transform_id == "TranslateY": return dispatch( - F.affine_image, - _F.affine, + F.affine_image_tensor, + F.affine_image_pil, input, angle=0.0, translate=[0, int(magnitude)], @@ -107,29 +121,42 @@ def transform(input: Any) -> Any: fill=fill, ) elif transform_id == "Rotate": - return dispatch(F.rotate_image, _F.rotate, input, angle=magnitude) + return dispatch(F.rotate_image_tensor, F.rotate_image_pil, input, angle=magnitude) elif transform_id == "Brightness": return dispatch( - F.adjust_brightness_image, _F.adjust_brightness, input, brightness_factor=1.0 + magnitude + F.adjust_brightness_image_tensor, + F.adjust_brightness_image_pil, + input, + brightness_factor=1.0 + magnitude, ) elif transform_id == "Color": return dispatch( - F.adjust_saturation_image, _F.adjust_saturation, input, saturation_factor=1.0 + magnitude + F.adjust_saturation_image_tensor, + F.adjust_saturation_image_pil, + input, + saturation_factor=1.0 + magnitude, ) elif transform_id == "Contrast": - return dispatch(F.adjust_contrast_image, _F.adjust_contrast, input, contrast_factor=1.0 + magnitude) + return dispatch( + F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, input, contrast_factor=1.0 + magnitude + ) elif transform_id == "Sharpness": - return dispatch(F.adjust_sharpness_image, _F.adjust_sharpness, input, sharpness_factor=1.0 + magnitude) + return dispatch( + F.adjust_sharpness_image_tensor, + F.adjust_sharpness_image_pil, + input, + sharpness_factor=1.0 + magnitude, + ) elif transform_id == "Posterize": - return dispatch(F.posterize_image, _F.posterize, input, bits=int(magnitude)) + return dispatch(F.posterize_image_tensor, F.posterize_image_pil, input, bits=int(magnitude)) elif transform_id == "Solarize": - return dispatch(F.solarize_image, _F.solarize, input, threshold=magnitude) + return dispatch(F.solarize_image_tensor, F.solarize_image_pil, input, threshold=magnitude) elif transform_id == "AutoContrast": - return dispatch(F.autocontrast_image, _F.autocontrast, input) + return dispatch(F.autocontrast_image_tensor, F.autocontrast_image_pil, input) elif transform_id == "Equalize": - return dispatch(F.equalize_image, _F.equalize, input) + return dispatch(F.equalize_image_tensor, F.equalize_image_pil, input) elif transform_id == "Invert": - return dispatch(F.invert_image, _F.invert, input) + return dispatch(F.invert_image_tensor, F.invert_image_pil, input) else: raise ValueError(f"No transform available for {transform_id}") diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index e78bab22251..735add49fc1 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -6,7 +6,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F -from torchvision.transforms import functional as _F +from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from ._utils import query_image @@ -14,14 +14,16 @@ class HorizontalFlip(Transform): def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): - return _F.hflip(input) - elif type(input) is features.Image: - output = F.horizontal_flip_image(input) + if isinstance(input, features.Image): + output = F.horizontal_flip_image_tensor(input) return features.Image.new_like(input, output) - elif type(input) is features.BoundingBox: + elif isinstance(input, features.BoundingBox): output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) return features.BoundingBox.new_like(input, output) + elif isinstance(input, PIL.Image.Image): + return F.horizontal_flip_image_pil(input) + elif isinstance(input, torch.Tensor): + return F.horizontal_flip_image_tensor(input) else: return input @@ -37,17 +39,19 @@ def __init__( self.interpolation = interpolation def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): - return _F.resize(input, size=self.size, interpolation=self.interpolation) - elif type(input) is features.Image: - output = F.resize_image(input, size=self.size, interpolation=self.interpolation) + if isinstance(input, features.Image): + output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation.value) return features.Image.new_like(input, output) - elif type(input) is features.SegmentationMask: - output = F.resize_segmentation_mask(input, size=self.size) + elif isinstance(input, features.SegmentationMask): + output = F.resize_segmentation_mask(input, self.size) return features.SegmentationMask.new_like(input, output) - elif type(input) is features.BoundingBox: - output = F.resize_bounding_box(input, size=self.size, image_size=input.image_size) + elif isinstance(input, features.BoundingBox): + output = F.resize_bounding_box(input, self.size, image_size=input.image_size) return features.BoundingBox.new_like(input, output, image_size=self.size) + elif isinstance(input, PIL.Image.Image): + return F.resize_image_pil(input, self.size, interpolation=pil_modes_mapping[self.interpolation]) + elif isinstance(input, torch.Tensor): + return F.resize_image_tensor(input, self.size, interpolation=self.interpolation.value) else: return input @@ -58,13 +62,15 @@ def __init__(self, output_size: List[int]): self.output_size = output_size def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is torch.Tensor or isinstance(input, PIL.Image.Image): - return _F.center_crop(input, output_size=self.output_size) - elif type(input) is features.Image: - output = F.center_crop_image(input, **params) + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.center_crop_image_tensor(input, self.output_size) return features.Image.new_like(input, output) - elif type(input) in {features.BoundingBox, features.SegmentationMask}: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") + elif isinstance(input, torch.Tensor): + return F.center_crop_image_tensor(input, self.output_size) + elif isinstance(input, PIL.Image.Image): + return F.center_crop_image_pil(input, self.output_size) else: return input @@ -142,10 +148,20 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: - output = F.resized_crop_image(input, size=self.size, interpolation=self.interpolation, **params) + if isinstance(input, (features.BoundingBox, features.SegmentationMask)): + raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()") + elif isinstance(input, features.Image): + output = F.resized_crop_image_tensor( + input, **params, size=list(self.size), interpolation=self.interpolation.value + ) return features.Image.new_like(input, output) - elif type(input) in {features.BoundingBox, features.SegmentationMask}: - raise TypeError(f"{type(input)} is not supported by {type(self).__name__}()") + elif isinstance(input, torch.Tensor): + return F.resized_crop_image_tensor( + input, **params, size=list(self.size), interpolation=self.interpolation.value + ) + elif isinstance(input, PIL.Image.Image): + return F.resized_crop_image_pil( + input, **params, size=list(self.size), interpolation=pil_modes_mapping[self.interpolation] + ) else: return input diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta_conversion.py index e33df450ce3..edc97256c4e 100644 --- a/torchvision/prototype/transforms/_meta_conversion.py +++ b/torchvision/prototype/transforms/_meta_conversion.py @@ -1,9 +1,10 @@ -from typing import Union, Any, Dict +from typing import Union, Any, Dict, Optional +import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from torchvision.transforms import functional as _F +from torchvision.transforms.functional import convert_image_dtype class ConvertBoundingBoxFormat(Transform): @@ -28,22 +29,49 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if type(input) is features.Image: - output = _F.convert_image_dtype(input, dtype=self.dtype) + output = convert_image_dtype(input, dtype=self.dtype) return features.Image.new_like(input, output, dtype=self.dtype) else: return input -class ConvertColorSpace(Transform): - def __init__(self, color_space: Union[str, features.ColorSpace]) -> None: +class ConvertImageColorSpace(Transform): + def __init__( + self, + color_space: Union[str, features.ColorSpace], + old_color_space: Optional[Union[str, features.ColorSpace]] = None, + ) -> None: super().__init__() + if isinstance(color_space, str): color_space = features.ColorSpace[color_space] self.color_space = color_space + if isinstance(old_color_space, str): + old_color_space = features.ColorSpace[old_color_space] + self.old_color_space = old_color_space + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is features.Image: - output = F.convert_color_space(input, old_color_space=input.color_space, new_color_space=self.color_space) + if isinstance(input, features.Image): + output = F.convert_image_color_space_tensor( + input, old_color_space=input.color_space, new_color_space=self.color_space + ) return features.Image.new_like(input, output, color_space=self.color_space) + if isinstance(input, torch.Tensor): + if self.old_color_space is None: + raise RuntimeError("") + + return F.convert_image_color_space_tensor( + input, old_color_space=self.old_color_space, new_color_space=self.color_space + ) + if isinstance(input, PIL.Image.Image): + old_color_space = { + "L": features.ColorSpace.GRAYSCALE, + "RGB": features.ColorSpace.RGB, + }.get(input.mode, features.ColorSpace.OTHER) + + return F.convert_image_color_space_pil( + input, old_color_space=old_color_space, new_color_space=self.color_space + ) else: return input diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index b1b57d4d63a..5f8c5fc0336 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -2,9 +2,7 @@ from typing import Any, List, Type, Callable, Dict import torch -from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F -from torchvision.transforms import functional as _F class Identity(Transform): @@ -40,11 +38,12 @@ def __init__(self, mean: List[float], std: List[float]): self.std = std def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if type(input) is torch.Tensor: - return _F.normalize(input, mean=self.mean, std=self.std) - if type(input) is features.Image: - output = F.normalize_image(input, mean=self.mean, std=self.std) - return features.Image.new_like(input, output) + if isinstance(input, torch.Tensor): + # We don't need to differentiate between vanilla tensors and features.Image's here, since the result of the + # normalization transform is no longer a features.Image + return F.normalize_image_tensor(input, mean=self.mean, std=self.std) + else: + return input class ToDtype(Lambda): @@ -54,3 +53,7 @@ def __init__(self, dtype: torch.dtype, *types: Type) -> None: def extra_repr(self) -> str: return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"]) + + +class GaussianBlur(Transform): + pass diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index fb29703118a..c487aba7fa2 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,43 +1,66 @@ from torchvision.transforms import InterpolationMode # usort: skip from ._utils import get_image_size, get_image_num_channels # usort: skip -from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip +from ._meta_conversion import ( + convert_bounding_box_format, + convert_image_color_space_tensor, + convert_image_color_space_pil, +) # usort: skip from ._augment import ( - erase_image, - mixup_image, + erase_image_tensor, + mixup_image_tensor, mixup_one_hot_label, - cutmix_image, + cutmix_image_tensor, cutmix_one_hot_label, ) from ._color import ( - adjust_brightness_image, - adjust_contrast_image, - adjust_saturation_image, - adjust_sharpness_image, - posterize_image, - solarize_image, - autocontrast_image, - equalize_image, - invert_image, - adjust_hue_image, - adjust_gamma_image, + adjust_brightness_image_tensor, + adjust_brightness_image_pil, + adjust_contrast_image_tensor, + adjust_contrast_image_pil, + adjust_saturation_image_tensor, + adjust_saturation_image_pil, + adjust_sharpness_image_tensor, + adjust_sharpness_image_pil, + posterize_image_tensor, + posterize_image_pil, + solarize_image_tensor, + solarize_image_pil, + autocontrast_image_tensor, + autocontrast_image_pil, + equalize_image_tensor, + equalize_image_pil, + invert_image_tensor, + invert_image_pil, + adjust_hue_image_tensor, + adjust_hue_image_pil, + adjust_gamma_image_tensor, + adjust_gamma_image_pil, ) from ._geometry import ( horizontal_flip_bounding_box, - horizontal_flip_image, + horizontal_flip_image_tensor, + horizontal_flip_image_pil, resize_bounding_box, - resize_image, + resize_image_tensor, + resize_image_pil, resize_segmentation_mask, - center_crop_image, - resized_crop_image, - affine_image, - rotate_image, - pad_image, - crop_image, - perspective_image, - vertical_flip_image, - five_crop_image, - ten_crop_image, + center_crop_image_tensor, + center_crop_image_pil, + resized_crop_image_tensor, + resized_crop_image_pil, + affine_image_tensor, + affine_image_pil, + rotate_image_tensor, + rotate_image_pil, + pad_image_tensor, + pad_image_pil, + crop_image_tensor, + crop_image_pil, + perspective_image_tensor, + perspective_image_pil, + vertical_flip_image_tensor, + vertical_flip_image_pil, ) -from ._misc import normalize_image, gaussian_blur_image +from ._misc import normalize_image_tensor, gaussian_blur_image_tensor from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 526ed85ffd8..5004ac550dd 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,32 +1,32 @@ from typing import Tuple import torch -from torchvision.transforms import functional as _F +from torchvision.transforms import functional_tensor as _FT -erase_image = _F.erase +erase_image_tensor = _FT.erase -def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: +def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: input = input.clone() return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) -def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: +def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: if image_batch.ndim < 4: raise ValueError("Need a batch of images") - return _mixup(image_batch, -4, lam) + return _mixup_tensor(image_batch, -4, lam) def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: if one_hot_label_batch.ndim < 2: raise ValueError("Need a batch of one hot labels") - return _mixup(one_hot_label_batch, -2, lam) + return _mixup_tensor(one_hot_label_batch, -2, lam) -def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: +def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: if image_batch.ndim < 4: raise ValueError("Need a batch of images") @@ -42,4 +42,4 @@ def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: flo if one_hot_label_batch.ndim < 2: raise ValueError("Need a batch of one hot labels") - return _mixup(one_hot_label_batch, -2, lam_adjusted) + return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 00ed5cfbfc7..fa632d7df58 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,14 +1,34 @@ -from torchvision.transforms import functional as _F - - -adjust_brightness_image = _F.adjust_brightness -adjust_saturation_image = _F.adjust_saturation -adjust_contrast_image = _F.adjust_contrast -adjust_sharpness_image = _F.adjust_sharpness -posterize_image = _F.posterize -solarize_image = _F.solarize -autocontrast_image = _F.autocontrast -equalize_image = _F.equalize -invert_image = _F.invert -adjust_hue_image = _F.adjust_hue -adjust_gamma_image = _F.adjust_gamma +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP + +adjust_brightness_image_tensor = _FT.adjust_brightness +adjust_brightness_image_pil = _FP.adjust_brightness + +adjust_saturation_image_tensor = _FT.adjust_saturation +adjust_saturation_image_pil = _FP.adjust_saturation + +adjust_contrast_image_tensor = _FT.adjust_contrast +adjust_contrast_image_pil = _FP.adjust_contrast + +adjust_sharpness_image_tensor = _FT.adjust_sharpness +adjust_sharpness_image_pil = _FP.adjust_sharpness + +posterize_image_tensor = _FT.posterize +posterize_image_pil = _FP.posterize + +solarize_image_tensor = _FT.solarize +solarize_image_pil = _FP.solarize + +autocontrast_image_tensor = _FT.autocontrast +autocontrast_image_pil = _FP.autocontrast + +equalize_image_tensor = _FT.equalize +equalize_image_pil = _FP.equalize + +invert_image_tensor = _FT.invert +invert_image_pil = _FP.invert + +adjust_hue_image_tensor = _FT.adjust_hue +adjust_hue_image_pil = _FP.adjust_hue + +adjust_gamma_image_tensor = _FT.adjust_gamma +adjust_gamma_image_pil = _FP.adjust_gamma diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 9880ca6a685..b8978a1bb52 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,17 +1,20 @@ from typing import Tuple, List, Optional +import PIL.Image import torch from torchvision.prototype import features -from torchvision.transforms import functional as _F, InterpolationMode +from torchvision.prototype.transforms.functional import get_image_size +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP from ._meta_conversion import convert_bounding_box_format -horizontal_flip_image = _F.hflip +horizontal_flip_image_tensor = _FT.hflip +horizontal_flip_image_pil = _FP.hflip def horizontal_flip_bounding_box( - bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int] + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] ) -> torch.Tensor: shape = bounding_box.shape @@ -26,17 +29,17 @@ def horizontal_flip_bounding_box( ).view(shape) -def resize_image( +def resize_image_tensor( image: torch.Tensor, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: str = "nearest", max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size num_channels, old_height, old_width = image.shape[-3:] batch_shape = image.shape[:-3] - return _F.resize( + return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation, @@ -45,29 +48,124 @@ def resize_image( ).reshape(batch_shape + (num_channels, new_height, new_width)) +resize_image_pil = _FP.resize + + def resize_segmentation_mask( - segmentation_mask: torch.Tensor, - size: List[int], - max_size: Optional[int] = None, + segmentation_mask: torch.Tensor, size: List[int], max_size: Optional[int] = None ) -> torch.Tensor: - return resize_image(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) + return resize_image_tensor(segmentation_mask, size=size, interpolation="nearest", max_size=max_size) # TODO: handle max_size -def resize_bounding_box(bounding_box: torch.Tensor, *, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: +def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: old_height, old_width = image_size new_height, new_width = size ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) -center_crop_image = _F.center_crop -resized_crop_image = _F.resized_crop -affine_image = _F.affine -rotate_image = _F.rotate -pad_image = _F.pad -crop_image = _F.crop -perspective_image = _F.perspective -vertical_flip_image = _F.vflip -five_crop_image = _F.five_crop -ten_crop_image = _F.ten_crop +vertical_flip_image_tensor = _FT.vflip +vertical_flip_image_pil = _FP.vflip + +affine_image_tensor = _FT.affine +affine_image_pil = _FP.affine + +rotate_image_tensor = _FT.rotate +rotate_image_pil = _FP.rotate + +pad_image_tensor = _FT.pad +pad_image_pil = _FP.pad + +crop_image_tensor = _FT.crop +crop_image_pil = _FP.crop + +perspective_image_tensor = _FT.perspective +perspective_image_pil = _FP.perspective + + +import numbers + + +def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: + if isinstance(output_size, numbers.Number): + return [int(output_size), int(output_size)] + elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: + return [output_size[0], output_size[0]] + else: + return list(output_size) + + +def _center_crop_compute_padding(crop_height: int, crop_width: int, image_height: int, image_width: int) -> List[int]: + return [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + + +def _center_crop_compute_crop_anchor( + crop_height: int, crop_width: int, image_height: int, image_width: int +) -> Tuple[int, int]: + crop_top = int(round((image_height - crop_height) / 2.0)) + crop_left = int(round((image_width - crop_width) / 2.0)) + return crop_top, crop_left + + +def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + image_height, image_width = get_image_size(img) + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + img = pad_image_tensor(img, padding_ltrb, fill=0) + + image_height, image_width = get_image_size(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return crop_image_tensor(img, crop_top, crop_left, crop_height, crop_width) + + +def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: + crop_height, crop_width = _center_crop_parse_output_size(output_size) + image_height, image_width = get_image_size(img) + + if crop_height > image_height or crop_width > image_width: + padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) + img = pad_image_pil(img, padding_ltrb, fill=0) + + image_height, image_width = get_image_size(img) + if crop_width == image_width and crop_height == image_height: + return img + + crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, image_height, image_width) + return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width) + + +def resized_crop_image_tensor( + img: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: str = "bilinear", +) -> torch.Tensor: + img = crop_image_tensor(img, top, left, height, width) + return resize_image_tensor(img, size, interpolation) + + +def resized_crop_image_pil( + img: PIL.Image.Image, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: int = PIL.Image.BILINEAR, +) -> PIL.Image.Image: + img = crop_image_pil(img, top, left, height, width) + return resize_image_pil(img, size, interpolation) diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta_conversion.py index 4acaf9fe9e4..b260beaa361 100644 --- a/torchvision/prototype/transforms/functional/_meta_conversion.py +++ b/torchvision/prototype/transforms/functional/_meta_conversion.py @@ -1,6 +1,7 @@ +import PIL.Image import torch from torchvision.prototype.features import BoundingBoxFormat, ColorSpace -from torchvision.transforms.functional_tensor import rgb_to_grayscale as _rgb_to_grayscale +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor: @@ -52,18 +53,39 @@ def convert_bounding_box_format( return bounding_box -def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: +def _grayscale_to_rgb_tensor(grayscale: torch.Tensor) -> torch.Tensor: return grayscale.expand(3, 1, 1) -def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor: +def convert_image_color_space_tensor( + image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace +) -> torch.Tensor: if new_color_space == old_color_space: return image.clone() if old_color_space == ColorSpace.GRAYSCALE: - image = _grayscale_to_rgb(image) + image = _grayscale_to_rgb_tensor(image) + + if new_color_space == ColorSpace.GRAYSCALE: + image = _FT.rgb_to_grayscale(image) + + return image + + +def _grayscale_to_rgb_pil(grayscale: PIL.Image.Image) -> PIL.Image.Image: + return grayscale.convert("RGB") + + +def convert_image_color_space_pil( + image: PIL.Image.Image, old_color_space: ColorSpace, new_color_space: ColorSpace +) -> PIL.Image.Image: + if new_color_space == old_color_space: + return image.copy() + + if old_color_space == ColorSpace.GRAYSCALE: + image = _grayscale_to_rgb_pil(image) if new_color_space == ColorSpace.GRAYSCALE: - image = _rgb_to_grayscale(image) + image = _FP.to_grayscale(image) return image diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index f4e2c69c7ee..fd0507cca4d 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,5 +1,42 @@ -from torchvision.transforms import functional as _F +from typing import Optional, List +import PIL.Image +import torch +from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms.functional import to_tensor, to_pil_image -normalize_image = _F.normalize -gaussian_blur_image = _F.gaussian_blur + +normalize_image_tensor = _FT.normalize + + +def gaussian_blur_image_tensor( + img: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None +) -> torch.Tensor: + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + if len(kernel_size) != 2: + raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") + for ksize in kernel_size: + if ksize % 2 == 0 or ksize < 0: + raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}") + + if sigma is None: + sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] + + if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): + raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") + if isinstance(sigma, (int, float)): + sigma = [float(sigma), float(sigma)] + if isinstance(sigma, (list, tuple)) and len(sigma) == 1: + sigma = [sigma[0], sigma[0]] + if len(sigma) != 2: + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}") + for s in sigma: + if s <= 0.0: + raise ValueError(f"sigma should have positive values. Got {sigma}") + + return _FT.gaussian_blur(img, kernel_size, sigma) + + +def gaussian_blur_image_pil(img: PIL.Image, kernel_size: List[int], sigma: Optional[List[float]] = None) -> PIL.Image: + return to_pil_image(gaussian_blur_image_tensor(to_tensor(img), kernel_size=kernel_size, sigma=sigma)) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 2c84cf2cb55..282745f52b9 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -5,6 +5,8 @@ from torchvision.prototype import features from torchvision.transforms import functional as _F +# FIXME + def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: if type(image) is torch.Tensor or isinstance(image, PIL.Image.Image): From 7826ab35df6b276a96676f75cf5bc860c79b4749 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 25 Feb 2022 10:53:04 +0100 Subject: [PATCH 21/25] cleanup --- torchvision/prototype/datasets/_builtin/caltech.py | 2 +- torchvision/prototype/transforms/_meta_conversion.py | 4 ++-- torchvision/prototype/transforms/_misc.py | 4 ---- 3 files changed, 3 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 4c66a1e70d8..1a052860ebf 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -95,7 +95,7 @@ def _prepare_sample( bounding_box=BoundingBox( ann["box_coord"].astype(np.int64).squeeze()[[2, 0, 3, 1]], format="xyxy", image_size=image.image_size ), - contour=_Feature(ann["obj_contour"]._FT), + contour=_Feature(ann["obj_contour"].T), ) def _make_datapipe( diff --git a/torchvision/prototype/transforms/_meta_conversion.py b/torchvision/prototype/transforms/_meta_conversion.py index edc97256c4e..3675e1d8ada 100644 --- a/torchvision/prototype/transforms/_meta_conversion.py +++ b/torchvision/prototype/transforms/_meta_conversion.py @@ -57,14 +57,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: input, old_color_space=input.color_space, new_color_space=self.color_space ) return features.Image.new_like(input, output, color_space=self.color_space) - if isinstance(input, torch.Tensor): + elif isinstance(input, torch.Tensor): if self.old_color_space is None: raise RuntimeError("") return F.convert_image_color_space_tensor( input, old_color_space=self.old_color_space, new_color_space=self.color_space ) - if isinstance(input, PIL.Image.Image): + elif isinstance(input, PIL.Image.Image): old_color_space = { "L": features.ColorSpace.GRAYSCALE, "RGB": features.ColorSpace.RGB, diff --git a/torchvision/prototype/transforms/_misc.py b/torchvision/prototype/transforms/_misc.py index 5f8c5fc0336..54440ee05a5 100644 --- a/torchvision/prototype/transforms/_misc.py +++ b/torchvision/prototype/transforms/_misc.py @@ -53,7 +53,3 @@ def __init__(self, dtype: torch.dtype, *types: Type) -> None: def extra_repr(self) -> str: return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"]) - - -class GaussianBlur(Transform): - pass From ced8bcf69a0109918e7eab2524ed19e42db76d64 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 25 Feb 2022 11:01:31 +0100 Subject: [PATCH 22/25] fix image size and channels extraction --- torchvision/prototype/transforms/_augment.py | 3 +-- .../transforms/functional/_geometry.py | 3 ++- .../prototype/transforms/functional/_utils.py | 25 +++++++++++-------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 96449d67790..dd3ecd4ad47 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -55,8 +55,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: if value is not None and not (len(value) in (1, img_c)): raise ValueError( - "If value is a sequence, it should have either a single value or " - f"{image.shape[-3]} (number of input channels)" + "If value is a sequence, it should have either a single value or " f"{img_c} (number of input channels)" ) area = img_h * img_w diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index a8aa57fe94e..4a8eb2b0b53 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -40,7 +40,8 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - num_channels, old_height, old_width = image.shape[-3:] + old_height, old_width = _FT.get_image_size(image) + num_channels = _FT.get_image_num_channels(image) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 282745f52b9..25fa0cd0d52 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,27 +1,30 @@ -from typing import Tuple, Union +from typing import Tuple, Union, cast import PIL.Image import torch from torchvision.prototype import features -from torchvision.transforms import functional as _F - -# FIXME +from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: - if type(image) is torch.Tensor or isinstance(image, PIL.Image.Image): - width, height = _F.get_image_size(image) - return height, width - elif type(image) is features.Image: + if isinstance(image, features.Image): return image.image_size + elif isinstance(image, torch.Tensor): + width, height = _FT.get_image_size(image) + return height, width + if isinstance(image, PIL.Image.Image): + width, height = _FP.get_image_size(image) + return height, width else: raise TypeError(f"unable to get image size from object of type {type(image).__name__}") def get_image_num_channels(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> int: - if type(image) is torch.Tensor or isinstance(image, PIL.Image.Image): - return _F.get_image_num_channels(image) - elif type(image) is features.Image: + if isinstance(image, features.Image): return image.num_channels + elif isinstance(image, torch.Tensor): + return _FT.get_image_num_channels(image) + if isinstance(image, PIL.Image.Image): + return cast(int, _FP.get_image_num_channels(image)) else: raise TypeError(f"unable to get num channels from object of type {type(image).__name__}") From 0017807939bc5f80db5d5963fd151bab7e360086 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 25 Feb 2022 11:17:18 +0100 Subject: [PATCH 23/25] fix affine and rotate --- .../transforms/functional/_geometry.py | 104 +++++++++++++++++- 1 file changed, 98 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 4a8eb2b0b53..ec3cec73165 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -7,7 +7,7 @@ from torchvision.prototype.transforms import InterpolationMode from torchvision.prototype.transforms.functional import get_image_size from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -from torchvision.transforms.functional import pil_modes_mapping +from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix from ._meta_conversion import convert_bounding_box_format @@ -79,31 +79,120 @@ def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: vertical_flip_image_pil = _FP.vflip +def _affine_parse_args( + angle: float, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + center: Optional[List[float]] = None, +) -> Tuple[float, List[float], List[float], Optional[List[float]]]: + if not isinstance(angle, (int, float)): + raise TypeError("Argument angle should be int or float") + + if not isinstance(translate, (list, tuple)): + raise TypeError("Argument translate should be a sequence") + + if len(translate) != 2: + raise ValueError("Argument translate should be a sequence of length 2") + + if scale <= 0.0: + raise ValueError("Argument scale should be positive") + + if not isinstance(shear, (numbers.Number, (list, tuple))): + raise TypeError("Shear should be either a single value or a sequence of two values") + + if not isinstance(interpolation, InterpolationMode): + raise TypeError("Argument interpolation should be a InterpolationMode") + + if isinstance(angle, int): + angle = float(angle) + + if isinstance(translate, tuple): + translate = list(translate) + + if isinstance(shear, numbers.Number): + shear = [shear, 0.0] + + if isinstance(shear, tuple): + shear = list(shear) + + if len(shear) == 1: + shear = [shear[0], shear[0]] + + if len(shear) != 2: + raise ValueError(f"Shear should be a sequence containing two values. Got {shear}") + + if center is not None and not isinstance(center, (list, tuple)): + raise TypeError("Argument center should be a sequence") + + return angle, translate, shear, center + + def affine_image_tensor( img: torch.Tensor, - matrix: List[float], + angle: float, + translate: List[float], + scale: float, + shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, ) -> torch.Tensor: + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + center_f = [0.0, 0.0] + if center is not None: + height, width = get_image_size(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + + translate_f = [1.0 * t for t in translate] + matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) + return _FT.affine(img, matrix, interpolation=interpolation.value, fill=fill) def affine_image_pil( img: PIL.Image.Image, - matrix: List[float], + angle: float, + translate: List[float], + scale: float, + shear: List[float], interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, ) -> PIL.Image.Image: - return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) + angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) + + # center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5) + # it is visually better to estimate the center without 0.5 offset + # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine + if center is None: + height, width = get_image_size(img) + center = [width * 0.5, height * 0.5] + matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) + + return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill, center=center) def rotate_image_tensor( img: torch.Tensor, - matrix: List[float], + angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, ) -> torch.Tensor: + center_f = [0.0, 0.0] + if center is not None: + height, width = get_image_size(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] + + # due to current incoherence of rotation angle direction between affine and rotate implementations + # we need to set -angle. + matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) return _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill) @@ -113,8 +202,11 @@ def rotate_image_pil( interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, ) -> PIL.Image.Image: - return _FP.rotate(img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill) + return _FP.rotate( + img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center + ) pad_image_tensor = _FT.pad From 71e4c56bd09cf6cf6ba99190b2a26a3856f656e3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 25 Feb 2022 11:55:16 +0100 Subject: [PATCH 24/25] revert image size to (width, height) --- torchvision/prototype/transforms/_augment.py | 4 ++-- .../prototype/transforms/_auto_augment.py | 8 ++++---- torchvision/prototype/transforms/_geometry.py | 2 +- .../prototype/transforms/functional/_geometry.py | 16 ++++++++-------- .../prototype/transforms/functional/_utils.py | 9 ++++----- 5 files changed, 19 insertions(+), 20 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index dd3ecd4ad47..a0a0062ddf4 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -42,7 +42,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) img_c = F.get_image_num_channels(image) - img_h, img_w = F.get_image_size(image) + img_w, img_h = F.get_image_size(image) if isinstance(self.value, (int, float)): value = [self.value] @@ -138,7 +138,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) image = query_image(sample) - H, W = F.get_image_size(image) + W, H = F.get_image_size(image) r_x = torch.randint(W, ()) r_y = torch.randint(H, ()) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 32e862d34b0..7eae25a681e 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase): _AUGMENTATION_SPACE = { "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), @@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase): "Identity": (lambda num_bins, image_size: None, False), "ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), "ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True), - "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), - "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), "Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True), "Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), "Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True), diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 6115a4a8ebd..9032f408956 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -109,7 +109,7 @@ def __init__( def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) - height, width = F.get_image_size(image) + width, height = F.get_image_size(image) area = height * width log_ratio = torch.log(torch.tensor(self.ratio)) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ec3cec73165..d4214f791b3 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -40,7 +40,7 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: new_height, new_width = size - old_height, old_width = _FT.get_image_size(image) + old_width, old_height = _FT.get_image_size(image) num_channels = _FT.get_image_num_channels(image) batch_shape = image.shape[:-3] return _FT.resize( @@ -143,7 +143,7 @@ def affine_image_tensor( center_f = [0.0, 0.0] if center is not None: - height, width = get_image_size(img) + width, height = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -169,7 +169,7 @@ def affine_image_pil( # it is visually better to estimate the center without 0.5 offset # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine if center is None: - height, width = get_image_size(img) + width, height = get_image_size(img) center = [width * 0.5, height * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) @@ -186,7 +186,7 @@ def rotate_image_tensor( ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: - height, width = get_image_size(img) + width, height = get_image_size(img) # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))] @@ -262,13 +262,13 @@ def _center_crop_compute_crop_anchor( def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_height, image_width = get_image_size(img) + image_width, image_height = get_image_size(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_tensor(img, padding_ltrb, fill=0) - image_height, image_width = get_image_size(img) + image_width, image_height = get_image_size(img) if crop_width == image_width and crop_height == image_height: return img @@ -278,13 +278,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image: crop_height, crop_width = _center_crop_parse_output_size(output_size) - image_height, image_width = get_image_size(img) + image_width, image_height = get_image_size(img) if crop_height > image_height or crop_width > image_width: padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) img = pad_image_pil(img, padding_ltrb, fill=0) - image_height, image_width = get_image_size(img) + image_width, image_height = get_image_size(img) if crop_width == image_width and crop_height == image_height: return img diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 25fa0cd0d52..07235d63716 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -8,13 +8,12 @@ def get_image_size(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int]: if isinstance(image, features.Image): - return image.image_size + height, width = image.image_size + return width, height elif isinstance(image, torch.Tensor): - width, height = _FT.get_image_size(image) - return height, width + return cast(Tuple[int, int], tuple(_FT.get_image_size(image))) if isinstance(image, PIL.Image.Image): - width, height = _FP.get_image_size(image) - return height, width + return cast(Tuple[int, int], tuple(_FP.get_image_size(image))) else: raise TypeError(f"unable to get image size from object of type {type(image).__name__}") From 0943de0e2b1dee75b4192468475f8b15a14f2cbd Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 25 Feb 2022 12:06:35 +0000 Subject: [PATCH 25/25] Minor corrections --- torchvision/prototype/transforms/_augment.py | 2 +- torchvision/prototype/transforms/_geometry.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index a0a0062ddf4..ce198d39b33 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -55,7 +55,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: if value is not None and not (len(value) in (1, img_c)): raise ValueError( - "If value is a sequence, it should have either a single value or " f"{img_c} (number of input channels)" + f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)" ) area = img_h * img_w diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 9032f408956..4c9d9192ac8 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -34,7 +34,7 @@ def __init__( interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - self.size = [size, size] if isinstance(size, int) else list(size) + self.size = [size] if isinstance(size, int) else list(size) self.interpolation = interpolation def _transform(self, input: Any, params: Dict[str, Any]) -> Any: