diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_kernels.py similarity index 96% rename from test/test_prototype_transforms_functional.py rename to test/test_prototype_transforms_kernels.py index 53776e1a8a4..249ee76e6bc 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_kernels.py @@ -3,7 +3,7 @@ import pytest import torch.testing -import torchvision.prototype.transforms.functional as F +import torchvision.prototype.transforms.kernels as K from torch import jit from torchvision.prototype import features @@ -115,7 +115,7 @@ def __init__(self, *args, **kwargs): class KernelInfo: def __init__(self, name, *, sample_inputs_fn): self.name = name - self.kernel = getattr(F, name) + self.kernel = getattr(K, name) self._sample_inputs_fn = sample_inputs_fn def sample_inputs(self): @@ -146,7 +146,7 @@ def horizontal_flip_image(): @register_kernel_info_from_sample_inputs_fn def horizontal_flip_bounding_box(): for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): - yield SampleInput(bounding_box, image_size=bounding_box.image_size) + yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) @register_kernel_info_from_sample_inputs_fn @@ -154,8 +154,8 @@ def resize_image(): for image, interpolation in itertools.product( make_images(), [ - F.InterpolationMode.BILINEAR, - F.InterpolationMode.NEAREST, + K.InterpolationMode.BILINEAR, + K.InterpolationMode.NEAREST, ], ): height, width = image.shape[-2:] diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index fbf4522be93..1ffd1fb84dc 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -40,7 +40,7 @@ def __new__( def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.functional import convert_bounding_box_format + from torchvision.prototype.transforms.kernels import convert_bounding_box_format if isinstance(format, str): format = BoundingBoxFormat[format] diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index ed2ede62921..ea8bdeae32e 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -40,7 +40,7 @@ def image_size(self) -> Tuple[int, int]: def decode(self) -> Image: # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.functional import decode_image_with_pil + from torchvision.prototype.transforms.kernels import decode_image_with_pil return Image(decode_image_with_pil(self)) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index c44a5fdb4b6..d6d4df8486e 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,7 +1,7 @@ -from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable +from typing import Any, cast, Dict, Set, TypeVar, Union, Optional, Type, Callable, Tuple, Sequence, Mapping import torch -from torch._C import _TensorBase +from torch._C import _TensorBase, DisableTorchFunction F = TypeVar("F", bound="Feature") @@ -76,5 +76,45 @@ def new_like( _metadata.update(metadata) return cls(data, dtype=dtype or other.dtype, device=device or other.device, **_metadata) + @classmethod + def __torch_function__( + cls, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> torch.Tensor: + """For general information about how the __torch_function__ protocol works, + see https://pytorch.org/docs/stable/notes/extending.html#extending-torch + + TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the + ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the + ``args`` and ``kwargs`` of the original call. + + The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature` + use case, this has two downsides: + + 1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. + ``return cls(func(*args, **kwargs))``, will fail for them. + 2. For most operations, there is no way of knowing if the input type is still valid for the output. + + For these reasons, the automatic output wrapping is turned off for most operators. + + Exceptions to this are: + + - :func:`torch.clone` + - :meth:`torch.Tensor.to` + """ + kwargs = kwargs or dict() + with DisableTorchFunction(): + output = func(*args, **kwargs) + + if func is torch.Tensor.clone: + return cls.new_like(args[0], output) + elif func is torch.Tensor.to: + return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) + else: + return output + def __repr__(self) -> str: return cast(str, torch.Tensor.__repr__(self)).replace("tensor", type(self).__name__) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 1fe3d010b28..c9988be1930 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,4 +1,5 @@ -from . import functional -from .functional import InterpolationMode # usort: skip +from . import kernels # usort: skip +from . import functional # usort: skip +from .kernels import InterpolationMode # usort: skip from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 087f2fb2ac0..9f05f16df2d 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,27 +1,14 @@ -from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label +from ._augment import erase, mixup, cutmix from ._color import ( - adjust_brightness_image, - adjust_contrast_image, - adjust_saturation_image, - adjust_sharpness_image, - posterize_image, - solarize_image, - autocontrast_image, - equalize_image, - invert_image, + adjust_brightness, + adjust_contrast, + adjust_saturation, + adjust_sharpness, + posterize, + solarize, + autocontrast, + equalize, + invert, ) -from ._geometry import ( - horizontal_flip_bounding_box, - horizontal_flip_image, - resize_bounding_box, - resize_image, - resize_segmentation_mask, - center_crop_image, - resized_crop_image, - InterpolationMode, - affine_image, - rotate_image, -) -from ._meta_conversion import convert_color_space, convert_bounding_box_format -from ._misc import normalize_image -from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot +from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate +from ._misc import normalize diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 842ff0cd5d6..bbae796c1c9 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,52 +1,57 @@ -from typing import Tuple +from typing import TypeVar, Any import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K from torchvision.transforms import functional as _F +from ._utils import dispatch -erase_image = _F.erase +T = TypeVar("T", bound=features.Feature) -def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: - if not inplace: - input = input.clone() +@dispatch( + { + torch.Tensor: _F.erase, + features.Image: K.erase_image, + } +) +def erase(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... - input_rolled = input.roll(1, batch_dim) - return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) +@dispatch( + { + features.Image: K.mixup_image, + features.OneHotLabel: K.mixup_one_hot_label, + } +) +def mixup(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - return _mixup(image_batch, -4, lam, inplace) +@dispatch( + { + features.Image: K.cutmix_image, + features.OneHotLabel: K.cutmix_one_hot_label, + } +) +def cutmix(input: T, *args: Any, **kwargs: Any) -> T: + """Perform the CutMix operation as introduced in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. + Dispatch to the corresponding kernels happens according to this table: -def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") + .. table:: + :widths: 30 70 - return _mixup(one_hot_label_batch, -2, lam, inplace) + ==================================================== ================================================================ + :class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image` + :class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label` + ==================================================== ================================================================ - -def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - if not inplace: - image_batch = image_batch.clone() - - x1, y1, x2, y2 = box - image_rolled = image_batch.roll(1, -4) - - image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return image_batch - - -def cutmix_one_hot_label( - one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False -) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) + Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. + """ + ... diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index f2529166d9a..479b55a1b03 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,20 +1,119 @@ +from typing import TypeVar, Any + +import PIL.Image +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K from torchvision.transforms import functional as _F +from ._utils import dispatch + +T = TypeVar("T", bound=features.Feature) + + +@dispatch( + { + torch.Tensor: _F.adjust_brightness, + PIL.Image.Image: _F.adjust_brightness, + features.Image: K.adjust_brightness_image, + } +) +def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.adjust_saturation, + PIL.Image.Image: _F.adjust_saturation, + features.Image: K.adjust_saturation_image, + } +) +def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.adjust_contrast, + PIL.Image.Image: _F.adjust_contrast, + features.Image: K.adjust_contrast_image, + } +) +def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + + +@dispatch( + { + torch.Tensor: _F.adjust_sharpness, + PIL.Image.Image: _F.adjust_sharpness, + features.Image: K.adjust_sharpness_image, + } +) +def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... + -adjust_brightness_image = _F.adjust_brightness +@dispatch( + { + torch.Tensor: _F.posterize, + PIL.Image.Image: _F.posterize, + features.Image: K.posterize_image, + } +) +def posterize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -adjust_saturation_image = _F.adjust_saturation -adjust_contrast_image = _F.adjust_contrast +@dispatch( + { + torch.Tensor: _F.solarize, + PIL.Image.Image: _F.solarize, + features.Image: K.solarize_image, + } +) +def solarize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -adjust_sharpness_image = _F.adjust_sharpness -posterize_image = _F.posterize +@dispatch( + { + torch.Tensor: _F.autocontrast, + PIL.Image.Image: _F.autocontrast, + features.Image: K.autocontrast_image, + } +) +def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -solarize_image = _F.solarize -autocontrast_image = _F.autocontrast +@dispatch( + { + torch.Tensor: _F.equalize, + PIL.Image.Image: _F.equalize, + features.Image: K.equalize_image, + } +) +def equalize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -equalize_image = _F.equalize -invert_image = _F.invert +@dispatch( + { + torch.Tensor: _F.invert, + PIL.Image.Image: _F.invert, + features.Image: K.invert_image, + } +) +def invert(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index b9d396c2058..2f9f0f76e39 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,72 +1,95 @@ -from typing import Tuple, List, Optional +from typing import TypeVar, Any, cast +import PIL.Image import torch -from torchvision.transforms import ( # noqa: F401 - functional as _F, - InterpolationMode, -) - -horizontal_flip_image = _F.hflip - - -def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tuple[int, int]) -> torch.Tensor: - shape = bounding_box.shape - bounding_box = bounding_box.view(-1, 4) - bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [0, 2]] - return bounding_box.view(shape) - - -_resize_image = _F.resize - +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F -def resize_image( - image: torch.Tensor, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> torch.Tensor: - new_height, new_width = size - num_channels, old_height, old_width = image.shape[-3:] - batch_shape = image.shape[:-3] - return _resize_image( - image.reshape((-1, num_channels, old_height, old_width)), - size=size, - interpolation=interpolation, - max_size=max_size, - antialias=antialias, - ).reshape(batch_shape + (num_channels, new_height, new_width)) +from ._utils import dispatch +T = TypeVar("T", bound=features.Feature) -def resize_segmentation_mask( - segmentation_mask: torch.Tensor, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.NEAREST, - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> torch.Tensor: - return resize_image( - segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias - ) +@dispatch( + { + torch.Tensor: _F.hflip, + PIL.Image.Image: _F.hflip, + features.Image: K.horizontal_flip_image, + features.BoundingBox: None, + }, +) +def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + if isinstance(input, features.BoundingBox): + output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + return cast(T, features.BoundingBox.new_like(input, output)) + + raise RuntimeError + + +@dispatch( + { + torch.Tensor: _F.resize, + PIL.Image.Image: _F.resize, + features.Image: K.resize_image, + features.SegmentationMask: K.resize_segmentation_mask, + features.BoundingBox: None, + } +) +def resize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + if isinstance(input, features.BoundingBox): + size = kwargs.pop("size") + output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) + return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) + + raise RuntimeError + + +@dispatch( + { + torch.Tensor: _F.center_crop, + PIL.Image.Image: _F.center_crop, + features.Image: K.center_crop_image, + } +) +def center_crop(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -# TODO: handle max_size -def resize_bounding_box( - bounding_box: torch.Tensor, - *, - old_image_size: List[int], - new_image_size: List[int], -) -> torch.Tensor: - old_height, old_width = old_image_size - new_height, new_width = new_image_size - ratios = torch.tensor((new_width / old_width, new_height / old_height)) - return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) +@dispatch( + { + torch.Tensor: _F.resized_crop, + PIL.Image.Image: _F.resized_crop, + features.Image: K.resized_crop_image, + } +) +def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -center_crop_image = _F.center_crop -resized_crop_image = _F.resized_crop +@dispatch( + { + torch.Tensor: _F.affine, + PIL.Image.Image: _F.affine, + features.Image: K.affine_image, + } +) +def affine(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... -affine_image = _F.affine -rotate_image = _F.rotate +@dispatch( + { + torch.Tensor: _F.rotate, + PIL.Image.Image: _F.rotate, + features.Image: K.rotate_image, + } +) +def rotate(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index de148ab194a..40fc5894f03 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,4 +1,21 @@ +from typing import TypeVar, Any + +import torch +from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K from torchvision.transforms import functional as _F +from ._utils import dispatch + +T = TypeVar("T", bound=features.Feature) + -normalize_image = _F.normalize +@dispatch( + { + torch.Tensor: _F.normalize, + features.Image: K.normalize_image, + } +) +def normalize(input: T, *args: Any, **kwargs: Any) -> T: + """ADDME""" + ... diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py new file mode 100644 index 00000000000..eb44b3421bf --- /dev/null +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -0,0 +1,89 @@ +import functools +import inspect +from typing import Any, Optional, Callable, TypeVar, Dict + +import torch +import torch.overrides +from torchvision.prototype import features + +F = TypeVar("F", bound=features.Feature) + + +def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]: + """Decorates a function to automatically dispatch to registered kernels based on the call arguments. + + The dispatch function should have this signature + + .. code:: python + + @dispatch( + ... + ) + def dispatch_fn(input, *args, **kwargs): + ... + + where ``input`` is used to determine which kernel to dispatch to. + + Args: + kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for + exact type matches first and if none is found falls back to checking for subclasses. If a value is + ``None``, the decorated function is called. + + Raises: + TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``. + TypeError: If the decorated function is called with an input that cannot be dispatched. + """ + + def check_kernel(kernel: Any) -> bool: + if kernel is None: + return True + + if not callable(kernel): + return False + + params = list(inspect.signature(kernel).parameters.values()) + if not params: + return False + + return params[0].kind != inspect.Parameter.KEYWORD_ONLY + + for feature_type, kernel in kernels.items(): + if not check_kernel(kernel): + raise TypeError( + f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)." + ) + + def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]: + @functools.wraps(dispatch_fn) + def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: + feature_type = type(input) + try: + kernel = kernels[feature_type] + except KeyError: + try: + feature_type, kernel = next( + (feature_type, kernel) + for feature_type, kernel in kernels.items() + if isinstance(input, feature_type) + ) + except StopIteration: + raise TypeError(f"No support for {type(input).__name__}") from None + + if kernel is None: + output = dispatch_fn(input, *args, **kwargs) + if output is None: + raise RuntimeError( + f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} " + f"although it was configured to do so." + ) + else: + output = kernel(input, *args, **kwargs) + + if issubclass(feature_type, features.Feature) and type(output) is torch.Tensor: + output = feature_type.new_like(input, output) + + return output + + return inner_wrapper + + return outer_wrapper diff --git a/torchvision/prototype/transforms/kernels/__init__.py b/torchvision/prototype/transforms/kernels/__init__.py new file mode 100644 index 00000000000..6f74f6af0e9 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/__init__.py @@ -0,0 +1,34 @@ +from torchvision.transforms import InterpolationMode # usort: skip +from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip + +from ._augment import ( + erase_image, + mixup_image, + mixup_one_hot_label, + cutmix_image, + cutmix_one_hot_label, +) +from ._color import ( + adjust_brightness_image, + adjust_contrast_image, + adjust_saturation_image, + adjust_sharpness_image, + posterize_image, + solarize_image, + autocontrast_image, + equalize_image, + invert_image, +) +from ._geometry import ( + horizontal_flip_bounding_box, + horizontal_flip_image, + resize_bounding_box, + resize_image, + resize_segmentation_mask, + center_crop_image, + resized_crop_image, + affine_image, + rotate_image, +) +from ._misc import normalize_image +from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/kernels/_augment.py b/torchvision/prototype/transforms/kernels/_augment.py new file mode 100644 index 00000000000..526ed85ffd8 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_augment.py @@ -0,0 +1,45 @@ +from typing import Tuple + +import torch +from torchvision.transforms import functional as _F + + +erase_image = _F.erase + + +def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: + input = input.clone() + return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) + + +def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + return _mixup(image_batch, -4, lam) + + +def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup(one_hot_label_batch, -2, lam) + + +def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + x1, y1, x2, y2 = box + image_rolled = image_batch.roll(1, -4) + + image_batch = image_batch.clone() + image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return image_batch + + +def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup(one_hot_label_batch, -2, lam_adjusted) diff --git a/torchvision/prototype/transforms/kernels/_color.py b/torchvision/prototype/transforms/kernels/_color.py new file mode 100644 index 00000000000..0d828e6d169 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_color.py @@ -0,0 +1,12 @@ +from torchvision.transforms import functional as _F + + +adjust_brightness_image = _F.adjust_brightness +adjust_saturation_image = _F.adjust_saturation +adjust_contrast_image = _F.adjust_contrast +adjust_sharpness_image = _F.adjust_sharpness +posterize_image = _F.posterize +solarize_image = _F.solarize +autocontrast_image = _F.autocontrast +equalize_image = _F.equalize +invert_image = _F.invert diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py new file mode 100644 index 00000000000..c3cbbb34b02 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -0,0 +1,76 @@ +from typing import Tuple, List, Optional, TypeVar + +import torch +from torchvision.prototype import features +from torchvision.transforms import functional as _F, InterpolationMode + +from ._meta_conversion import convert_bounding_box_format + + +T = TypeVar("T", bound=features.Feature) + + +horizontal_flip_image = _F.hflip + + +def horizontal_flip_bounding_box( + bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]] + + return convert_bounding_box_format( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format + ).view(shape) + + +def resize_image( + image: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + new_height, new_width = size + num_channels, old_height, old_width = image.shape[-3:] + batch_shape = image.shape[:-3] + return _F.resize( + image.reshape((-1, num_channels, old_height, old_width)), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ).reshape(batch_shape + (num_channels, new_height, new_width)) + + +def resize_segmentation_mask( + segmentation_mask: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + return resize_image( + segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + + +# TODO: handle max_size +def resize_bounding_box( + bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int] +) -> torch.Tensor: + old_height, old_width = old_image_size + new_height, new_width = new_image_size + ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) + return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) + + +center_crop_image = _F.center_crop +resized_crop_image = _F.resized_crop +affine_image = _F.affine +rotate_image = _F.rotate diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/kernels/_meta_conversion.py similarity index 96% rename from torchvision/prototype/transforms/functional/_meta_conversion.py rename to torchvision/prototype/transforms/kernels/_meta_conversion.py index 484066a39ee..4acaf9fe9e4 100644 --- a/torchvision/prototype/transforms/functional/_meta_conversion.py +++ b/torchvision/prototype/transforms/kernels/_meta_conversion.py @@ -37,7 +37,7 @@ def convert_bounding_box_format( bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat ) -> torch.Tensor: if new_format == old_format: - return bounding_box + return bounding_box.clone() if old_format == BoundingBoxFormat.XYWH: bounding_box = _xywh_to_xyxy(bounding_box) @@ -58,7 +58,7 @@ def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor: if new_color_space == old_color_space: - return image + return image.clone() if old_color_space == ColorSpace.GRAYSCALE: image = _grayscale_to_rgb(image) diff --git a/torchvision/prototype/transforms/kernels/_misc.py b/torchvision/prototype/transforms/kernels/_misc.py new file mode 100644 index 00000000000..de148ab194a --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_misc.py @@ -0,0 +1,4 @@ +from torchvision.transforms import functional as _F + + +normalize_image = _F.normalize diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/kernels/_type_conversion.py similarity index 100% rename from torchvision/prototype/transforms/functional/_type_conversion.py rename to torchvision/prototype/transforms/kernels/_type_conversion.py