From 3f6982e0048a18293c3eb4b20f4475586137e105 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 25 Jan 2022 17:01:24 +0100 Subject: [PATCH 01/32] add auto dispatch --- scripts/regenerate_transform_dispatch.py | 578 ++++++++++++++++++ torchvision/prototype/features/_feature.py | 48 +- .../transforms/functional/__init__.py | 4 + .../transforms/functional/_dispatch.py | 448 ++++++++++++++ .../transforms/functional/_geometry.py | 5 +- .../transforms/functional/dispatch.yaml | 164 +++++ .../prototype/transforms/functional/utils.py | 72 +++ 7 files changed, 1313 insertions(+), 6 deletions(-) create mode 100644 scripts/regenerate_transform_dispatch.py create mode 100644 torchvision/prototype/transforms/functional/_dispatch.py create mode 100644 torchvision/prototype/transforms/functional/dispatch.yaml create mode 100644 torchvision/prototype/transforms/functional/utils.py diff --git a/scripts/regenerate_transform_dispatch.py b/scripts/regenerate_transform_dispatch.py new file mode 100644 index 00000000000..edc3bcf01b0 --- /dev/null +++ b/scripts/regenerate_transform_dispatch.py @@ -0,0 +1,578 @@ +import contextlib +import enum +import importlib +import inspect +import pathlib +import re +import sys +import typing +import warnings +from copy import copy +from typing import Any + +import torchvision.prototype.transforms.functional as F +from torchvision import transforms +from torchvision.prototype import features + +try: + import yaml +except ModuleNotFoundError: + raise ModuleNotFoundError() + + +ENUMS = [ + (features, ["BoundingBoxFormat", "ColorSpace"]), + (transforms, ["InterpolationMode"]), +] + +ENUMS_MAP = {name: getattr(module, name) for module, names in ENUMS for name in names} + +META_CONVERTER_MAP = { + (features.Image, "color_space"): F.convert_color_space, + (features.BoundingBox, "format"): F.convert_bounding_box_format, +} + + +class ManualAnnotation: + def __init__(self, repr): + self.repr = repr + + def __repr__(self): + return self.repr + + def __eq__(self, other): + if not isinstance(other, ManualAnnotation): + return NotImplemented + + return self.repr == other.repr + + +# TODO: typing module +FEATURE_SPECIFIC_PARAM = ManualAnnotation("Dispatcher.FEATURE_SPECIFIC_PARAM") +FEATURE_SPECIFIC_DEFAULT = ManualAnnotation("FEATURE_SPECIFIC_DEFAULT") +GENERIC_FEATURE_TYPE = ManualAnnotation("T") + + +def main(dispatch_config): + functions = [] + for dispatcher_name, feature_type_configs in dispatch_config.items(): + try: + feature_type_configs = validate_feature_type_configs(feature_type_configs) + kernel_params, implementer_params = make_kernel_and_implementer_params(feature_type_configs) + dispatcher_params = make_dispatcher_params(implementer_params) + except Exception as error: + raise RuntimeError( + f"while working on dispatcher '{dispatcher_name}' the following error was raised:\n\n" + f"{type(error).__name__}: {error}" + ) from None + + functions.append(DispatcherFunction(name=dispatcher_name, params=dispatcher_params)) + functions.extend( + [ + ImplementerFunction( + dispatcher_name=dispatcher_name, + feature_type=feature_type, + params=implementer_params[feature_type], + pil_kernel=config.get("pil_kernel"), + kernel=config["kernel"], + kernel_params=kernel_params[feature_type], + conversion_map=config["meta_conversion"], + kernel_param_name_map=config["kwargs_overwrite"], + meta_overwrite=config["meta_overwrite"], + ) + for feature_type, config in feature_type_configs.items() + ] + ) + + return ufmt_format(make_file_content(functions)) + + +def validate_feature_type_configs(feature_type_configs): + try: + feature_type_configs = { + getattr(features, feature_type_name): config for feature_type_name, config in feature_type_configs.items() + } + except AttributeError: + # unknown feature type + raise TypeError() from None + + for feature_type, config in tuple(feature_type_configs.items()): + if not isinstance(config, dict): + feature_type_configs[feature_type] = config = dict(kernel=config) + + unknown_keys = config.keys() - { + "kernel", + "pil_kernel", + "meta_conversion", + "kwargs_overwrite", + "meta_overwrite", + } + if unknown_keys: + raise KeyError(unknown_keys) + + try: + config["kernel"] = getattr(F, config["kernel"]) + except KeyError: + # no kernel provided + raise + except AttributeError: + # kernel not accessible + raise + + if "pil_kernel" in config and feature_type is not features.Image: + raise TypeError + + for key in ["meta_conversion", "kwargs_overwrite", "meta_overwrite"]: + if key not in config: + config[key] = dict() + continue + + for meta_attr, value in tuple(config[key].items()): + # if meta_attr not in feature_type._META_ATTRS: + # raise KeyError(meta_attr) + + config[key][meta_attr] = maybe_convert_to_enum(value) + + # TODO: bunchify the individual configs + return feature_type_configs + + +def make_kernel_and_implementer_params(feature_type_configs): + kernel_params = {} + implementer_params = {} + for feature_type, config in feature_type_configs.items(): + kernel_params[feature_type] = [ + Parameter.from_regular(param) for param in list(inspect.signature(config["kernel"]).parameters.values())[1:] + ] + implementer_params[feature_type] = [ + Parameter( + name=config["kwargs_overwrite"].get(kernel_param.name, kernel_param.name), + kind=inspect.Parameter.KEYWORD_ONLY, + default=kernel_param.default, + annotation=kernel_param.annotation, + ) + for kernel_param in kernel_params[feature_type] + if not config["kwargs_overwrite"].get(kernel_param.name, "").startswith(".") + ] + return kernel_params, implementer_params + + +def make_dispatcher_params(implementer_params): + # not using a set here to keep the order + dispatcher_param_names = [] + for params in implementer_params.values(): + dispatcher_param_names.extend([param.name for param in params]) + dispatcher_param_names = unique(dispatcher_param_names) + + dispatcher_params = [] + need_kwargs_ignore = set() + for name in dispatcher_param_names: + dispatcher_param_candidates = {} + for feature_type, params in implementer_params.items(): + params = {param.name: param for param in params} + if name not in params: + need_kwargs_ignore.add(feature_type) + continue + else: + dispatcher_param_candidates[feature_type] = params[name] + + if len(dispatcher_param_candidates) == 1: + param = next(iter(dispatcher_param_candidates.values())) + if len(implementer_params) == 1: + dispatcher_params.append(copy(param)) + else: + param._default = FEATURE_SPECIFIC_PARAM + dispatcher_params.append( + Parameter( + name=name, + kind=Parameter.KEYWORD_ONLY, + default=FEATURE_SPECIFIC_PARAM, + annotation=param.annotation, + ) + ) + continue + + annotations = {param.annotation for param in dispatcher_param_candidates.values()} + if len(annotations) > 1: + raise TypeError( + f"Found multiple annotations for parameter `{name}`: " + f"{', '.join([str(annotation) for annotation in annotations])}" + ) + + defaults = {param.default for param in dispatcher_param_candidates.values()} + default = FEATURE_SPECIFIC_DEFAULT if len(defaults) > 1 else defaults.pop() + + dispatcher_params.append( + Parameter( + name=name, + kind=Parameter.KEYWORD_ONLY, + default=default, + annotation=annotations.pop(), + ) + ) + + without_default = [] + with_default = [] + for param in dispatcher_params: + (without_default if param.default in (Parameter.empty, FEATURE_SPECIFIC_PARAM) else with_default).append(param) + dispatcher_params = [*without_default, *with_default] + + for feature_type in need_kwargs_ignore: + implementer_params[feature_type].append(Parameter(name="_", kind=Parameter.VAR_KEYWORD, annotation=Any)) + + return dispatcher_params + + +def make_file_content(functions): + enums = "\n".join(f"from {module.__package__} import {', '.join(names)}" for module, names in ENUMS) + + header = f""" +# THIS FILE IS AUTOGENERATED +# +# FROM torchvision/prototype/transforms/functional/dispatch.yaml +# WITH scripts/regenerate_transforms_dispatch.py +# +# DO NOT CHANGE MANUALLY! + +from typing import Any, TypeVar, List, Optional, Tuple + +import torch +import torchvision.transforms.functional as _F +import torchvision.prototype.transforms.functional as F +from torchvision.prototype import features +{enums} + +Dispatcher = F.utils.Dispatcher + +# This is just a sentinel to have a default argument for a dispatcher if the feature specific implementations use +# different defaults. The actual value is never used. +{FEATURE_SPECIFIC_DEFAULT} = object() + +{GENERIC_FEATURE_TYPE} = TypeVar("{GENERIC_FEATURE_TYPE}", bound=features.Feature) +""" + header = "\n".join(line.strip() for line in header.splitlines()) + + __all__ = "\n".join( + ( + "__all__ = [", + *[ + indent(f"{format_value(function.name)},") + for function in functions + if isinstance(function, DispatcherFunction) + ], + "]", + ) + ) + return ( + "\n\n\n".join( + ( + header, + __all__, + *[str(function) for function in functions], + ) + ) + + "\n" + ) + + +class Parameter(inspect.Parameter): + @classmethod + def from_regular(cls, param): + return cls(param.name, param.kind, default=param.default, annotation=param.annotation) + + def __str__(self): + @contextlib.contextmanager + def tmp_override(**tmp_values): + values = {name: getattr(self, name) for name in tmp_values} + for name, tmp_value in tmp_values.items(): + setattr(self, f"_{name}", tmp_value) + try: + yield + finally: + for name, value in values.items(): + setattr(self, f"_{name}", value) + + tmp_values = dict() + + if isinstance(self.default, enum.Enum): + tmp_values["default"] = ManualAnnotation(format_value(self.default)) + + # OPtional only has one + # check docs ther ewas something about checking in the patch notes maybe? + if ( + hasattr(self.annotation, "__origin__") + and self.annotation.__origin__ is typing.Union + and type(None) in self.annotation.__args__ + ): + annotations = [ + inspect.formatannotation(arg) for arg in self.annotation.__args__ if arg is not type(None) # noqa: E721 + ] + tmp_values["annotation"] = ManualAnnotation(f"Optional[{', '.join(annotations)}]") + elif isinstance(self.annotation, enum.EnumMeta): + tmp_values["annotation"] = ManualAnnotation(self.annotation.__name__) + + with tmp_override(**tmp_values): + return super().__str__() + + +class Signature(inspect.Signature): + def __str__(self): + text = super().__str__() + for separator in [FEATURE_SPECIFIC_PARAM, FEATURE_SPECIFIC_DEFAULT]: + parts = text.split(repr(separator)) + text = f"{separator}, # type: ignore[assignment]\n".join( + [ + parts[0], + *[part.lstrip(",") for part in parts[1:]], + ] + ) + return text + + +class Function: + def __init__(self, *, decorator=None, name, signature, docstring=None, body=("pass",)): + self.decorator = decorator + self.name = name + self.signature = signature + self.docstring = docstring + self.body = body + + def __str__(self): + lines = [] + if self.decorator: + lines.append(f"@{self.decorator}") + lines.append(f"def {self.name}{self.signature}:") + if self.docstring: + lines.append(indent('"""' + self.docstring + '"""')) + lines.extend([indent(line) for line in self.body]) + return "\n".join(lines) + + +class DispatcherFunction(Function): + def __init__(self, *, name, params, input_name="input"): + for param in params: + param._kind = Parameter.KEYWORD_ONLY + signature = Signature( + parameters=[ + Parameter( + name=input_name, + kind=Parameter.POSITIONAL_OR_KEYWORD, + annotation=GENERIC_FEATURE_TYPE, + ), + *params, + ], + return_annotation=GENERIC_FEATURE_TYPE, + ) + super().__init__( + decorator="Dispatcher", + name=name, + signature=signature, + docstring="ADDME", + ) + + +class ImplementerFunction(Function): + def __init__( + self, + *, + dispatcher_name, + feature_type, + params, + pil_kernel, + kernel, + kernel_params, + conversion_map, + kernel_param_name_map, + meta_overwrite, + input_name="input", + ): + feature_type_usage = ManualAnnotation(f"features.{feature_type.__name__}") + + body = [] + + feature_specific_params = [] + for param in params: + if param.default is FEATURE_SPECIFIC_PARAM: + feature_specific_params.append(param.name) + param._default = Parameter.empty + + output_conversions = [] + for idx, (attr, intermediate_value) in enumerate(conversion_map.items()): + + converter = META_CONVERTER_MAP[(feature_type, attr)] + + def make_conversion_call(input, old, new): + return f"F.{converter.__name__}({input}, old_{attr}={old}, new_{attr}={new})" + + input_attr = f"input.{attr}" + intermediate_name = f"intermediate_{attr}" + body.extend( + [ + f"{intermediate_name} = {format_value(intermediate_value)}", + f"converted_input = {make_conversion_call(input_name, input_attr, intermediate_name)}", + "", + ] + ) + if idx == 0: + input_name = "converted_input" + + output_conversions = [f"output = {make_conversion_call('output', intermediate_name, input_attr)}"] + + kernel_call = self._make_kernel_call( + input_name=input_name, + kernel=kernel, + kernel_params=kernel_params, + kernel_param_name_map=kernel_param_name_map, + ) + body.extend( + [ + f"output = {kernel_call}", + *reversed(output_conversions), + "", + ] + ) + + feature_type_wrapper = self._make_feature_type_wrapper( + feature_type_usage=feature_type_usage, + meta_overwrite=meta_overwrite, + ) + body.append(f"return {feature_type_wrapper}") + + super().__init__( + decorator=self._make_decorator( + dispatcher_name=dispatcher_name, + feature_type_usage=feature_type_usage, + feature_specific_params=feature_specific_params, + pil_kernel=pil_kernel, + ), + name=f"_{dispatcher_name}_{camel_to_snake_case(feature_type.__name__)}", + signature=Signature( + parameters=[ + Parameter( + name="input", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=feature_type_usage, + ), + *params, + ], + return_annotation=feature_type_usage, + ), + body=body, + ) + + def _make_decorator(self, *, dispatcher_name, feature_type_usage, feature_specific_params, pil_kernel): + decorator = f"{dispatcher_name}.implements({feature_type_usage}" + if feature_specific_params: + decorator += f", feature_specific_params={tuple(feature_specific_params)}" + if pil_kernel: + decorator += f", pil_kernel=_F.{pil_kernel}" + return f"{decorator})" + + def _make_kernel_call( + self, + *, + kernel, + input_name, + kernel_params, + kernel_param_name_map, + ): + call_args = [input_name] + for param in kernel_params: + dispatcher_param_name = kernel_param_name_map.get(param.name, param.name) + if dispatcher_param_name.startswith("."): + dispatcher_param_name = f"input{dispatcher_param_name}" + call_args.append(f"{param.name}={dispatcher_param_name}") + return f"F.{kernel.__name__}({', '.join(call_args)})" + + def _make_feature_type_wrapper(self, *, feature_type_usage, meta_overwrite): + call_args = ["input", "output"] + call_args.extend( + f"{meta_name}={format_value(dispatcher_param_name)}" + for meta_name, dispatcher_param_name in meta_overwrite.items() + ) + return f"{feature_type_usage}.new_like({', '.join(call_args)})" + + +def ufmt_format(content): + try: + import ufmt + except ModuleNotFoundError: + return content + + HERE = pathlib.Path(__file__).parent + + with open(HERE.parent / ".pre-commit-config.yaml") as file: + repo = next( + repo for repo in yaml.load(file, yaml.Loader)["repos"] for hook in repo["hooks"] if hook["id"] == "ufmt" + ) + + expected_versions = {ufmt: repo["rev"].replace("v", "")} + for dependency in repo["hooks"][0]["additional_dependencies"]: + name, version = [item.strip() for item in dependency.split("==")] + expected_versions[importlib.import_module(name)] = version + + for module, expected_version in expected_versions.items(): + if module.__version__ != expected_version: + warnings.warn("foo") + + from ufmt.core import make_black_config + from usort.config import Config as UsortConfig + + black_config = make_black_config(HERE) + usort_config = UsortConfig.find(HERE) + + return ufmt.ufmt_string(path=HERE, content=content, usort_config=usort_config, black_config=black_config) + + +def maybe_convert_to_enum(value): + if not isinstance(value, str): + return value + + parts = value.split(".") + if len(parts) != 2: + return value + + enum, member = parts + + try: + return ENUMS_MAP[enum][member] + except KeyError: + return value + + +def indent(text, level=1): + return "\n".join(" " * (level * 4) + line for line in text.splitlines()) + + +def camel_to_snake_case(camel_case: str) -> str: + return "_".join([part.lower() for part in re.findall("[A-Z][^A-Z]*", camel_case)]) + + +def format_value(value): + if isinstance(value, str): + return f'"{value}"' + elif isinstance(value, enum.Enum): + return f"{type(value).__name__}.{value.name}" + else: + return repr(value) + + +def unique(seq): + unique_seq = [] + for item in seq: + if item not in unique_seq: + unique_seq.append(item) + return unique_seq + + +if __name__ == "__main__": + try: + with open(pathlib.Path(F.__path__[0]) / "dispatch.yaml") as file: + dispatch_config = yaml.load(file, yaml.Loader) + content = main(dispatch_config) + with open(pathlib.Path(F.__path__[0]) / "_dispatch.py", "w") as file: + file.write(content) + except Exception as error: + msg = str(error) + print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr) + sys.exit(1) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 38fff2da04a..bc059bf7142 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -1,7 +1,7 @@ -from typing import Any, cast, Dict, Set, TypeVar +from typing import Any, Callable, cast, Dict, Mapping, Optional, Sequence, Set, Tuple, Type, TypeVar import torch -from torch._C import _TensorBase +from torch._C import _TensorBase, DisableTorchFunction F = TypeVar("F", bound="Feature") @@ -10,6 +10,7 @@ class Feature(torch.Tensor): _META_ATTRS: Set[str] = set() _metadata: Dict[str, Any] + _KERNELS: Dict[Callable, Callable] def __init_subclass__(cls): # In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes @@ -37,6 +38,8 @@ def __init_subclass__(cls): for name in meta_attrs: setattr(cls, name, property(lambda self, name=name: self._metadata[name])) + cls._KERNELS = {} + def __new__(cls, data, *, dtype=None, device=None): feature = torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -57,5 +60,46 @@ def new_like(cls, other, data, *, dtype=None, device=None, **metadata): metadata.setdefault(name, getattr(other, name)) return cls(data, dtype=dtype or other.dtype, device=device or other.device, **metadata) + _TORCH_FUNCTION_ALLOW_MAP = { + torch.Tensor.clone: (0,), + torch.stack: (0, 0), + torch.Tensor.to: (0,), + } + + _DTYPE_CONVERTERS = { + torch.Tensor.to, + } + + _DEVICE_CONVERTERS = { + torch.Tensor.to, + } + + @classmethod + def __torch_function__( + cls, + func: Callable[..., torch.Tensor], + types: Tuple[Type[torch.Tensor], ...], + args: Sequence[Any] = (), + kwargs: Optional[Mapping[str, Any]] = None, + ) -> torch.Tensor: + kwargs = kwargs or dict() + if cls is not Feature and func in cls._KERNELS: + return cls._KERNELS[func](*args, **kwargs) + + with DisableTorchFunction(): + output = func(*args, **kwargs) + + if func not in cls._TORCH_FUNCTION_ALLOW_MAP: + return output + + other = args + for item in cls._TORCH_FUNCTION_ALLOW_MAP[func]: + other = other[item] + + dtype = output.dtype if func in cls._DTYPE_CONVERTERS else None + device = output.device if func in cls._DTYPE_CONVERTERS else None + + return cls.new_like(other, output, dtype=dtype, device=device) + def __repr__(self): return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 087f2fb2ac0..37a8096ff7a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,3 +1,5 @@ +from . import utils # usort: skip + from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label from ._color import ( adjust_brightness_image, @@ -25,3 +27,5 @@ from ._meta_conversion import convert_color_space, convert_bounding_box_format from ._misc import normalize_image from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot + +from ._dispatch import * # usort: skip diff --git a/torchvision/prototype/transforms/functional/_dispatch.py b/torchvision/prototype/transforms/functional/_dispatch.py new file mode 100644 index 00000000000..7aa5bf5d752 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_dispatch.py @@ -0,0 +1,448 @@ +# THIS FILE IS AUTOGENERATED +# +# FROM torchvision/prototype/transforms/functional/dispatch.yaml +# WITH scripts/regenerate_transforms_dispatch.py +# +# DO NOT CHANGE MANUALLY! + +from typing import Any, TypeVar, List, Optional, Tuple + +import torch +import torchvision.prototype.transforms.functional as F +import torchvision.transforms.functional as _F +from torchvision.prototype import features +from torchvision.prototype.features import BoundingBoxFormat, ColorSpace +from torchvision.transforms import InterpolationMode + +Dispatcher = F.utils.Dispatcher + +# This is just a sentinel to have a default argument for a dispatcher if the feature specific implementations use +# different defaults. The actual value is never used. +FEATURE_SPECIFIC_DEFAULT = object() + +T = TypeVar("T", bound=features.Feature) + + +__all__ = [ + "horizontal_flip", + "resize", + "center_crop", + "normalize", + "resized_crop", + "erase", + "mixup", + "cutmix", + "affine", + "rotate", + "adjust_brightness", + "adjust_saturation", + "adjust_contrast", + "adjust_sharpness", + "posterize", + "solarize", + "autocontrast", + "equalize", + "invert", +] + + +@Dispatcher +def horizontal_flip(input: T) -> T: + """ADDME""" + pass + + +@horizontal_flip.implements(features.Image) +def _horizontal_flip_image(input: features.Image) -> features.Image: + output = F.horizontal_flip_image(input) + + return features.Image.new_like(input, output) + + +@horizontal_flip.implements(features.BoundingBox) +def _horizontal_flip_bounding_box(input: features.BoundingBox) -> features.BoundingBox: + intermediate_format = BoundingBoxFormat.XYXY + converted_input = F.convert_bounding_box_format(input, old_format=input.format, new_format=intermediate_format) + + output = F.horizontal_flip_bounding_box(converted_input, image_size=input.image_size) + output = F.convert_bounding_box_format(output, old_format=intermediate_format, new_format=input.format) + + return features.BoundingBox.new_like(input, output) + + +@Dispatcher +def resize( + input: T, + *, + size: List[int], + interpolation: InterpolationMode = FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> T: + """ADDME""" + pass + + +@resize.implements(features.Image, pil_kernel=_F.resize) +def _resize_image( + input: features.Image, + *, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> features.Image: + output = F.resize_image(input, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) + + return features.Image.new_like(input, output) + + +@resize.implements(features.BoundingBox) +def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: + intermediate_format = BoundingBoxFormat.XYXY + converted_input = F.convert_bounding_box_format(input, old_format=input.format, new_format=intermediate_format) + + output = F.resize_bounding_box(converted_input, old_image_size=input.image_size, new_image_size=size) + output = F.convert_bounding_box_format(output, old_format=intermediate_format, new_format=input.format) + + return features.BoundingBox.new_like(input, output, image_size="size") + + +@resize.implements(features.SegmentationMask) +def _resize_segmentation_mask( + input: features.SegmentationMask, + *, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> features.SegmentationMask: + output = F.resize_segmentation_mask( + input, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + + return features.SegmentationMask.new_like(input, output) + + +@Dispatcher +def center_crop(input: T, *, output_size: List[int]) -> T: + """ADDME""" + pass + + +@center_crop.implements(features.Image, pil_kernel=_F.center_crop) +def _center_crop_image(input: features.Image, *, output_size: List[int]) -> features.Image: + output = F.center_crop(input, output_size=output_size) + + return features.Image.new_like(input, output) + + +@Dispatcher +def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: + """ADDME""" + pass + + +@normalize.implements(features.Image) +def _normalize_image( + input: features.Image, *, mean: List[float], std: List[float], inplace: bool = False +) -> features.Image: + output = F.normalize(input, mean=mean, std=std, inplace=inplace) + + return features.Image.new_like(input, output, color_space=ColorSpace.OTHER) + + +@Dispatcher +def resized_crop( + input: T, + *, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, +) -> T: + """ADDME""" + pass + + +@resized_crop.implements(features.Image, pil_kernel=_F.resized_crop) +def _resized_crop_image( + input: features.Image, + *, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, +) -> features.Image: + output = F.resized_crop( + input, top=top, left=left, height=height, width=width, size=size, interpolation=interpolation + ) + + return features.Image.new_like(input, output) + + +@Dispatcher +def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> T: + """ADDME""" + pass + + +@erase.implements(features.Image, pil_kernel=_F.erase) +def _erase_image( + input: features.Image, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> features.Image: + output = F.erase(input, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + + return features.Image.new_like(input, output) + + +@Dispatcher +def mixup(input: T, *, lam: float, inplace: bool = False) -> T: + """ADDME""" + pass + + +@mixup.implements(features.Image) +def _mixup_image(input: features.Image, *, lam: float, inplace: bool = False) -> features.Image: + output = F.mixup_image(input, lam=lam, inplace=inplace) + + return features.Image.new_like(input, output) + + +@mixup.implements(features.OneHotLabel) +def _mixup_one_hot_label(input: features.OneHotLabel, *, lam: float, inplace: bool = False) -> features.OneHotLabel: + output = F.mixup_one_hot_label(input, lam=lam, inplace=inplace) + + return features.OneHotLabel.new_like(input, output) + + +@Dispatcher +def cutmix( + input: T, + *, + box: Tuple[int, int, int, int] = Dispatcher.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + lam_adjusted: float = Dispatcher.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + inplace: bool = False, +) -> T: + """ADDME""" + pass + + +@cutmix.implements(features.Image, feature_specific_params=("box",)) +def _cutmix_image( + input: features.Image, *, box: Tuple[int, int, int, int], inplace: bool = False, **_: Any +) -> features.Image: + output = F.cutmix_image(input, box=box, inplace=inplace) + + return features.Image.new_like(input, output) + + +@cutmix.implements(features.OneHotLabel, feature_specific_params=("lam_adjusted",)) +def _cutmix_one_hot_label( + input: features.OneHotLabel, *, lam_adjusted: float, inplace: bool = False, **_: Any +) -> features.OneHotLabel: + output = F.cutmix_one_hot_label(input, lam_adjusted=lam_adjusted, inplace=inplace) + + return features.OneHotLabel.new_like(input, output) + + +@Dispatcher +def affine( + input: T, + *, + angle: float, + translate: List[int], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, + fillcolor: Optional[List[float]] = None, + center: Optional[List[int]] = None, +) -> T: + """ADDME""" + pass + + +@affine.implements(features.Image, pil_kernel=_F.affine) +def _affine_image( + input: features.Image, + *, + angle: float, + translate: List[int], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, + fillcolor: Optional[List[float]] = None, + center: Optional[List[int]] = None, +) -> features.Image: + output = F.affine( + input, + angle=angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + resample=resample, + fillcolor=fillcolor, + center=center, + ) + + return features.Image.new_like(input, output) + + +@Dispatcher +def rotate( + input: T, + *, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[int]] = None, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, +) -> T: + """ADDME""" + pass + + +@rotate.implements(features.Image, pil_kernel=_F.rotate) +def _rotate_image( + input: features.Image, + *, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[int]] = None, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, +) -> features.Image: + output = F.rotate( + input, angle=angle, interpolation=interpolation, expand=expand, center=center, fill=fill, resample=resample + ) + + return features.Image.new_like(input, output) + + +@Dispatcher +def adjust_brightness(input: T, *, brightness_factor: float) -> T: + """ADDME""" + pass + + +@adjust_brightness.implements(features.Image, pil_kernel=_F.adjust_brightness) +def _adjust_brightness_image(input: features.Image, *, brightness_factor: float) -> features.Image: + output = F.adjust_brightness(input, brightness_factor=brightness_factor) + + return features.Image.new_like(input, output) + + +@Dispatcher +def adjust_saturation(input: T, *, saturation_factor: float) -> T: + """ADDME""" + pass + + +@adjust_saturation.implements(features.Image, pil_kernel=_F.adjust_saturation) +def _adjust_saturation_image(input: features.Image, *, saturation_factor: float) -> features.Image: + output = F.adjust_saturation(input, saturation_factor=saturation_factor) + + return features.Image.new_like(input, output) + + +@Dispatcher +def adjust_contrast(input: T, *, contrast_factor: float) -> T: + """ADDME""" + pass + + +@adjust_contrast.implements(features.Image, pil_kernel=_F.adjust_contrast) +def _adjust_contrast_image(input: features.Image, *, contrast_factor: float) -> features.Image: + output = F.adjust_contrast(input, contrast_factor=contrast_factor) + + return features.Image.new_like(input, output) + + +@Dispatcher +def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: + """ADDME""" + pass + + +@adjust_sharpness.implements(features.Image, pil_kernel=_F.adjust_sharpness) +def _adjust_sharpness_image(input: features.Image, *, sharpness_factor: float) -> features.Image: + output = F.adjust_sharpness(input, sharpness_factor=sharpness_factor) + + return features.Image.new_like(input, output) + + +@Dispatcher +def posterize(input: T, *, bits: int) -> T: + """ADDME""" + pass + + +@posterize.implements(features.Image, pil_kernel=_F.posterize) +def _posterize_image(input: features.Image, *, bits: int) -> features.Image: + output = F.posterize(input, bits=bits) + + return features.Image.new_like(input, output) + + +@Dispatcher +def solarize(input: T, *, threshold: float) -> T: + """ADDME""" + pass + + +@solarize.implements(features.Image, pil_kernel=_F.solarize) +def _solarize_image(input: features.Image, *, threshold: float) -> features.Image: + output = F.solarize(input, threshold=threshold) + + return features.Image.new_like(input, output) + + +@Dispatcher +def autocontrast(input: T) -> T: + """ADDME""" + pass + + +@autocontrast.implements(features.Image, pil_kernel=_F.autocontrast) +def _autocontrast_image(input: features.Image) -> features.Image: + output = F.autocontrast(input) + + return features.Image.new_like(input, output) + + +@Dispatcher +def equalize(input: T) -> T: + """ADDME""" + pass + + +@equalize.implements(features.Image, pil_kernel=_F.equalize) +def _equalize_image(input: features.Image) -> features.Image: + output = F.equalize(input) + + return features.Image.new_like(input, output) + + +@Dispatcher +def invert(input: T) -> T: + """ADDME""" + pass + + +@invert.implements(features.Image, pil_kernel=_F.invert) +def _invert_image(input: features.Image) -> features.Image: + output = F.invert(input) + + return features.Image.new_like(input, output) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c8142742fa8..e656db6058a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -28,9 +28,6 @@ def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tupl ) -_resize_image = _F.resize - - def resize_image( image: torch.Tensor, size: List[int], @@ -41,7 +38,7 @@ def resize_image( new_height, new_width = size num_channels, old_height, old_width = image.shape[-3:] batch_shape = image.shape[:-3] - return _resize_image( + return _F.resize( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation, diff --git a/torchvision/prototype/transforms/functional/dispatch.yaml b/torchvision/prototype/transforms/functional/dispatch.yaml new file mode 100644 index 00000000000..2d5d41e685e --- /dev/null +++ b/torchvision/prototype/transforms/functional/dispatch.yaml @@ -0,0 +1,164 @@ +# This is the configuration file to auto-generate the dispatch of feature type specific kernels of the same transform. +# The auto-generation is general enough to apply to all dispatchers that +# 1. take a feature as first input that is solely used to determine the required kernel and +# 2. are feature type preserving. +# +# After you have implemented new kernels, follow the schema explanation below to configure the auto-generation. +# Finally, run `$ python scripts/regenerate_transform_dispatch.py` to regenerate the dispatch. +# +# This configuration uses the following schema: +# +# $DISPATCHER_NAME: +# $FEATURE_NAME: +# kernel: $KERNEL_NAME +# meta_conversion: $META_CONVERSION +# kwargs_overwrite: $KWARGS_OVERWRITE +# meta_overwrite: $META_OVERWRITE +# ... +# +# KERNEL_NAME: Name of the kernel that can be accessed from `torchvision.prototype.transforms.fuctional`. If the +# canonical naming scheme is followed, this should be $DISPATCHER_NAME followed by $FEATURE_NAME in snake case. +# +# META_CONVERSION: Optional mapping of meta attributes to convert before the kernel call. This is needed if the kernel +# makes assumptions on meta attributes of the input tensor. For example, the `resize_bounding_box` requires the input +# to be in XYXY format. This can be achieved with: +# +# resize: +# BoundingBox: +# kernel: resize_bounding_box +# meta_conversion: +# format: BoundingBoxFormat.XYXY +# +# KWARGS_OVERWRITE: Optional mapping of keyword arguments to apply to the keyword arguments of the dispatcher before +# they get passed to the kernel. If a value is prefixed with a dot, this meta attribute of the input feature is added +# under the given key. This is useful in two scenarios: +# +# 1. The kernel requires meta information that is available as meta attribute on the input feature or +# 2. the kernel parameters names differ from the ones of the dispatcher. +# +# For example, the `resize_bounding_box` kernel takes the `old_image_size` and `new_image_size` parameters, whereas +# the `resize` dispatcher should only take `size`. Thus, we need to take `image_size` stored on the `BoundingBox` +# feature and map it to `old_image_size` as well as mapping the `size` input to `new_image_size`. +# +# resize: +# BoundingBox: +# kernel: resize_bounding_box +# kwargs_overwrite: +# old_image_size: .image_size +# new_image_size: size +# +# META_OVERWRITE: Optional mapping of keyword arguments to overwrite meta attributes when creating the new feature from +# the kernel output. This is needed if the kernel changes a meta attribute. For example, the `resize_bounding_box` +# kernel adapts the values, but the new image size cannot be inferred from them. Thus, we set the `image_size` +# metadata of the new bounding box to be the `size` the dispatcher was called with: +# +# resize: +# BoundingBox: +# kernel: resize_bounding_box +# meta_overwrite: +# image_size: size +# +# For $FEATURE_NAME == Image, there is also the `pil_kernel` key available. This is the name of the legacy `PIL` kernel +# that can be accessed from `torchvision.transforms.fuctional` (note the missing `.prototype` compared to +# $KERNEL_NAME). It will be called with same arguments as the dispatcher. +# +# If no optional configuration is required, the shortcut +# +# $DISPATCHER_NAME: +# $FEATURE_NAME: $KERNEL_NAME +# +# can be used, which is equivalent to +# +# $DISPATCHER_NAME: +# $FEATURE_NAME: +# kernel: $KERNEL_NAME + +horizontal_flip: + Image: horizontal_flip_image + BoundingBox: + kernel: horizontal_flip_bounding_box + meta_conversion: + format: BoundingBoxFormat.XYXY + kwargs_overwrite: + image_size: .image_size +resize: + Image: + kernel: resize_image + pil_kernel: resize + BoundingBox: + kernel: resize_bounding_box + meta_conversion: + format: BoundingBoxFormat.XYXY + kwargs_overwrite: + old_image_size: .image_size + new_image_size: size + meta_overwrite: + image_size: size + SegmentationMask: resize_segmentation_mask +center_crop: + Image: + kernel: center_crop_image + pil_kernel: center_crop +normalize: + Image: + kernel: normalize_image + meta_overwrite: + color_space: ColorSpace.OTHER +resized_crop: + Image: + kernel: resized_crop_image + pil_kernel: resized_crop +erase: + Image: + kernel: erase_image + pil_kernel: erase +mixup: + Image: mixup_image + OneHotLabel: mixup_one_hot_label +cutmix: + Image: cutmix_image + OneHotLabel: cutmix_one_hot_label +affine: + Image: + kernel: affine_image + pil_kernel: affine +rotate: + Image: + kernel: rotate_image + pil_kernel: rotate +adjust_brightness: + Image: + kernel: adjust_brightness_image + pil_kernel: adjust_brightness +adjust_saturation: + Image: + kernel: adjust_saturation_image + pil_kernel: adjust_saturation +adjust_contrast: + Image: + kernel: adjust_contrast_image + pil_kernel: adjust_contrast +adjust_sharpness: + Image: + kernel: adjust_sharpness_image + pil_kernel: adjust_sharpness +posterize: + Image: + kernel: posterize_image + pil_kernel: posterize +solarize: + Image: + kernel: solarize_image + pil_kernel: solarize +autocontrast: + Image: + kernel: autocontrast_image + pil_kernel: autocontrast +equalize: + Image: + kernel: equalize_image + pil_kernel: equalize +invert: + Image: + kernel: invert_image + pil_kernel: invert diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py new file mode 100644 index 00000000000..22bdeefdffd --- /dev/null +++ b/torchvision/prototype/transforms/functional/utils.py @@ -0,0 +1,72 @@ +import functools +from typing import Any, Type, Optional, Callable + +import PIL.Image +import torch +import torch.overrides +from torchvision.prototype import features +from torchvision.prototype.utils._internal import sequence_to_str + + +def is_supported(obj: Any, *types: Type) -> bool: + return (obj if isinstance(obj, type) else type(obj)) in types + + +class Dispatcher: + FEATURE_SPECIFIC_PARAM = object() + + def __init__(self, dispatch_fn): + self._dispatch_fn = dispatch_fn + self._support = set() + self._pil_kernel: Optional[Callable] = None + + def supports(self, obj: Any) -> bool: + return is_supported(obj, *self._support) + + def implements(self, feature_type, *, feature_specific_params=(), pil_kernel=None): + if pil_kernel is not None: + if not issubclass(feature_type, features.Image): + raise TypeError("PIL kernel can only be registered for images") + + self._pil_kernel = pil_kernel + + def outer_wrapper(implement_fn): + feature_type._KERNELS[self._dispatch_fn] = implement_fn + self._support.add(feature_type) + + @functools.wraps(implement_fn) + def inner_wrapper(*args, **kwargs) -> Any: + missing = [ + param + for param in feature_specific_params + if kwargs.get(param, self.FEATURE_SPECIFIC_PARAM) is self.FEATURE_SPECIFIC_PARAM + ] + if missing: + raise TypeError( + f"{implement_fn.__name__}() missing {len(missing)} required keyword-only arguments: " + f"{sequence_to_str(missing, separate_last='and ')}" + ) + + return implement_fn(*args, **kwargs) + + return inner_wrapper + + return outer_wrapper + + def __call__(self, input, *args, **kwargs): + feature_type = type(input) + if issubclass(feature_type, PIL.Image.Image): + if self._pil_kernel is None: + raise TypeError("No PIL kernel") + + return self._pil_kernel(input, *args, **kwargs) + elif not isinstance(input, torch.Tensor): + raise TypeError("No tensor") + + if not (issubclass(type(input), features.Feature)): + input = features.Image(input) + + if not self.supports(input): + raise ValueError(f"No support for {type(input).__name__}") + + return torch.overrides.handle_torch_function(self._dispatch_fn, (input,), input, *args, **kwargs) From 587687e7c7f2c8d6e91fa216723d0e2d460b6e27 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 31 Jan 2022 18:39:16 +0100 Subject: [PATCH 02/32] fix missing arguments error message --- torchvision/prototype/transforms/functional/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 22bdeefdffd..85bd8320ce7 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -31,9 +31,6 @@ def implements(self, feature_type, *, feature_specific_params=(), pil_kernel=Non self._pil_kernel = pil_kernel def outer_wrapper(implement_fn): - feature_type._KERNELS[self._dispatch_fn] = implement_fn - self._support.add(feature_type) - @functools.wraps(implement_fn) def inner_wrapper(*args, **kwargs) -> Any: missing = [ @@ -43,12 +40,15 @@ def inner_wrapper(*args, **kwargs) -> Any: ] if missing: raise TypeError( - f"{implement_fn.__name__}() missing {len(missing)} required keyword-only arguments: " - f"{sequence_to_str(missing, separate_last='and ')}" + f"{self._dispatch_fn.__name__}() missing {len(missing)} required keyword-only arguments " + f"for feature type {feature_type.__name__}: {sequence_to_str(missing, separate_last='and ')}" ) return implement_fn(*args, **kwargs) + feature_type._KERNELS[self._dispatch_fn] = inner_wrapper + self._support.add(feature_type) + return inner_wrapper return outer_wrapper From 3a4e53dc66f1dfac0febf2db5b981cce2c0cab13 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Feb 2022 14:01:03 +0100 Subject: [PATCH 03/32] remove pil kernel for erase --- torchvision/prototype/transforms/functional/_dispatch.py | 2 +- torchvision/prototype/transforms/functional/dispatch.yaml | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_dispatch.py b/torchvision/prototype/transforms/functional/_dispatch.py index 7aa5bf5d752..d62196d99bb 100644 --- a/torchvision/prototype/transforms/functional/_dispatch.py +++ b/torchvision/prototype/transforms/functional/_dispatch.py @@ -191,7 +191,7 @@ def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: pass -@erase.implements(features.Image, pil_kernel=_F.erase) +@erase.implements(features.Image) def _erase_image( input: features.Image, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False ) -> features.Image: diff --git a/torchvision/prototype/transforms/functional/dispatch.yaml b/torchvision/prototype/transforms/functional/dispatch.yaml index 2d5d41e685e..4df3221e3b7 100644 --- a/torchvision/prototype/transforms/functional/dispatch.yaml +++ b/torchvision/prototype/transforms/functional/dispatch.yaml @@ -109,9 +109,7 @@ resized_crop: kernel: resized_crop_image pil_kernel: resized_crop erase: - Image: - kernel: erase_image - pil_kernel: erase + Image: erase_image mixup: Image: mixup_image OneHotLabel: mixup_one_hot_label From 35845b5e8a9db047ca6cd311cb53aea6536cf4a9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Feb 2022 14:17:51 +0100 Subject: [PATCH 04/32] automate feature specific parameter detection --- scripts/regenerate_transform_dispatch.py | 12 +----------- .../prototype/transforms/functional/_dispatch.py | 4 ++-- torchvision/prototype/transforms/functional/utils.py | 12 +++++++++++- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/scripts/regenerate_transform_dispatch.py b/scripts/regenerate_transform_dispatch.py index edc3bcf01b0..22c2184a940 100644 --- a/scripts/regenerate_transform_dispatch.py +++ b/scripts/regenerate_transform_dispatch.py @@ -181,7 +181,6 @@ def make_dispatcher_params(implementer_params): if len(implementer_params) == 1: dispatcher_params.append(copy(param)) else: - param._default = FEATURE_SPECIFIC_PARAM dispatcher_params.append( Parameter( name=name, @@ -390,12 +389,6 @@ def __init__( body = [] - feature_specific_params = [] - for param in params: - if param.default is FEATURE_SPECIFIC_PARAM: - feature_specific_params.append(param.name) - param._default = Parameter.empty - output_conversions = [] for idx, (attr, intermediate_value) in enumerate(conversion_map.items()): @@ -442,7 +435,6 @@ def make_conversion_call(input, old, new): decorator=self._make_decorator( dispatcher_name=dispatcher_name, feature_type_usage=feature_type_usage, - feature_specific_params=feature_specific_params, pil_kernel=pil_kernel, ), name=f"_{dispatcher_name}_{camel_to_snake_case(feature_type.__name__)}", @@ -460,10 +452,8 @@ def make_conversion_call(input, old, new): body=body, ) - def _make_decorator(self, *, dispatcher_name, feature_type_usage, feature_specific_params, pil_kernel): + def _make_decorator(self, *, dispatcher_name, feature_type_usage, pil_kernel): decorator = f"{dispatcher_name}.implements({feature_type_usage}" - if feature_specific_params: - decorator += f", feature_specific_params={tuple(feature_specific_params)}" if pil_kernel: decorator += f", pil_kernel=_F.{pil_kernel}" return f"{decorator})" diff --git a/torchvision/prototype/transforms/functional/_dispatch.py b/torchvision/prototype/transforms/functional/_dispatch.py index d62196d99bb..2febe232091 100644 --- a/torchvision/prototype/transforms/functional/_dispatch.py +++ b/torchvision/prototype/transforms/functional/_dispatch.py @@ -232,7 +232,7 @@ def cutmix( pass -@cutmix.implements(features.Image, feature_specific_params=("box",)) +@cutmix.implements(features.Image) def _cutmix_image( input: features.Image, *, box: Tuple[int, int, int, int], inplace: bool = False, **_: Any ) -> features.Image: @@ -241,7 +241,7 @@ def _cutmix_image( return features.Image.new_like(input, output) -@cutmix.implements(features.OneHotLabel, feature_specific_params=("lam_adjusted",)) +@cutmix.implements(features.OneHotLabel) def _cutmix_one_hot_label( input: features.OneHotLabel, *, lam_adjusted: float, inplace: bool = False, **_: Any ) -> features.OneHotLabel: diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 85bd8320ce7..c47d4cb42d1 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -1,4 +1,5 @@ import functools +import inspect from typing import Any, Type, Optional, Callable import PIL.Image @@ -23,7 +24,7 @@ def __init__(self, dispatch_fn): def supports(self, obj: Any) -> bool: return is_supported(obj, *self._support) - def implements(self, feature_type, *, feature_specific_params=(), pil_kernel=None): + def implements(self, feature_type, *, pil_kernel=None): if pil_kernel is not None: if not issubclass(feature_type, features.Image): raise TypeError("PIL kernel can only be registered for images") @@ -31,6 +32,15 @@ def implements(self, feature_type, *, feature_specific_params=(), pil_kernel=Non self._pil_kernel = pil_kernel def outer_wrapper(implement_fn): + implement_params = inspect.signature(implement_fn).parameters + feature_specific_params = [ + name + for name, param in inspect.signature(self._dispatch_fn).parameters.items() + if param.default is self.FEATURE_SPECIFIC_PARAM + and name in implement_params + and implement_params[name] is inspect.Parameter.empty + ] + @functools.wraps(implement_fn) def inner_wrapper(*args, **kwargs) -> Any: missing = [ From 7778782bf97b29d35db087563f81d53734b13876 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 1 Feb 2022 14:51:45 +0100 Subject: [PATCH 05/32] fix typos --- references/detection/coco_utils.py | 9 ++++++--- references/detection/train.py | 11 ++++++++++- torchvision/models/detection/keypoint_rcnn.py | 3 +++ .../prototype/transforms/functional/dispatch.yaml | 10 +++++----- 4 files changed, 24 insertions(+), 9 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index b0f193135ee..b9143f47034 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -53,12 +53,13 @@ def __call__(self, image, target): anno = target["annotations"] + # drop all crowd annotations anno = [obj for obj in anno if obj["iscrowd"] == 0] boxes = [obj["bbox"] for obj in anno] # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) - boxes[:, 2:] += boxes[:, :2] + boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) @@ -76,7 +77,7 @@ def __call__(self, image, target): if num_keypoints: keypoints = keypoints.view(num_keypoints, -1, 3) - keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) # valid boxes boxes = boxes[keep] classes = classes[keep] masks = masks[keep] @@ -93,7 +94,9 @@ def __call__(self, image, target): # for conversion to coco api area = torch.tensor([obj["area"] for obj in anno]) - iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) + iscrowd = torch.tensor( + [obj["iscrowd"] for obj in anno] + ) # this makes little sense, since we already exlcuded them at the top target["area"] = area target["iscrowd"] = iscrowd diff --git a/references/detection/train.py b/references/detection/train.py index 765f8144364..6a81d7fa332 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -177,6 +177,8 @@ def main(args): dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) + dataset[0] + print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -276,5 +278,12 @@ def main(args): if __name__ == "__main__": - args = get_args_parser().parse_args() + args = get_args_parser().parse_args( + [ + "--data-path=/home/philip/datasets/coco", + "--device=cpu", + "--dataset=coco_kp", + "--model=keypointrcnn_resnet50_fpn", + ] + ) main(args) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 93e966bae4b..c58b72217c9 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -253,6 +253,9 @@ def __init__( self.roi_heads.keypoint_head = keypoint_head self.roi_heads.keypoint_predictor = keypoint_predictor + def forward(self, *args, **kwargs): + return super().forward(*args, **kwargs) + class KeypointRCNNHeads(nn.Sequential): def __init__(self, in_channels, layers): diff --git a/torchvision/prototype/transforms/functional/dispatch.yaml b/torchvision/prototype/transforms/functional/dispatch.yaml index 4df3221e3b7..6c963ad840b 100644 --- a/torchvision/prototype/transforms/functional/dispatch.yaml +++ b/torchvision/prototype/transforms/functional/dispatch.yaml @@ -16,12 +16,12 @@ # meta_overwrite: $META_OVERWRITE # ... # -# KERNEL_NAME: Name of the kernel that can be accessed from `torchvision.prototype.transforms.fuctional`. If the +# KERNEL_NAME: Name of the kernel that can be accessed from `torchvision.prototype.transforms.functional`. If the # canonical naming scheme is followed, this should be $DISPATCHER_NAME followed by $FEATURE_NAME in snake case. # # META_CONVERSION: Optional mapping of meta attributes to convert before the kernel call. This is needed if the kernel -# makes assumptions on meta attributes of the input tensor. For example, the `resize_bounding_box` requires the input -# to be in XYXY format. This can be achieved with: +# makes assumptions on meta attributes of the input tensor. For example, the `resize_bounding_box` kernel requires +# the input to be in XYXY format. This can be achieved with: # # resize: # BoundingBox: @@ -50,7 +50,7 @@ # META_OVERWRITE: Optional mapping of keyword arguments to overwrite meta attributes when creating the new feature from # the kernel output. This is needed if the kernel changes a meta attribute. For example, the `resize_bounding_box` # kernel adapts the values, but the new image size cannot be inferred from them. Thus, we set the `image_size` -# metadata of the new bounding box to be the `size` the dispatcher was called with: +# metadata of the new bounding box to be the `size` the dispatcher was called with: # # resize: # BoundingBox: @@ -59,7 +59,7 @@ # image_size: size # # For $FEATURE_NAME == Image, there is also the `pil_kernel` key available. This is the name of the legacy `PIL` kernel -# that can be accessed from `torchvision.transforms.fuctional` (note the missing `.prototype` compared to +# that can be accessed from `torchvision.transforms.functional` (note the missing `.prototype` compared to # $KERNEL_NAME). It will be called with same arguments as the dispatcher. # # If no optional configuration is required, the shortcut From 019a0b6e80b5d5a1e2cdc14cc420bb441c40e99f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 3 Feb 2022 15:01:00 +0100 Subject: [PATCH 06/32] cleanup dispatcher call --- torchvision/prototype/transforms/functional/utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index c47d4cb42d1..7466aa9defe 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -64,8 +64,7 @@ def inner_wrapper(*args, **kwargs) -> Any: return outer_wrapper def __call__(self, input, *args, **kwargs): - feature_type = type(input) - if issubclass(feature_type, PIL.Image.Image): + if isinstance(input, PIL.Image.Image): if self._pil_kernel is None: raise TypeError("No PIL kernel") @@ -73,7 +72,7 @@ def __call__(self, input, *args, **kwargs): elif not isinstance(input, torch.Tensor): raise TypeError("No tensor") - if not (issubclass(type(input), features.Feature)): + if not isinstance(input, features.Feature): input = features.Image(input) if not self.supports(input): From 4cb2350193cadbea12275975c089d462b31a860f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 4 Feb 2022 12:13:37 +0100 Subject: [PATCH 07/32] remove __torch_function__ from transform dispatch --- torchvision/prototype/features/_feature.py | 21 +++++-------------- .../prototype/transforms/functional/utils.py | 21 ++++++++++--------- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index bc059bf7142..d4948048cfe 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -10,7 +10,6 @@ class Feature(torch.Tensor): _META_ATTRS: Set[str] = set() _metadata: Dict[str, Any] - _KERNELS: Dict[Callable, Callable] def __init_subclass__(cls): # In order to help static type checkers, we require subclasses of `Feature` to add the metadata attributes @@ -38,8 +37,6 @@ def __init_subclass__(cls): for name in meta_attrs: setattr(cls, name, property(lambda self, name=name: self._metadata[name])) - cls._KERNELS = {} - def __new__(cls, data, *, dtype=None, device=None): feature = torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -83,23 +80,15 @@ def __torch_function__( kwargs: Optional[Mapping[str, Any]] = None, ) -> torch.Tensor: kwargs = kwargs or dict() - if cls is not Feature and func in cls._KERNELS: - return cls._KERNELS[func](*args, **kwargs) - with DisableTorchFunction(): output = func(*args, **kwargs) - if func not in cls._TORCH_FUNCTION_ALLOW_MAP: + if func is torch.Tensor.clone: + return cls.new_like(args[0], output) + elif func is torch.Tensor.to: + return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) + else: return output - other = args - for item in cls._TORCH_FUNCTION_ALLOW_MAP[func]: - other = other[item] - - dtype = output.dtype if func in cls._DTYPE_CONVERTERS else None - device = output.device if func in cls._DTYPE_CONVERTERS else None - - return cls.new_like(other, output, dtype=dtype, device=device) - def __repr__(self): return torch.Tensor.__repr__(self).replace("tensor", type(self).__name__) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 7466aa9defe..7d0620a4a36 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -18,11 +18,11 @@ class Dispatcher: def __init__(self, dispatch_fn): self._dispatch_fn = dispatch_fn - self._support = set() + self._kernels = {} self._pil_kernel: Optional[Callable] = None def supports(self, obj: Any) -> bool: - return is_supported(obj, *self._support) + return is_supported(obj, *self._kernels.keys()) def implements(self, feature_type, *, pil_kernel=None): if pil_kernel is not None: @@ -56,26 +56,27 @@ def inner_wrapper(*args, **kwargs) -> Any: return implement_fn(*args, **kwargs) - feature_type._KERNELS[self._dispatch_fn] = inner_wrapper - self._support.add(feature_type) + self._kernels[feature_type] = inner_wrapper return inner_wrapper return outer_wrapper def __call__(self, input, *args, **kwargs): - if isinstance(input, PIL.Image.Image): + feature_type = type(input) + + if issubclass(feature_type, PIL.Image.Image): if self._pil_kernel is None: raise TypeError("No PIL kernel") return self._pil_kernel(input, *args, **kwargs) - elif not isinstance(input, torch.Tensor): + elif not issubclass(feature_type, torch.Tensor): raise TypeError("No tensor") - if not isinstance(input, features.Feature): + if not issubclass(feature_type, features.Feature): input = features.Image(input) - if not self.supports(input): - raise ValueError(f"No support for {type(input).__name__}") + if not self.supports(feature_type): + raise ValueError(f"No support for {feature_type.__name__}") - return torch.overrides.handle_torch_function(self._dispatch_fn, (input,), input, *args, **kwargs) + return self._kernels[feature_type](input, *args, **kwargs) From 158a2167f62ad60eede46c8d1885f965eb7ec522 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Feb 2022 15:19:43 +0100 Subject: [PATCH 08/32] remove auto-generation --- scripts/regenerate_transform_dispatch.py | 568 ------------------ .../transforms/functional/__init__.py | 33 +- .../transforms/functional/_augment.py | 43 ++ .../prototype/transforms/functional/_color.py | 79 +++ .../transforms/functional/_dispatch.py | 448 -------------- .../transforms/functional/_geometry.py | 127 +++- .../transforms/functional/_meta_conversion.py | 7 +- .../prototype/transforms/functional/_misc.py | 15 + .../transforms/functional/dispatch.yaml | 162 ----- .../prototype/transforms/functional/utils.py | 79 +-- 10 files changed, 330 insertions(+), 1231 deletions(-) delete mode 100644 scripts/regenerate_transform_dispatch.py delete mode 100644 torchvision/prototype/transforms/functional/_dispatch.py delete mode 100644 torchvision/prototype/transforms/functional/dispatch.yaml diff --git a/scripts/regenerate_transform_dispatch.py b/scripts/regenerate_transform_dispatch.py deleted file mode 100644 index 22c2184a940..00000000000 --- a/scripts/regenerate_transform_dispatch.py +++ /dev/null @@ -1,568 +0,0 @@ -import contextlib -import enum -import importlib -import inspect -import pathlib -import re -import sys -import typing -import warnings -from copy import copy -from typing import Any - -import torchvision.prototype.transforms.functional as F -from torchvision import transforms -from torchvision.prototype import features - -try: - import yaml -except ModuleNotFoundError: - raise ModuleNotFoundError() - - -ENUMS = [ - (features, ["BoundingBoxFormat", "ColorSpace"]), - (transforms, ["InterpolationMode"]), -] - -ENUMS_MAP = {name: getattr(module, name) for module, names in ENUMS for name in names} - -META_CONVERTER_MAP = { - (features.Image, "color_space"): F.convert_color_space, - (features.BoundingBox, "format"): F.convert_bounding_box_format, -} - - -class ManualAnnotation: - def __init__(self, repr): - self.repr = repr - - def __repr__(self): - return self.repr - - def __eq__(self, other): - if not isinstance(other, ManualAnnotation): - return NotImplemented - - return self.repr == other.repr - - -# TODO: typing module -FEATURE_SPECIFIC_PARAM = ManualAnnotation("Dispatcher.FEATURE_SPECIFIC_PARAM") -FEATURE_SPECIFIC_DEFAULT = ManualAnnotation("FEATURE_SPECIFIC_DEFAULT") -GENERIC_FEATURE_TYPE = ManualAnnotation("T") - - -def main(dispatch_config): - functions = [] - for dispatcher_name, feature_type_configs in dispatch_config.items(): - try: - feature_type_configs = validate_feature_type_configs(feature_type_configs) - kernel_params, implementer_params = make_kernel_and_implementer_params(feature_type_configs) - dispatcher_params = make_dispatcher_params(implementer_params) - except Exception as error: - raise RuntimeError( - f"while working on dispatcher '{dispatcher_name}' the following error was raised:\n\n" - f"{type(error).__name__}: {error}" - ) from None - - functions.append(DispatcherFunction(name=dispatcher_name, params=dispatcher_params)) - functions.extend( - [ - ImplementerFunction( - dispatcher_name=dispatcher_name, - feature_type=feature_type, - params=implementer_params[feature_type], - pil_kernel=config.get("pil_kernel"), - kernel=config["kernel"], - kernel_params=kernel_params[feature_type], - conversion_map=config["meta_conversion"], - kernel_param_name_map=config["kwargs_overwrite"], - meta_overwrite=config["meta_overwrite"], - ) - for feature_type, config in feature_type_configs.items() - ] - ) - - return ufmt_format(make_file_content(functions)) - - -def validate_feature_type_configs(feature_type_configs): - try: - feature_type_configs = { - getattr(features, feature_type_name): config for feature_type_name, config in feature_type_configs.items() - } - except AttributeError: - # unknown feature type - raise TypeError() from None - - for feature_type, config in tuple(feature_type_configs.items()): - if not isinstance(config, dict): - feature_type_configs[feature_type] = config = dict(kernel=config) - - unknown_keys = config.keys() - { - "kernel", - "pil_kernel", - "meta_conversion", - "kwargs_overwrite", - "meta_overwrite", - } - if unknown_keys: - raise KeyError(unknown_keys) - - try: - config["kernel"] = getattr(F, config["kernel"]) - except KeyError: - # no kernel provided - raise - except AttributeError: - # kernel not accessible - raise - - if "pil_kernel" in config and feature_type is not features.Image: - raise TypeError - - for key in ["meta_conversion", "kwargs_overwrite", "meta_overwrite"]: - if key not in config: - config[key] = dict() - continue - - for meta_attr, value in tuple(config[key].items()): - # if meta_attr not in feature_type._META_ATTRS: - # raise KeyError(meta_attr) - - config[key][meta_attr] = maybe_convert_to_enum(value) - - # TODO: bunchify the individual configs - return feature_type_configs - - -def make_kernel_and_implementer_params(feature_type_configs): - kernel_params = {} - implementer_params = {} - for feature_type, config in feature_type_configs.items(): - kernel_params[feature_type] = [ - Parameter.from_regular(param) for param in list(inspect.signature(config["kernel"]).parameters.values())[1:] - ] - implementer_params[feature_type] = [ - Parameter( - name=config["kwargs_overwrite"].get(kernel_param.name, kernel_param.name), - kind=inspect.Parameter.KEYWORD_ONLY, - default=kernel_param.default, - annotation=kernel_param.annotation, - ) - for kernel_param in kernel_params[feature_type] - if not config["kwargs_overwrite"].get(kernel_param.name, "").startswith(".") - ] - return kernel_params, implementer_params - - -def make_dispatcher_params(implementer_params): - # not using a set here to keep the order - dispatcher_param_names = [] - for params in implementer_params.values(): - dispatcher_param_names.extend([param.name for param in params]) - dispatcher_param_names = unique(dispatcher_param_names) - - dispatcher_params = [] - need_kwargs_ignore = set() - for name in dispatcher_param_names: - dispatcher_param_candidates = {} - for feature_type, params in implementer_params.items(): - params = {param.name: param for param in params} - if name not in params: - need_kwargs_ignore.add(feature_type) - continue - else: - dispatcher_param_candidates[feature_type] = params[name] - - if len(dispatcher_param_candidates) == 1: - param = next(iter(dispatcher_param_candidates.values())) - if len(implementer_params) == 1: - dispatcher_params.append(copy(param)) - else: - dispatcher_params.append( - Parameter( - name=name, - kind=Parameter.KEYWORD_ONLY, - default=FEATURE_SPECIFIC_PARAM, - annotation=param.annotation, - ) - ) - continue - - annotations = {param.annotation for param in dispatcher_param_candidates.values()} - if len(annotations) > 1: - raise TypeError( - f"Found multiple annotations for parameter `{name}`: " - f"{', '.join([str(annotation) for annotation in annotations])}" - ) - - defaults = {param.default for param in dispatcher_param_candidates.values()} - default = FEATURE_SPECIFIC_DEFAULT if len(defaults) > 1 else defaults.pop() - - dispatcher_params.append( - Parameter( - name=name, - kind=Parameter.KEYWORD_ONLY, - default=default, - annotation=annotations.pop(), - ) - ) - - without_default = [] - with_default = [] - for param in dispatcher_params: - (without_default if param.default in (Parameter.empty, FEATURE_SPECIFIC_PARAM) else with_default).append(param) - dispatcher_params = [*without_default, *with_default] - - for feature_type in need_kwargs_ignore: - implementer_params[feature_type].append(Parameter(name="_", kind=Parameter.VAR_KEYWORD, annotation=Any)) - - return dispatcher_params - - -def make_file_content(functions): - enums = "\n".join(f"from {module.__package__} import {', '.join(names)}" for module, names in ENUMS) - - header = f""" -# THIS FILE IS AUTOGENERATED -# -# FROM torchvision/prototype/transforms/functional/dispatch.yaml -# WITH scripts/regenerate_transforms_dispatch.py -# -# DO NOT CHANGE MANUALLY! - -from typing import Any, TypeVar, List, Optional, Tuple - -import torch -import torchvision.transforms.functional as _F -import torchvision.prototype.transforms.functional as F -from torchvision.prototype import features -{enums} - -Dispatcher = F.utils.Dispatcher - -# This is just a sentinel to have a default argument for a dispatcher if the feature specific implementations use -# different defaults. The actual value is never used. -{FEATURE_SPECIFIC_DEFAULT} = object() - -{GENERIC_FEATURE_TYPE} = TypeVar("{GENERIC_FEATURE_TYPE}", bound=features.Feature) -""" - header = "\n".join(line.strip() for line in header.splitlines()) - - __all__ = "\n".join( - ( - "__all__ = [", - *[ - indent(f"{format_value(function.name)},") - for function in functions - if isinstance(function, DispatcherFunction) - ], - "]", - ) - ) - return ( - "\n\n\n".join( - ( - header, - __all__, - *[str(function) for function in functions], - ) - ) - + "\n" - ) - - -class Parameter(inspect.Parameter): - @classmethod - def from_regular(cls, param): - return cls(param.name, param.kind, default=param.default, annotation=param.annotation) - - def __str__(self): - @contextlib.contextmanager - def tmp_override(**tmp_values): - values = {name: getattr(self, name) for name in tmp_values} - for name, tmp_value in tmp_values.items(): - setattr(self, f"_{name}", tmp_value) - try: - yield - finally: - for name, value in values.items(): - setattr(self, f"_{name}", value) - - tmp_values = dict() - - if isinstance(self.default, enum.Enum): - tmp_values["default"] = ManualAnnotation(format_value(self.default)) - - # OPtional only has one - # check docs ther ewas something about checking in the patch notes maybe? - if ( - hasattr(self.annotation, "__origin__") - and self.annotation.__origin__ is typing.Union - and type(None) in self.annotation.__args__ - ): - annotations = [ - inspect.formatannotation(arg) for arg in self.annotation.__args__ if arg is not type(None) # noqa: E721 - ] - tmp_values["annotation"] = ManualAnnotation(f"Optional[{', '.join(annotations)}]") - elif isinstance(self.annotation, enum.EnumMeta): - tmp_values["annotation"] = ManualAnnotation(self.annotation.__name__) - - with tmp_override(**tmp_values): - return super().__str__() - - -class Signature(inspect.Signature): - def __str__(self): - text = super().__str__() - for separator in [FEATURE_SPECIFIC_PARAM, FEATURE_SPECIFIC_DEFAULT]: - parts = text.split(repr(separator)) - text = f"{separator}, # type: ignore[assignment]\n".join( - [ - parts[0], - *[part.lstrip(",") for part in parts[1:]], - ] - ) - return text - - -class Function: - def __init__(self, *, decorator=None, name, signature, docstring=None, body=("pass",)): - self.decorator = decorator - self.name = name - self.signature = signature - self.docstring = docstring - self.body = body - - def __str__(self): - lines = [] - if self.decorator: - lines.append(f"@{self.decorator}") - lines.append(f"def {self.name}{self.signature}:") - if self.docstring: - lines.append(indent('"""' + self.docstring + '"""')) - lines.extend([indent(line) for line in self.body]) - return "\n".join(lines) - - -class DispatcherFunction(Function): - def __init__(self, *, name, params, input_name="input"): - for param in params: - param._kind = Parameter.KEYWORD_ONLY - signature = Signature( - parameters=[ - Parameter( - name=input_name, - kind=Parameter.POSITIONAL_OR_KEYWORD, - annotation=GENERIC_FEATURE_TYPE, - ), - *params, - ], - return_annotation=GENERIC_FEATURE_TYPE, - ) - super().__init__( - decorator="Dispatcher", - name=name, - signature=signature, - docstring="ADDME", - ) - - -class ImplementerFunction(Function): - def __init__( - self, - *, - dispatcher_name, - feature_type, - params, - pil_kernel, - kernel, - kernel_params, - conversion_map, - kernel_param_name_map, - meta_overwrite, - input_name="input", - ): - feature_type_usage = ManualAnnotation(f"features.{feature_type.__name__}") - - body = [] - - output_conversions = [] - for idx, (attr, intermediate_value) in enumerate(conversion_map.items()): - - converter = META_CONVERTER_MAP[(feature_type, attr)] - - def make_conversion_call(input, old, new): - return f"F.{converter.__name__}({input}, old_{attr}={old}, new_{attr}={new})" - - input_attr = f"input.{attr}" - intermediate_name = f"intermediate_{attr}" - body.extend( - [ - f"{intermediate_name} = {format_value(intermediate_value)}", - f"converted_input = {make_conversion_call(input_name, input_attr, intermediate_name)}", - "", - ] - ) - if idx == 0: - input_name = "converted_input" - - output_conversions = [f"output = {make_conversion_call('output', intermediate_name, input_attr)}"] - - kernel_call = self._make_kernel_call( - input_name=input_name, - kernel=kernel, - kernel_params=kernel_params, - kernel_param_name_map=kernel_param_name_map, - ) - body.extend( - [ - f"output = {kernel_call}", - *reversed(output_conversions), - "", - ] - ) - - feature_type_wrapper = self._make_feature_type_wrapper( - feature_type_usage=feature_type_usage, - meta_overwrite=meta_overwrite, - ) - body.append(f"return {feature_type_wrapper}") - - super().__init__( - decorator=self._make_decorator( - dispatcher_name=dispatcher_name, - feature_type_usage=feature_type_usage, - pil_kernel=pil_kernel, - ), - name=f"_{dispatcher_name}_{camel_to_snake_case(feature_type.__name__)}", - signature=Signature( - parameters=[ - Parameter( - name="input", - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=feature_type_usage, - ), - *params, - ], - return_annotation=feature_type_usage, - ), - body=body, - ) - - def _make_decorator(self, *, dispatcher_name, feature_type_usage, pil_kernel): - decorator = f"{dispatcher_name}.implements({feature_type_usage}" - if pil_kernel: - decorator += f", pil_kernel=_F.{pil_kernel}" - return f"{decorator})" - - def _make_kernel_call( - self, - *, - kernel, - input_name, - kernel_params, - kernel_param_name_map, - ): - call_args = [input_name] - for param in kernel_params: - dispatcher_param_name = kernel_param_name_map.get(param.name, param.name) - if dispatcher_param_name.startswith("."): - dispatcher_param_name = f"input{dispatcher_param_name}" - call_args.append(f"{param.name}={dispatcher_param_name}") - return f"F.{kernel.__name__}({', '.join(call_args)})" - - def _make_feature_type_wrapper(self, *, feature_type_usage, meta_overwrite): - call_args = ["input", "output"] - call_args.extend( - f"{meta_name}={format_value(dispatcher_param_name)}" - for meta_name, dispatcher_param_name in meta_overwrite.items() - ) - return f"{feature_type_usage}.new_like({', '.join(call_args)})" - - -def ufmt_format(content): - try: - import ufmt - except ModuleNotFoundError: - return content - - HERE = pathlib.Path(__file__).parent - - with open(HERE.parent / ".pre-commit-config.yaml") as file: - repo = next( - repo for repo in yaml.load(file, yaml.Loader)["repos"] for hook in repo["hooks"] if hook["id"] == "ufmt" - ) - - expected_versions = {ufmt: repo["rev"].replace("v", "")} - for dependency in repo["hooks"][0]["additional_dependencies"]: - name, version = [item.strip() for item in dependency.split("==")] - expected_versions[importlib.import_module(name)] = version - - for module, expected_version in expected_versions.items(): - if module.__version__ != expected_version: - warnings.warn("foo") - - from ufmt.core import make_black_config - from usort.config import Config as UsortConfig - - black_config = make_black_config(HERE) - usort_config = UsortConfig.find(HERE) - - return ufmt.ufmt_string(path=HERE, content=content, usort_config=usort_config, black_config=black_config) - - -def maybe_convert_to_enum(value): - if not isinstance(value, str): - return value - - parts = value.split(".") - if len(parts) != 2: - return value - - enum, member = parts - - try: - return ENUMS_MAP[enum][member] - except KeyError: - return value - - -def indent(text, level=1): - return "\n".join(" " * (level * 4) + line for line in text.splitlines()) - - -def camel_to_snake_case(camel_case: str) -> str: - return "_".join([part.lower() for part in re.findall("[A-Z][^A-Z]*", camel_case)]) - - -def format_value(value): - if isinstance(value, str): - return f'"{value}"' - elif isinstance(value, enum.Enum): - return f"{type(value).__name__}.{value.name}" - else: - return repr(value) - - -def unique(seq): - unique_seq = [] - for item in seq: - if item not in unique_seq: - unique_seq.append(item) - return unique_seq - - -if __name__ == "__main__": - try: - with open(pathlib.Path(F.__path__[0]) / "dispatch.yaml") as file: - dispatch_config = yaml.load(file, yaml.Loader) - content = main(dispatch_config) - with open(pathlib.Path(F.__path__[0]) / "_dispatch.py", "w") as file: - file.write(content) - except Exception as error: - msg = str(error) - print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr) - sys.exit(1) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 37a8096ff7a..5ef198492c6 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,31 +1,54 @@ from . import utils # usort: skip -from ._augment import erase_image, mixup_image, mixup_one_hot_label, cutmix_image, cutmix_one_hot_label +from torchvision.transforms import InterpolationMode + +from ._augment import ( + erase_image, + erase, + mixup_image, + mixup_one_hot_label, + mixup, + cutmix_image, + cutmix_one_hot_label, + cutmix, +) from ._color import ( adjust_brightness_image, + adjust_brightness, adjust_contrast_image, + adjust_contrast, adjust_saturation_image, + adjust_saturation, adjust_sharpness_image, + adjust_sharpness, posterize_image, + posterize, solarize_image, + solarize, autocontrast_image, + autocontrast, equalize_image, + equalize, invert_image, + invert, ) from ._geometry import ( horizontal_flip_bounding_box, horizontal_flip_image, + horizontal_flip, resize_bounding_box, resize_image, resize_segmentation_mask, + resize, center_crop_image, + center_crop, resized_crop_image, - InterpolationMode, + resized_crop, affine_image, + affine, rotate_image, + rotate, ) from ._meta_conversion import convert_color_space, convert_bounding_box_format -from ._misc import normalize_image +from ._misc import normalize_image, normalize from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot - -from ._dispatch import * # usort: skip diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 842ff0cd5d6..0258ab437a6 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,10 +1,29 @@ from typing import Tuple +from typing import TypeVar import torch +from torchvision.prototype import features from torchvision.transforms import functional as _F +from .utils import dispatch + +T = TypeVar("T", bound=features.Feature) + + +@dispatch +def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> T: + """ADDME""" + pass + erase_image = _F.erase +erase.register(features.Image, erase_image) + + +@dispatch +def mixup(input: T, *, lam: float, inplace: bool = False) -> T: + """ADDME""" + pass def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: @@ -22,6 +41,9 @@ def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) return _mixup(image_batch, -4, lam, inplace) +mixup.register(features.Image, mixup_image) + + def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: if one_hot_label_batch.ndim < 2: raise ValueError("Need a batch of one hot labels") @@ -29,6 +51,21 @@ def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplac return _mixup(one_hot_label_batch, -2, lam, inplace) +mixup.register(features.OneHotLabel, mixup_one_hot_label) + + +@dispatch +def cutmix( + input: T, + *, + box: Tuple[int, int, int, int] = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + lam_adjusted: float = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + inplace: bool = False, +) -> T: + """ADDME""" + pass + + def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: if image_batch.ndim < 4: raise ValueError("Need a batch of images") @@ -43,6 +80,9 @@ def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], i return image_batch +cutmix.register(features.Image, cutmix_image) + + def cutmix_one_hot_label( one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False ) -> torch.Tensor: @@ -50,3 +90,6 @@ def cutmix_one_hot_label( raise ValueError("Need a batch of one hot labels") return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) + + +cutmix.register(features.OneHotLabel, cutmix_one_hot_label) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index f2529166d9a..2401391bcb5 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,20 +1,99 @@ +from typing import TypeVar + +from torchvision.prototype import features from torchvision.transforms import functional as _F +from .utils import dispatch + +T = TypeVar("T", bound=features.Feature) + + +@dispatch +def adjust_brightness(input: T, *, brightness_factor: float) -> T: + """ADDME""" + pass + adjust_brightness_image = _F.adjust_brightness +adjust_brightness.register(features.Image, adjust_brightness_image) + + +@dispatch +def adjust_saturation(input: T, *, saturation_factor: float) -> T: + """ADDME""" + pass + + adjust_saturation_image = _F.adjust_saturation +adjust_saturation.register(features.Image, adjust_saturation_image) + + +@dispatch +def adjust_contrast(input: T, *, contrast_factor: float) -> T: + """ADDME""" + pass + adjust_contrast_image = _F.adjust_contrast +adjust_contrast.register(features.Image, adjust_contrast_image) + + +@dispatch +def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: + """ADDME""" + pass + adjust_sharpness_image = _F.adjust_sharpness +adjust_sharpness.register(features.Image, adjust_sharpness_image) + + +@dispatch +def posterize(input: T, *, bits: int) -> T: + """ADDME""" + pass + posterize_image = _F.posterize +posterize.register(features.Image, posterize_image) + + +@dispatch +def solarize(input: T, *, threshold: float) -> T: + """ADDME""" + pass + solarize_image = _F.solarize +solarize.register(features.Image, solarize_image) + + +@dispatch +def autocontrast(input: T) -> T: + """ADDME""" + pass + autocontrast_image = _F.autocontrast +autocontrast.register(features.Image, autocontrast_image) + + +@dispatch +def equalize(input: T) -> T: + """ADDME""" + pass + equalize_image = _F.equalize +equalize.register(features.Image, equalize_image) + + +@dispatch +def invert(input: T) -> T: + """ADDME""" + pass + invert_image = _F.invert +invert.register(features.Image, invert_image) diff --git a/torchvision/prototype/transforms/functional/_dispatch.py b/torchvision/prototype/transforms/functional/_dispatch.py deleted file mode 100644 index 2febe232091..00000000000 --- a/torchvision/prototype/transforms/functional/_dispatch.py +++ /dev/null @@ -1,448 +0,0 @@ -# THIS FILE IS AUTOGENERATED -# -# FROM torchvision/prototype/transforms/functional/dispatch.yaml -# WITH scripts/regenerate_transforms_dispatch.py -# -# DO NOT CHANGE MANUALLY! - -from typing import Any, TypeVar, List, Optional, Tuple - -import torch -import torchvision.prototype.transforms.functional as F -import torchvision.transforms.functional as _F -from torchvision.prototype import features -from torchvision.prototype.features import BoundingBoxFormat, ColorSpace -from torchvision.transforms import InterpolationMode - -Dispatcher = F.utils.Dispatcher - -# This is just a sentinel to have a default argument for a dispatcher if the feature specific implementations use -# different defaults. The actual value is never used. -FEATURE_SPECIFIC_DEFAULT = object() - -T = TypeVar("T", bound=features.Feature) - - -__all__ = [ - "horizontal_flip", - "resize", - "center_crop", - "normalize", - "resized_crop", - "erase", - "mixup", - "cutmix", - "affine", - "rotate", - "adjust_brightness", - "adjust_saturation", - "adjust_contrast", - "adjust_sharpness", - "posterize", - "solarize", - "autocontrast", - "equalize", - "invert", -] - - -@Dispatcher -def horizontal_flip(input: T) -> T: - """ADDME""" - pass - - -@horizontal_flip.implements(features.Image) -def _horizontal_flip_image(input: features.Image) -> features.Image: - output = F.horizontal_flip_image(input) - - return features.Image.new_like(input, output) - - -@horizontal_flip.implements(features.BoundingBox) -def _horizontal_flip_bounding_box(input: features.BoundingBox) -> features.BoundingBox: - intermediate_format = BoundingBoxFormat.XYXY - converted_input = F.convert_bounding_box_format(input, old_format=input.format, new_format=intermediate_format) - - output = F.horizontal_flip_bounding_box(converted_input, image_size=input.image_size) - output = F.convert_bounding_box_format(output, old_format=intermediate_format, new_format=input.format) - - return features.BoundingBox.new_like(input, output) - - -@Dispatcher -def resize( - input: T, - *, - size: List[int], - interpolation: InterpolationMode = FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> T: - """ADDME""" - pass - - -@resize.implements(features.Image, pil_kernel=_F.resize) -def _resize_image( - input: features.Image, - *, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> features.Image: - output = F.resize_image(input, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) - - return features.Image.new_like(input, output) - - -@resize.implements(features.BoundingBox) -def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: - intermediate_format = BoundingBoxFormat.XYXY - converted_input = F.convert_bounding_box_format(input, old_format=input.format, new_format=intermediate_format) - - output = F.resize_bounding_box(converted_input, old_image_size=input.image_size, new_image_size=size) - output = F.convert_bounding_box_format(output, old_format=intermediate_format, new_format=input.format) - - return features.BoundingBox.new_like(input, output, image_size="size") - - -@resize.implements(features.SegmentationMask) -def _resize_segmentation_mask( - input: features.SegmentationMask, - *, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.NEAREST, - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> features.SegmentationMask: - output = F.resize_segmentation_mask( - input, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias - ) - - return features.SegmentationMask.new_like(input, output) - - -@Dispatcher -def center_crop(input: T, *, output_size: List[int]) -> T: - """ADDME""" - pass - - -@center_crop.implements(features.Image, pil_kernel=_F.center_crop) -def _center_crop_image(input: features.Image, *, output_size: List[int]) -> features.Image: - output = F.center_crop(input, output_size=output_size) - - return features.Image.new_like(input, output) - - -@Dispatcher -def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: - """ADDME""" - pass - - -@normalize.implements(features.Image) -def _normalize_image( - input: features.Image, *, mean: List[float], std: List[float], inplace: bool = False -) -> features.Image: - output = F.normalize(input, mean=mean, std=std, inplace=inplace) - - return features.Image.new_like(input, output, color_space=ColorSpace.OTHER) - - -@Dispatcher -def resized_crop( - input: T, - *, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, -) -> T: - """ADDME""" - pass - - -@resized_crop.implements(features.Image, pil_kernel=_F.resized_crop) -def _resized_crop_image( - input: features.Image, - *, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, -) -> features.Image: - output = F.resized_crop( - input, top=top, left=left, height=height, width=width, size=size, interpolation=interpolation - ) - - return features.Image.new_like(input, output) - - -@Dispatcher -def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> T: - """ADDME""" - pass - - -@erase.implements(features.Image) -def _erase_image( - input: features.Image, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False -) -> features.Image: - output = F.erase(input, i=i, j=j, h=h, w=w, v=v, inplace=inplace) - - return features.Image.new_like(input, output) - - -@Dispatcher -def mixup(input: T, *, lam: float, inplace: bool = False) -> T: - """ADDME""" - pass - - -@mixup.implements(features.Image) -def _mixup_image(input: features.Image, *, lam: float, inplace: bool = False) -> features.Image: - output = F.mixup_image(input, lam=lam, inplace=inplace) - - return features.Image.new_like(input, output) - - -@mixup.implements(features.OneHotLabel) -def _mixup_one_hot_label(input: features.OneHotLabel, *, lam: float, inplace: bool = False) -> features.OneHotLabel: - output = F.mixup_one_hot_label(input, lam=lam, inplace=inplace) - - return features.OneHotLabel.new_like(input, output) - - -@Dispatcher -def cutmix( - input: T, - *, - box: Tuple[int, int, int, int] = Dispatcher.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] - lam_adjusted: float = Dispatcher.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] - inplace: bool = False, -) -> T: - """ADDME""" - pass - - -@cutmix.implements(features.Image) -def _cutmix_image( - input: features.Image, *, box: Tuple[int, int, int, int], inplace: bool = False, **_: Any -) -> features.Image: - output = F.cutmix_image(input, box=box, inplace=inplace) - - return features.Image.new_like(input, output) - - -@cutmix.implements(features.OneHotLabel) -def _cutmix_one_hot_label( - input: features.OneHotLabel, *, lam_adjusted: float, inplace: bool = False, **_: Any -) -> features.OneHotLabel: - output = F.cutmix_one_hot_label(input, lam_adjusted=lam_adjusted, inplace=inplace) - - return features.OneHotLabel.new_like(input, output) - - -@Dispatcher -def affine( - input: T, - *, - angle: float, - translate: List[int], - scale: float, - shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None, - resample: Optional[int] = None, - fillcolor: Optional[List[float]] = None, - center: Optional[List[int]] = None, -) -> T: - """ADDME""" - pass - - -@affine.implements(features.Image, pil_kernel=_F.affine) -def _affine_image( - input: features.Image, - *, - angle: float, - translate: List[int], - scale: float, - shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None, - resample: Optional[int] = None, - fillcolor: Optional[List[float]] = None, - center: Optional[List[int]] = None, -) -> features.Image: - output = F.affine( - input, - angle=angle, - translate=translate, - scale=scale, - shear=shear, - interpolation=interpolation, - fill=fill, - resample=resample, - fillcolor=fillcolor, - center=center, - ) - - return features.Image.new_like(input, output) - - -@Dispatcher -def rotate( - input: T, - *, - angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[int]] = None, - fill: Optional[List[float]] = None, - resample: Optional[int] = None, -) -> T: - """ADDME""" - pass - - -@rotate.implements(features.Image, pil_kernel=_F.rotate) -def _rotate_image( - input: features.Image, - *, - angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[int]] = None, - fill: Optional[List[float]] = None, - resample: Optional[int] = None, -) -> features.Image: - output = F.rotate( - input, angle=angle, interpolation=interpolation, expand=expand, center=center, fill=fill, resample=resample - ) - - return features.Image.new_like(input, output) - - -@Dispatcher -def adjust_brightness(input: T, *, brightness_factor: float) -> T: - """ADDME""" - pass - - -@adjust_brightness.implements(features.Image, pil_kernel=_F.adjust_brightness) -def _adjust_brightness_image(input: features.Image, *, brightness_factor: float) -> features.Image: - output = F.adjust_brightness(input, brightness_factor=brightness_factor) - - return features.Image.new_like(input, output) - - -@Dispatcher -def adjust_saturation(input: T, *, saturation_factor: float) -> T: - """ADDME""" - pass - - -@adjust_saturation.implements(features.Image, pil_kernel=_F.adjust_saturation) -def _adjust_saturation_image(input: features.Image, *, saturation_factor: float) -> features.Image: - output = F.adjust_saturation(input, saturation_factor=saturation_factor) - - return features.Image.new_like(input, output) - - -@Dispatcher -def adjust_contrast(input: T, *, contrast_factor: float) -> T: - """ADDME""" - pass - - -@adjust_contrast.implements(features.Image, pil_kernel=_F.adjust_contrast) -def _adjust_contrast_image(input: features.Image, *, contrast_factor: float) -> features.Image: - output = F.adjust_contrast(input, contrast_factor=contrast_factor) - - return features.Image.new_like(input, output) - - -@Dispatcher -def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: - """ADDME""" - pass - - -@adjust_sharpness.implements(features.Image, pil_kernel=_F.adjust_sharpness) -def _adjust_sharpness_image(input: features.Image, *, sharpness_factor: float) -> features.Image: - output = F.adjust_sharpness(input, sharpness_factor=sharpness_factor) - - return features.Image.new_like(input, output) - - -@Dispatcher -def posterize(input: T, *, bits: int) -> T: - """ADDME""" - pass - - -@posterize.implements(features.Image, pil_kernel=_F.posterize) -def _posterize_image(input: features.Image, *, bits: int) -> features.Image: - output = F.posterize(input, bits=bits) - - return features.Image.new_like(input, output) - - -@Dispatcher -def solarize(input: T, *, threshold: float) -> T: - """ADDME""" - pass - - -@solarize.implements(features.Image, pil_kernel=_F.solarize) -def _solarize_image(input: features.Image, *, threshold: float) -> features.Image: - output = F.solarize(input, threshold=threshold) - - return features.Image.new_like(input, output) - - -@Dispatcher -def autocontrast(input: T) -> T: - """ADDME""" - pass - - -@autocontrast.implements(features.Image, pil_kernel=_F.autocontrast) -def _autocontrast_image(input: features.Image) -> features.Image: - output = F.autocontrast(input) - - return features.Image.new_like(input, output) - - -@Dispatcher -def equalize(input: T) -> T: - """ADDME""" - pass - - -@equalize.implements(features.Image, pil_kernel=_F.equalize) -def _equalize_image(input: features.Image) -> features.Image: - output = F.equalize(input) - - return features.Image.new_like(input, output) - - -@Dispatcher -def invert(input: T) -> T: - """ADDME""" - pass - - -@invert.implements(features.Image, pil_kernel=_F.invert) -def _invert_image(input: features.Image) -> features.Image: - output = F.invert(input) - - return features.Image.new_like(input, output) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 65a6051367c..c2e7f033aba 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,18 +1,51 @@ -from typing import Tuple, List, Optional +from typing import Tuple, List, Optional, TypeVar, Any import torch -from torchvision.transforms import ( # noqa: F401 - functional as _F, - InterpolationMode, -) +from torchvision.prototype import features +from torchvision.transforms import functional as _F, InterpolationMode + +from ._meta_conversion import convert_bounding_box_format +from .utils import dispatch + +T = TypeVar("T", bound=features.Feature) + + +@dispatch +def horizontal_flip(input: T) -> T: + """ADDME""" + pass + horizontal_flip_image = _F.hflip +horizontal_flip.register(features.Image, horizontal_flip_image) -def horizontal_flip_bounding_box(bounding_box: torch.Tensor, *, image_size: Tuple[int, int]) -> torch.Tensor: - bounding_box = bounding_box.clone() +def horizontal_flip_bounding_box( + bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY, copy=True + ) bounding_box[..., (0, 2)] = image_size[1] - bounding_box[..., (2, 0)] - return bounding_box + return convert_bounding_box_format(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format) + + +@horizontal_flip.implements(features.BoundingBox) +def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: + return horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + + +@dispatch +def resize( + input: T, + *, + size: List[int], + interpolation: InterpolationMode = dispatch.FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> T: + """ADDME""" + pass def resize_image( @@ -34,6 +67,9 @@ def resize_image( ).reshape(batch_shape + (num_channels, new_height, new_width)) +resize.register(features.Image, resize_image, pil_kernel=_F.resize) + + def resize_segmentation_mask( segmentation_mask: torch.Tensor, size: List[int], @@ -46,12 +82,12 @@ def resize_segmentation_mask( ) +resize.register(features.SegmentationMask, resize_segmentation_mask) + + # TODO: handle max_size def resize_bounding_box( - bounding_box: torch.Tensor, - *, - old_image_size: List[int], - new_image_size: List[int], + bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int] ) -> torch.Tensor: old_height, old_width = old_image_size new_height, new_width = new_image_size @@ -59,10 +95,77 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) +@resize.implements(features.BoundingBox, wrap_output=False) +def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: + output = resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) + return features.BoundingBox.new_like(input, output, image_size=size) + + +@dispatch +def center_crop(input: T, *, output_size: List[int]) -> T: + """ADDME""" + pass + + center_crop_image = _F.center_crop +center_crop.register(features.Image, center_crop_image) + + +@dispatch +def resized_crop( + input: T, + *, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, +) -> T: + """ADDME""" + pass + resized_crop_image = _F.resized_crop +resized_crop.register(features.Image, resized_crop_image) + + +@dispatch +def affine( + input: T, + *, + angle: float, + translate: List[int], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, + fillcolor: Optional[List[float]] = None, + center: Optional[List[int]] = None, +) -> T: + """ADDME""" + pass + affine_image = _F.affine +affine.register(features.Image, affine_image) + + +@dispatch +def rotate( + input: T, + *, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + center: Optional[List[int]] = None, + fill: Optional[List[float]] = None, + resample: Optional[int] = None, +) -> T: + """ADDME""" + pass + rotate_image = _F.rotate +rotate.register(features.Image, rotate_image) diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/functional/_meta_conversion.py index 484066a39ee..a351d03aea2 100644 --- a/torchvision/prototype/transforms/functional/_meta_conversion.py +++ b/torchvision/prototype/transforms/functional/_meta_conversion.py @@ -34,10 +34,13 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: def convert_bounding_box_format( - bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat + bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, copy: bool = False ) -> torch.Tensor: if new_format == old_format: - return bounding_box + if copy: + return bounding_box.clone() + else: + return bounding_box if old_format == BoundingBoxFormat.XYWH: bounding_box = _xywh_to_xyxy(bounding_box) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index de148ab194a..721d41dbbd2 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,4 +1,19 @@ +from typing import List +from typing import TypeVar + +from torchvision.prototype import features from torchvision.transforms import functional as _F +from .utils import dispatch + +T = TypeVar("T", bound=features.Feature) + + +@dispatch +def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: + """ADDME""" + pass + normalize_image = _F.normalize +normalize.register(features.Image, normalize_image) diff --git a/torchvision/prototype/transforms/functional/dispatch.yaml b/torchvision/prototype/transforms/functional/dispatch.yaml deleted file mode 100644 index 6c963ad840b..00000000000 --- a/torchvision/prototype/transforms/functional/dispatch.yaml +++ /dev/null @@ -1,162 +0,0 @@ -# This is the configuration file to auto-generate the dispatch of feature type specific kernels of the same transform. -# The auto-generation is general enough to apply to all dispatchers that -# 1. take a feature as first input that is solely used to determine the required kernel and -# 2. are feature type preserving. -# -# After you have implemented new kernels, follow the schema explanation below to configure the auto-generation. -# Finally, run `$ python scripts/regenerate_transform_dispatch.py` to regenerate the dispatch. -# -# This configuration uses the following schema: -# -# $DISPATCHER_NAME: -# $FEATURE_NAME: -# kernel: $KERNEL_NAME -# meta_conversion: $META_CONVERSION -# kwargs_overwrite: $KWARGS_OVERWRITE -# meta_overwrite: $META_OVERWRITE -# ... -# -# KERNEL_NAME: Name of the kernel that can be accessed from `torchvision.prototype.transforms.functional`. If the -# canonical naming scheme is followed, this should be $DISPATCHER_NAME followed by $FEATURE_NAME in snake case. -# -# META_CONVERSION: Optional mapping of meta attributes to convert before the kernel call. This is needed if the kernel -# makes assumptions on meta attributes of the input tensor. For example, the `resize_bounding_box` kernel requires -# the input to be in XYXY format. This can be achieved with: -# -# resize: -# BoundingBox: -# kernel: resize_bounding_box -# meta_conversion: -# format: BoundingBoxFormat.XYXY -# -# KWARGS_OVERWRITE: Optional mapping of keyword arguments to apply to the keyword arguments of the dispatcher before -# they get passed to the kernel. If a value is prefixed with a dot, this meta attribute of the input feature is added -# under the given key. This is useful in two scenarios: -# -# 1. The kernel requires meta information that is available as meta attribute on the input feature or -# 2. the kernel parameters names differ from the ones of the dispatcher. -# -# For example, the `resize_bounding_box` kernel takes the `old_image_size` and `new_image_size` parameters, whereas -# the `resize` dispatcher should only take `size`. Thus, we need to take `image_size` stored on the `BoundingBox` -# feature and map it to `old_image_size` as well as mapping the `size` input to `new_image_size`. -# -# resize: -# BoundingBox: -# kernel: resize_bounding_box -# kwargs_overwrite: -# old_image_size: .image_size -# new_image_size: size -# -# META_OVERWRITE: Optional mapping of keyword arguments to overwrite meta attributes when creating the new feature from -# the kernel output. This is needed if the kernel changes a meta attribute. For example, the `resize_bounding_box` -# kernel adapts the values, but the new image size cannot be inferred from them. Thus, we set the `image_size` -# metadata of the new bounding box to be the `size` the dispatcher was called with: -# -# resize: -# BoundingBox: -# kernel: resize_bounding_box -# meta_overwrite: -# image_size: size -# -# For $FEATURE_NAME == Image, there is also the `pil_kernel` key available. This is the name of the legacy `PIL` kernel -# that can be accessed from `torchvision.transforms.functional` (note the missing `.prototype` compared to -# $KERNEL_NAME). It will be called with same arguments as the dispatcher. -# -# If no optional configuration is required, the shortcut -# -# $DISPATCHER_NAME: -# $FEATURE_NAME: $KERNEL_NAME -# -# can be used, which is equivalent to -# -# $DISPATCHER_NAME: -# $FEATURE_NAME: -# kernel: $KERNEL_NAME - -horizontal_flip: - Image: horizontal_flip_image - BoundingBox: - kernel: horizontal_flip_bounding_box - meta_conversion: - format: BoundingBoxFormat.XYXY - kwargs_overwrite: - image_size: .image_size -resize: - Image: - kernel: resize_image - pil_kernel: resize - BoundingBox: - kernel: resize_bounding_box - meta_conversion: - format: BoundingBoxFormat.XYXY - kwargs_overwrite: - old_image_size: .image_size - new_image_size: size - meta_overwrite: - image_size: size - SegmentationMask: resize_segmentation_mask -center_crop: - Image: - kernel: center_crop_image - pil_kernel: center_crop -normalize: - Image: - kernel: normalize_image - meta_overwrite: - color_space: ColorSpace.OTHER -resized_crop: - Image: - kernel: resized_crop_image - pil_kernel: resized_crop -erase: - Image: erase_image -mixup: - Image: mixup_image - OneHotLabel: mixup_one_hot_label -cutmix: - Image: cutmix_image - OneHotLabel: cutmix_one_hot_label -affine: - Image: - kernel: affine_image - pil_kernel: affine -rotate: - Image: - kernel: rotate_image - pil_kernel: rotate -adjust_brightness: - Image: - kernel: adjust_brightness_image - pil_kernel: adjust_brightness -adjust_saturation: - Image: - kernel: adjust_saturation_image - pil_kernel: adjust_saturation -adjust_contrast: - Image: - kernel: adjust_contrast_image - pil_kernel: adjust_contrast -adjust_sharpness: - Image: - kernel: adjust_sharpness_image - pil_kernel: adjust_sharpness -posterize: - Image: - kernel: posterize_image - pil_kernel: posterize -solarize: - Image: - kernel: solarize_image - pil_kernel: solarize -autocontrast: - Image: - kernel: autocontrast_image - pil_kernel: autocontrast -equalize: - Image: - kernel: equalize_image - pil_kernel: equalize -invert: - Image: - kernel: invert_image - pil_kernel: invert diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 7d0620a4a36..e99ece155c8 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -13,63 +13,74 @@ def is_supported(obj: Any, *types: Type) -> bool: return (obj if isinstance(obj, type) else type(obj)) in types -class Dispatcher: +class dispatch: FEATURE_SPECIFIC_PARAM = object() + FEATURE_SPECIFIC_DEFAULT = object() def __init__(self, dispatch_fn): self._dispatch_fn = dispatch_fn - self._kernels = {} - self._pil_kernel: Optional[Callable] = None + self.__doc__ = dispatch_fn.__doc__ + self.__signature__ = inspect.signature(dispatch_fn) + + self._fns = {} + self._pil_fn: Optional[Callable] = None def supports(self, obj: Any) -> bool: - return is_supported(obj, *self._kernels.keys()) + return is_supported(obj, *self._fns.keys()) - def implements(self, feature_type, *, pil_kernel=None): + def register(self, feature_type, fn, *, wrap_output: bool = True, pil_kernel=None) -> None: if pil_kernel is not None: if not issubclass(feature_type, features.Image): raise TypeError("PIL kernel can only be registered for images") - self._pil_kernel = pil_kernel - - def outer_wrapper(implement_fn): - implement_params = inspect.signature(implement_fn).parameters - feature_specific_params = [ - name - for name, param in inspect.signature(self._dispatch_fn).parameters.items() - if param.default is self.FEATURE_SPECIFIC_PARAM - and name in implement_params - and implement_params[name] is inspect.Parameter.empty + self._pil_fn = pil_kernel + + params = inspect.signature(fn).parameters + feature_specific_params = [ + name + for name, param in self.__signature__.parameters.items() + if param.default is self.FEATURE_SPECIFIC_PARAM + and name in params + and params[name].default is inspect.Parameter.empty + ] + + @functools.wraps(fn) + def wrapper(input, *args, **kwargs) -> Any: + missing = [ + param + for param in feature_specific_params + if kwargs.get(param, self.FEATURE_SPECIFIC_PARAM) is self.FEATURE_SPECIFIC_PARAM ] + if missing: + raise TypeError( + f"{self._dispatch_fn.__name__}() missing {len(missing)} required keyword-only arguments " + f"for feature type {feature_type.__name__}: {sequence_to_str(missing, separate_last='and ')}" + ) + + output = fn(input, *args, **kwargs) - @functools.wraps(implement_fn) - def inner_wrapper(*args, **kwargs) -> Any: - missing = [ - param - for param in feature_specific_params - if kwargs.get(param, self.FEATURE_SPECIFIC_PARAM) is self.FEATURE_SPECIFIC_PARAM - ] - if missing: - raise TypeError( - f"{self._dispatch_fn.__name__}() missing {len(missing)} required keyword-only arguments " - f"for feature type {feature_type.__name__}: {sequence_to_str(missing, separate_last='and ')}" - ) + if wrap_output: + output = feature_type.new_like(input, output) - return implement_fn(*args, **kwargs) + return output - self._kernels[feature_type] = inner_wrapper + self._fns[feature_type] = wrapper - return inner_wrapper + def implements(self, feature_type, *, wrap_output=False, pil_kernel=None): + def wrapper(fn): + self.register(feature_type, fn, wrap_output=wrap_output, pil_kernel=pil_kernel) + return fn - return outer_wrapper + return wrapper def __call__(self, input, *args, **kwargs): feature_type = type(input) if issubclass(feature_type, PIL.Image.Image): - if self._pil_kernel is None: + if self._pil_fn is None: raise TypeError("No PIL kernel") - return self._pil_kernel(input, *args, **kwargs) + return self._pil_fn(input, *args, **kwargs) elif not issubclass(feature_type, torch.Tensor): raise TypeError("No tensor") @@ -79,4 +90,4 @@ def __call__(self, input, *args, **kwargs): if not self.supports(feature_type): raise ValueError(f"No support for {feature_type.__name__}") - return self._kernels[feature_type](input, *args, **kwargs) + return self._fns[feature_type](input, *args, **kwargs) From 3ceb056d4c4a2d79ccdc65cbd2342b0e89a7ae47 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Feb 2022 15:29:53 +0100 Subject: [PATCH 09/32] revert unrelated changes --- references/detection/coco_utils.py | 9 +++------ references/detection/train.py | 11 +---------- torchvision/models/detection/keypoint_rcnn.py | 3 --- torchvision/prototype/features/_feature.py | 14 -------------- 4 files changed, 4 insertions(+), 33 deletions(-) diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index b9143f47034..b0f193135ee 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -53,13 +53,12 @@ def __call__(self, image, target): anno = target["annotations"] - # drop all crowd annotations anno = [obj for obj in anno if obj["iscrowd"] == 0] boxes = [obj["bbox"] for obj in anno] # guard against no boxes via resizing boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) - boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy + boxes[:, 2:] += boxes[:, :2] boxes[:, 0::2].clamp_(min=0, max=w) boxes[:, 1::2].clamp_(min=0, max=h) @@ -77,7 +76,7 @@ def __call__(self, image, target): if num_keypoints: keypoints = keypoints.view(num_keypoints, -1, 3) - keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) # valid boxes + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) boxes = boxes[keep] classes = classes[keep] masks = masks[keep] @@ -94,9 +93,7 @@ def __call__(self, image, target): # for conversion to coco api area = torch.tensor([obj["area"] for obj in anno]) - iscrowd = torch.tensor( - [obj["iscrowd"] for obj in anno] - ) # this makes little sense, since we already exlcuded them at the top + iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) target["area"] = area target["iscrowd"] = iscrowd diff --git a/references/detection/train.py b/references/detection/train.py index 6a81d7fa332..765f8144364 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -177,8 +177,6 @@ def main(args): dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) - dataset[0] - print("Creating data loaders") if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -278,12 +276,5 @@ def main(args): if __name__ == "__main__": - args = get_args_parser().parse_args( - [ - "--data-path=/home/philip/datasets/coco", - "--device=cpu", - "--dataset=coco_kp", - "--model=keypointrcnn_resnet50_fpn", - ] - ) + args = get_args_parser().parse_args() main(args) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index c58b72217c9..93e966bae4b 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -253,9 +253,6 @@ def __init__( self.roi_heads.keypoint_head = keypoint_head self.roi_heads.keypoint_predictor = keypoint_predictor - def forward(self, *args, **kwargs): - return super().forward(*args, **kwargs) - class KeypointRCNNHeads(nn.Sequential): def __init__(self, in_channels, layers): diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index d4948048cfe..436ef984ba1 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -57,20 +57,6 @@ def new_like(cls, other, data, *, dtype=None, device=None, **metadata): metadata.setdefault(name, getattr(other, name)) return cls(data, dtype=dtype or other.dtype, device=device or other.device, **metadata) - _TORCH_FUNCTION_ALLOW_MAP = { - torch.Tensor.clone: (0,), - torch.stack: (0, 0), - torch.Tensor.to: (0,), - } - - _DTYPE_CONVERTERS = { - torch.Tensor.to, - } - - _DEVICE_CONVERTERS = { - torch.Tensor.to, - } - @classmethod def __torch_function__( cls, From b3cbfcae10841462ca8a99fcccf324dde551f60e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 09:14:52 +0100 Subject: [PATCH 10/32] remove implements decorator --- torchvision/prototype/transforms/functional/_geometry.py | 8 ++++++-- torchvision/prototype/transforms/functional/utils.py | 7 ------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c2e7f033aba..6e7047d0dad 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -30,11 +30,13 @@ def horizontal_flip_bounding_box( return convert_bounding_box_format(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format) -@horizontal_flip.implements(features.BoundingBox) def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: return horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) +horizontal_flip.register(features.BoundingBox, _horizontal_flip_bounding_box) + + @dispatch def resize( input: T, @@ -95,12 +97,14 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) -@resize.implements(features.BoundingBox, wrap_output=False) def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: output = resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) return features.BoundingBox.new_like(input, output, image_size=size) +resize.register(features.BoundingBox, _resize_bounding_box, wrap_output=False) + + @dispatch def center_crop(input: T, *, output_size: List[int]) -> T: """ADDME""" diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index e99ece155c8..096d9210976 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -66,13 +66,6 @@ def wrapper(input, *args, **kwargs) -> Any: self._fns[feature_type] = wrapper - def implements(self, feature_type, *, wrap_output=False, pil_kernel=None): - def wrapper(fn): - self.register(feature_type, fn, wrap_output=wrap_output, pil_kernel=pil_kernel) - return fn - - return wrapper - def __call__(self, input, *args, **kwargs): feature_type = type(input) From 2a8345a2c4a84e7e6db4de2559f9f571eb322ace Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 09:15:58 +0100 Subject: [PATCH 11/32] change register parameter order --- .../transforms/functional/_augment.py | 10 +++++----- .../prototype/transforms/functional/_color.py | 18 +++++++++--------- .../transforms/functional/_geometry.py | 18 +++++++++--------- .../prototype/transforms/functional/_misc.py | 2 +- .../prototype/transforms/functional/utils.py | 2 +- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 0258ab437a6..3805b195c57 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -17,7 +17,7 @@ def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: erase_image = _F.erase -erase.register(features.Image, erase_image) +erase.register(erase_image, features.Image) @dispatch @@ -41,7 +41,7 @@ def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) return _mixup(image_batch, -4, lam, inplace) -mixup.register(features.Image, mixup_image) +mixup.register(mixup_image, features.Image) def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: @@ -51,7 +51,7 @@ def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplac return _mixup(one_hot_label_batch, -2, lam, inplace) -mixup.register(features.OneHotLabel, mixup_one_hot_label) +mixup.register(mixup_one_hot_label, features.OneHotLabel) @dispatch @@ -80,7 +80,7 @@ def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], i return image_batch -cutmix.register(features.Image, cutmix_image) +cutmix.register(cutmix_image, features.Image) def cutmix_one_hot_label( @@ -92,4 +92,4 @@ def cutmix_one_hot_label( return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) -cutmix.register(features.OneHotLabel, cutmix_one_hot_label) +cutmix.register(cutmix_one_hot_label, features.OneHotLabel) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 2401391bcb5..44e9a0554d4 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -16,7 +16,7 @@ def adjust_brightness(input: T, *, brightness_factor: float) -> T: adjust_brightness_image = _F.adjust_brightness -adjust_brightness.register(features.Image, adjust_brightness_image) +adjust_brightness.register(adjust_brightness_image, features.Image) @dispatch @@ -26,7 +26,7 @@ def adjust_saturation(input: T, *, saturation_factor: float) -> T: adjust_saturation_image = _F.adjust_saturation -adjust_saturation.register(features.Image, adjust_saturation_image) +adjust_saturation.register(adjust_saturation_image, features.Image) @dispatch @@ -36,7 +36,7 @@ def adjust_contrast(input: T, *, contrast_factor: float) -> T: adjust_contrast_image = _F.adjust_contrast -adjust_contrast.register(features.Image, adjust_contrast_image) +adjust_contrast.register(adjust_contrast_image, features.Image) @dispatch @@ -46,7 +46,7 @@ def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: adjust_sharpness_image = _F.adjust_sharpness -adjust_sharpness.register(features.Image, adjust_sharpness_image) +adjust_sharpness.register(adjust_sharpness_image, features.Image) @dispatch @@ -56,7 +56,7 @@ def posterize(input: T, *, bits: int) -> T: posterize_image = _F.posterize -posterize.register(features.Image, posterize_image) +posterize.register(posterize_image, features.Image) @dispatch @@ -66,7 +66,7 @@ def solarize(input: T, *, threshold: float) -> T: solarize_image = _F.solarize -solarize.register(features.Image, solarize_image) +solarize.register(solarize_image, features.Image) @dispatch @@ -76,7 +76,7 @@ def autocontrast(input: T) -> T: autocontrast_image = _F.autocontrast -autocontrast.register(features.Image, autocontrast_image) +autocontrast.register(autocontrast_image, features.Image) @dispatch @@ -86,7 +86,7 @@ def equalize(input: T) -> T: equalize_image = _F.equalize -equalize.register(features.Image, equalize_image) +equalize.register(equalize_image, features.Image) @dispatch @@ -96,4 +96,4 @@ def invert(input: T) -> T: invert_image = _F.invert -invert.register(features.Image, invert_image) +invert.register(invert_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 6e7047d0dad..4ebbd27ef3a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -17,7 +17,7 @@ def horizontal_flip(input: T) -> T: horizontal_flip_image = _F.hflip -horizontal_flip.register(features.Image, horizontal_flip_image) +horizontal_flip.register(horizontal_flip_image, features.Image) def horizontal_flip_bounding_box( @@ -34,7 +34,7 @@ def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: return horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) -horizontal_flip.register(features.BoundingBox, _horizontal_flip_bounding_box) +horizontal_flip.register(_horizontal_flip_bounding_box, features.BoundingBox) @dispatch @@ -69,7 +69,7 @@ def resize_image( ).reshape(batch_shape + (num_channels, new_height, new_width)) -resize.register(features.Image, resize_image, pil_kernel=_F.resize) +resize.register(resize_image, features.Image, pil_kernel=_F.resize) def resize_segmentation_mask( @@ -84,7 +84,7 @@ def resize_segmentation_mask( ) -resize.register(features.SegmentationMask, resize_segmentation_mask) +resize.register(resize_segmentation_mask, features.SegmentationMask) # TODO: handle max_size @@ -102,7 +102,7 @@ def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: A return features.BoundingBox.new_like(input, output, image_size=size) -resize.register(features.BoundingBox, _resize_bounding_box, wrap_output=False) +resize.register(_resize_bounding_box, features.BoundingBox, wrap_output=False) @dispatch @@ -112,7 +112,7 @@ def center_crop(input: T, *, output_size: List[int]) -> T: center_crop_image = _F.center_crop -center_crop.register(features.Image, center_crop_image) +center_crop.register(center_crop_image, features.Image) @dispatch @@ -131,7 +131,7 @@ def resized_crop( resized_crop_image = _F.resized_crop -resized_crop.register(features.Image, resized_crop_image) +resized_crop.register(resized_crop_image, features.Image) @dispatch @@ -153,7 +153,7 @@ def affine( affine_image = _F.affine -affine.register(features.Image, affine_image) +affine.register(affine_image, features.Image) @dispatch @@ -172,4 +172,4 @@ def rotate( rotate_image = _F.rotate -rotate.register(features.Image, rotate_image) +rotate.register(rotate_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 721d41dbbd2..5e395768193 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -16,4 +16,4 @@ def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = normalize_image = _F.normalize -normalize.register(features.Image, normalize_image) +normalize.register(normalize_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 096d9210976..91d1b7fd578 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -28,7 +28,7 @@ def __init__(self, dispatch_fn): def supports(self, obj: Any) -> bool: return is_supported(obj, *self._fns.keys()) - def register(self, feature_type, fn, *, wrap_output: bool = True, pil_kernel=None) -> None: + def register(self, fn, feature_type, *, wrap_output: bool = True, pil_kernel=None) -> None: if pil_kernel is not None: if not issubclass(feature_type, features.Image): raise TypeError("PIL kernel can only be registered for images") From 9518cfbfab7683f8b821f6944229258119bd3b24 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 09:19:43 +0100 Subject: [PATCH 12/32] change order of transforms for readability --- .../transforms/functional/_augment.py | 44 ++++++------ .../prototype/transforms/functional/_color.py | 37 +++++++--- .../transforms/functional/_geometry.py | 68 +++++++++++-------- .../prototype/transforms/functional/_misc.py | 4 +- 4 files changed, 89 insertions(+), 64 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 3805b195c57..3c9f9e5ec78 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -10,22 +10,18 @@ T = TypeVar("T", bound=features.Feature) +erase_image = _F.erase + + @dispatch def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> T: """ADDME""" pass -erase_image = _F.erase erase.register(erase_image, features.Image) -@dispatch -def mixup(input: T, *, lam: float, inplace: bool = False) -> T: - """ADDME""" - pass - - def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: if not inplace: input = input.clone() @@ -41,9 +37,6 @@ def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) return _mixup(image_batch, -4, lam, inplace) -mixup.register(mixup_image, features.Image) - - def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: if one_hot_label_batch.ndim < 2: raise ValueError("Need a batch of one hot labels") @@ -51,21 +44,16 @@ def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplac return _mixup(one_hot_label_batch, -2, lam, inplace) -mixup.register(mixup_one_hot_label, features.OneHotLabel) - - @dispatch -def cutmix( - input: T, - *, - box: Tuple[int, int, int, int] = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] - lam_adjusted: float = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] - inplace: bool = False, -) -> T: +def mixup(input: T, *, lam: float, inplace: bool = False) -> T: """ADDME""" pass +mixup.register(mixup_image, features.Image) +mixup.register(mixup_one_hot_label, features.OneHotLabel) + + def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: if image_batch.ndim < 4: raise ValueError("Need a batch of images") @@ -80,9 +68,6 @@ def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], i return image_batch -cutmix.register(cutmix_image, features.Image) - - def cutmix_one_hot_label( one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False ) -> torch.Tensor: @@ -92,4 +77,17 @@ def cutmix_one_hot_label( return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) +@dispatch +def cutmix( + input: T, + *, + box: Tuple[int, int, int, int] = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + lam_adjusted: float = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + inplace: bool = False, +) -> T: + """ADDME""" + pass + + +cutmix.register(cutmix_image, features.Image) cutmix.register(cutmix_one_hot_label, features.OneHotLabel) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 44e9a0554d4..8ba1270ca9c 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -8,92 +8,109 @@ T = TypeVar("T", bound=features.Feature) +adjust_brightness_image = _F.adjust_brightness + + @dispatch def adjust_brightness(input: T, *, brightness_factor: float) -> T: """ADDME""" pass -adjust_brightness_image = _F.adjust_brightness - adjust_brightness.register(adjust_brightness_image, features.Image) +adjust_saturation_image = _F.adjust_saturation + + @dispatch def adjust_saturation(input: T, *, saturation_factor: float) -> T: """ADDME""" pass -adjust_saturation_image = _F.adjust_saturation adjust_saturation.register(adjust_saturation_image, features.Image) +adjust_contrast_image = _F.adjust_contrast + + @dispatch def adjust_contrast(input: T, *, contrast_factor: float) -> T: """ADDME""" pass -adjust_contrast_image = _F.adjust_contrast adjust_contrast.register(adjust_contrast_image, features.Image) +adjust_sharpness_image = _F.adjust_sharpness + + @dispatch def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: """ADDME""" pass -adjust_sharpness_image = _F.adjust_sharpness adjust_sharpness.register(adjust_sharpness_image, features.Image) +posterize_image = _F.posterize + + @dispatch def posterize(input: T, *, bits: int) -> T: """ADDME""" pass -posterize_image = _F.posterize posterize.register(posterize_image, features.Image) +solarize_image = _F.solarize + + @dispatch def solarize(input: T, *, threshold: float) -> T: """ADDME""" pass -solarize_image = _F.solarize solarize.register(solarize_image, features.Image) +autocontrast_image = _F.autocontrast + + @dispatch def autocontrast(input: T) -> T: """ADDME""" pass -autocontrast_image = _F.autocontrast autocontrast.register(autocontrast_image, features.Image) +equalize_image = _F.equalize + + @dispatch def equalize(input: T) -> T: """ADDME""" pass -equalize_image = _F.equalize equalize.register(equalize_image, features.Image) +invert_image = _F.invert + + @dispatch def invert(input: T) -> T: """ADDME""" pass -invert_image = _F.invert invert.register(invert_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 4ebbd27ef3a..2e003959294 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -10,14 +10,7 @@ T = TypeVar("T", bound=features.Feature) -@dispatch -def horizontal_flip(input: T) -> T: - """ADDME""" - pass - - horizontal_flip_image = _F.hflip -horizontal_flip.register(horizontal_flip_image, features.Image) def horizontal_flip_bounding_box( @@ -30,6 +23,15 @@ def horizontal_flip_bounding_box( return convert_bounding_box_format(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format) +@dispatch +def horizontal_flip(input: T) -> T: + """ADDME""" + pass + + +horizontal_flip.register(horizontal_flip_image, features.Image) + + def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: return horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) @@ -37,19 +39,6 @@ def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: horizontal_flip.register(_horizontal_flip_bounding_box, features.BoundingBox) -@dispatch -def resize( - input: T, - *, - size: List[int], - interpolation: InterpolationMode = dispatch.FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> T: - """ADDME""" - pass - - def resize_image( image: torch.Tensor, size: List[int], @@ -69,9 +58,6 @@ def resize_image( ).reshape(batch_shape + (num_channels, new_height, new_width)) -resize.register(resize_image, features.Image, pil_kernel=_F.resize) - - def resize_segmentation_mask( segmentation_mask: torch.Tensor, size: List[int], @@ -84,9 +70,6 @@ def resize_segmentation_mask( ) -resize.register(resize_segmentation_mask, features.SegmentationMask) - - # TODO: handle max_size def resize_bounding_box( bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int] @@ -97,6 +80,23 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) +@dispatch +def resize( + input: T, + *, + size: List[int], + interpolation: InterpolationMode = dispatch.FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> T: + """ADDME""" + pass + + +resize.register(resize_image, features.Image, pil_kernel=_F.resize) +resize.register(resize_segmentation_mask, features.SegmentationMask) + + def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: output = resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) return features.BoundingBox.new_like(input, output, image_size=size) @@ -105,16 +105,21 @@ def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: A resize.register(_resize_bounding_box, features.BoundingBox, wrap_output=False) +center_crop_image = _F.center_crop + + @dispatch def center_crop(input: T, *, output_size: List[int]) -> T: """ADDME""" pass -center_crop_image = _F.center_crop center_crop.register(center_crop_image, features.Image) +resized_crop_image = _F.resized_crop + + @dispatch def resized_crop( input: T, @@ -130,10 +135,12 @@ def resized_crop( pass -resized_crop_image = _F.resized_crop resized_crop.register(resized_crop_image, features.Image) +affine_image = _F.affine + + @dispatch def affine( input: T, @@ -152,10 +159,12 @@ def affine( pass -affine_image = _F.affine affine.register(affine_image, features.Image) +rotate_image = _F.rotate + + @dispatch def rotate( input: T, @@ -171,5 +180,4 @@ def rotate( pass -rotate_image = _F.rotate rotate.register(rotate_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 5e395768193..41b9d56db6d 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -9,11 +9,13 @@ T = TypeVar("T", bound=features.Feature) +normalize_image = _F.normalize + + @dispatch def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: """ADDME""" pass -normalize_image = _F.normalize normalize.register(normalize_image, features.Image) From a0352861f7e81c313fb31e9ffce5226bd0eae619 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 10:09:03 +0100 Subject: [PATCH 13/32] add documentation for __torch_function__ --- torchvision/prototype/features/_feature.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 70556132c3b..d6d4df8486e 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -84,6 +84,27 @@ def __torch_function__( args: Sequence[Any] = (), kwargs: Optional[Mapping[str, Any]] = None, ) -> torch.Tensor: + """For general information about how the __torch_function__ protocol works, + see https://pytorch.org/docs/stable/notes/extending.html#extending-torch + + TL;DR: Every time a PyTorch operator is called, it goes through the inputs and looks for the + ``__torch_function__`` method. If one is found, it is invoked with the operator as ``func`` as well as the + ``args`` and ``kwargs`` of the original call. + + The default behavior of :class:`~torch.Tensor`'s is to retain a custom tensor type. For the :class:`Feature` + use case, this has two downsides: + + 1. Since some :class:`Feature`'s require metadata to be constructed, the default wrapping, i.e. + ``return cls(func(*args, **kwargs))``, will fail for them. + 2. For most operations, there is no way of knowing if the input type is still valid for the output. + + For these reasons, the automatic output wrapping is turned off for most operators. + + Exceptions to this are: + + - :func:`torch.clone` + - :meth:`torch.Tensor.to` + """ kwargs = kwargs or dict() with DisableTorchFunction(): output = func(*args, **kwargs) From 2d10741e9661e54dc9a8bbe3b55e9fdae45033a1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 10:58:29 +0100 Subject: [PATCH 14/32] fix mypy --- .../prototype/transforms/functional/utils.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 91d1b7fd578..32f79508b57 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -1,6 +1,6 @@ import functools import inspect -from typing import Any, Type, Optional, Callable +from typing import Any, Type, Optional, Callable, TypeVar, Dict, Union import PIL.Image import torch @@ -8,6 +8,8 @@ from torchvision.prototype import features from torchvision.prototype.utils._internal import sequence_to_str +F = TypeVar("F", bound=features.Feature) + def is_supported(obj: Any, *types: Type) -> bool: return (obj if isinstance(obj, type) else type(obj)) in types @@ -17,18 +19,28 @@ class dispatch: FEATURE_SPECIFIC_PARAM = object() FEATURE_SPECIFIC_DEFAULT = object() - def __init__(self, dispatch_fn): + def __init__(self, dispatch_fn: Callable[..., F]): self._dispatch_fn = dispatch_fn self.__doc__ = dispatch_fn.__doc__ self.__signature__ = inspect.signature(dispatch_fn) - self._fns = {} + self._fns: Dict[Type[F], Callable[..., F]] = {} self._pil_fn: Optional[Callable] = None def supports(self, obj: Any) -> bool: return is_supported(obj, *self._fns.keys()) - def register(self, fn, feature_type, *, wrap_output: bool = True, pil_kernel=None) -> None: + def register( + self, + fn: Callable[..., Union[torch.Tensor, F]], + feature_type: Type[F], + *, + wrap_output: bool = True, + pil_kernel: Optional[Callable[..., PIL.Image.Image]] = None, + ) -> None: + if not (issubclass(feature_type, features.Feature) and feature_type is not features.Feature): + raise TypeError("Can only register kernels for subclasses of `torchvision.prototype.features.Feature`.") + if pil_kernel is not None: if not issubclass(feature_type, features.Image): raise TypeError("PIL kernel can only be registered for images") @@ -45,7 +57,7 @@ def register(self, fn, feature_type, *, wrap_output: bool = True, pil_kernel=Non ] @functools.wraps(fn) - def wrapper(input, *args, **kwargs) -> Any: + def wrapper(input: F, *args: Any, **kwargs: Any) -> Any: missing = [ param for param in feature_specific_params @@ -66,7 +78,7 @@ def wrapper(input, *args, **kwargs) -> Any: self._fns[feature_type] = wrapper - def __call__(self, input, *args, **kwargs): + def __call__(self, input: F, *args: Any, **kwargs: Any) -> Union[F, PIL.Image.Image]: feature_type = type(input) if issubclass(feature_type, PIL.Image.Image): From f3d6522ee0166a9dc8a498c8ba176416efd834c6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 11:04:01 +0100 Subject: [PATCH 15/32] inline check for support --- torchvision/prototype/transforms/functional/utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index 32f79508b57..b90e49562ad 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -11,10 +11,6 @@ F = TypeVar("F", bound=features.Feature) -def is_supported(obj: Any, *types: Type) -> bool: - return (obj if isinstance(obj, type) else type(obj)) in types - - class dispatch: FEATURE_SPECIFIC_PARAM = object() FEATURE_SPECIFIC_DEFAULT = object() @@ -28,7 +24,7 @@ def __init__(self, dispatch_fn: Callable[..., F]): self._pil_fn: Optional[Callable] = None def supports(self, obj: Any) -> bool: - return is_supported(obj, *self._fns.keys()) + return (obj if isinstance(obj, type) else type(obj)) in self._fns.keys() def register( self, From b8cda568dcdeee96a5145714686c9198aaab3033 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 12:03:46 +0100 Subject: [PATCH 16/32] refactor kernel registering process --- .../transforms/functional/_augment.py | 31 ++--- .../prototype/transforms/functional/_color.py | 81 ++++++------ .../transforms/functional/_geometry.py | 82 +++++++------ .../prototype/transforms/functional/_misc.py | 9 +- .../prototype/transforms/functional/utils.py | 115 ++++++++---------- 5 files changed, 161 insertions(+), 157 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 3c9f9e5ec78..315eee55fcc 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -13,15 +13,16 @@ erase_image = _F.erase -@dispatch +@dispatch( + { + features.Image: erase_image, + }, +) def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> T: """ADDME""" pass -erase.register(erase_image, features.Image) - - def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: if not inplace: input = input.clone() @@ -44,16 +45,17 @@ def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplac return _mixup(one_hot_label_batch, -2, lam, inplace) -@dispatch +@dispatch( + { + features.Image: mixup_image, + features.OneHotLabel: mixup_one_hot_label, + }, +) def mixup(input: T, *, lam: float, inplace: bool = False) -> T: """ADDME""" pass -mixup.register(mixup_image, features.Image) -mixup.register(mixup_one_hot_label, features.OneHotLabel) - - def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: if image_batch.ndim < 4: raise ValueError("Need a batch of images") @@ -77,7 +79,12 @@ def cutmix_one_hot_label( return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) -@dispatch +@dispatch( + { + features.Image: cutmix_image, + features.OneHotLabel: cutmix_one_hot_label, + }, +) def cutmix( input: T, *, @@ -87,7 +94,3 @@ def cutmix( ) -> T: """ADDME""" pass - - -cutmix.register(cutmix_image, features.Image) -cutmix.register(cutmix_one_hot_label, features.OneHotLabel) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 8ba1270ca9c..99e99323f2b 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -11,106 +11,115 @@ adjust_brightness_image = _F.adjust_brightness -@dispatch +@dispatch( + { + features.Image: adjust_brightness_image, + } +) def adjust_brightness(input: T, *, brightness_factor: float) -> T: """ADDME""" pass -adjust_brightness.register(adjust_brightness_image, features.Image) - - adjust_saturation_image = _F.adjust_saturation -@dispatch +@dispatch( + { + features.Image: adjust_saturation_image, + } +) def adjust_saturation(input: T, *, saturation_factor: float) -> T: """ADDME""" pass -adjust_saturation.register(adjust_saturation_image, features.Image) - - adjust_contrast_image = _F.adjust_contrast -@dispatch +@dispatch( + { + features.Image: adjust_contrast_image, + } +) def adjust_contrast(input: T, *, contrast_factor: float) -> T: """ADDME""" pass -adjust_contrast.register(adjust_contrast_image, features.Image) - - adjust_sharpness_image = _F.adjust_sharpness -@dispatch +@dispatch( + { + features.Image: adjust_sharpness_image, + } +) def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: """ADDME""" pass -adjust_sharpness.register(adjust_sharpness_image, features.Image) - - posterize_image = _F.posterize -@dispatch +@dispatch( + { + features.Image: posterize_image, + } +) def posterize(input: T, *, bits: int) -> T: """ADDME""" pass -posterize.register(posterize_image, features.Image) - - solarize_image = _F.solarize -@dispatch +@dispatch( + { + features.Image: solarize_image, + } +) def solarize(input: T, *, threshold: float) -> T: """ADDME""" pass -solarize.register(solarize_image, features.Image) - - autocontrast_image = _F.autocontrast -@dispatch +@dispatch( + { + features.Image: autocontrast_image, + } +) def autocontrast(input: T) -> T: """ADDME""" pass -autocontrast.register(autocontrast_image, features.Image) - - equalize_image = _F.equalize -@dispatch +@dispatch( + { + features.Image: equalize_image, + } +) def equalize(input: T) -> T: """ADDME""" pass -equalize.register(equalize_image, features.Image) - - invert_image = _F.invert -@dispatch +@dispatch( + { + features.Image: invert_image, + } +) def invert(input: T) -> T: """ADDME""" pass - - -invert.register(invert_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 2e003959294..c854bfc2e40 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -23,20 +23,20 @@ def horizontal_flip_bounding_box( return convert_bounding_box_format(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format) -@dispatch -def horizontal_flip(input: T) -> T: - """ADDME""" - pass - - -horizontal_flip.register(horizontal_flip_image, features.Image) - - def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: return horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) -horizontal_flip.register(_horizontal_flip_bounding_box, features.BoundingBox) +@dispatch( + { + features.Image: horizontal_flip_image, + features.BoundingBox: _horizontal_flip_bounding_box, + }, + pil_kernel=_F.hflip, +) +def horizontal_flip(input: T) -> T: + """ADDME""" + pass def resize_image( @@ -80,7 +80,19 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) -@dispatch +def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: + output = resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) + return features.BoundingBox.new_like(input, output, image_size=size) + + +@dispatch( + { + features.Image: resize_image, + features.SegmentationMask: resize_segmentation_mask, + features.BoundingBox: _resize_bounding_box, + }, + pil_kernel=_F.resize, +) def resize( input: T, *, @@ -93,34 +105,27 @@ def resize( pass -resize.register(resize_image, features.Image, pil_kernel=_F.resize) -resize.register(resize_segmentation_mask, features.SegmentationMask) - - -def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: - output = resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) - return features.BoundingBox.new_like(input, output, image_size=size) - - -resize.register(_resize_bounding_box, features.BoundingBox, wrap_output=False) - - center_crop_image = _F.center_crop -@dispatch +@dispatch( + { + features.Image: center_crop_image, + } +) def center_crop(input: T, *, output_size: List[int]) -> T: """ADDME""" pass -center_crop.register(center_crop_image, features.Image) - - resized_crop_image = _F.resized_crop -@dispatch +@dispatch( + { + features.Image: resized_crop_image, + } +) def resized_crop( input: T, *, @@ -135,13 +140,14 @@ def resized_crop( pass -resized_crop.register(resized_crop_image, features.Image) - - affine_image = _F.affine -@dispatch +@dispatch( + { + features.Image: affine_image, + } +) def affine( input: T, *, @@ -159,13 +165,14 @@ def affine( pass -affine.register(affine_image, features.Image) - - rotate_image = _F.rotate -@dispatch +@dispatch( + { + features.Image: rotate_image, + } +) def rotate( input: T, *, @@ -178,6 +185,3 @@ def rotate( ) -> T: """ADDME""" pass - - -rotate.register(rotate_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 41b9d56db6d..556e834fadf 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -12,10 +12,11 @@ normalize_image = _F.normalize -@dispatch +@dispatch( + { + features.Image: normalize_image, + } +) def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: """ADDME""" pass - - -normalize.register(normalize_image, features.Image) diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py index b90e49562ad..ad714c54dc5 100644 --- a/torchvision/prototype/transforms/functional/utils.py +++ b/torchvision/prototype/transforms/functional/utils.py @@ -1,6 +1,6 @@ import functools import inspect -from typing import Any, Type, Optional, Callable, TypeVar, Dict, Union +from typing import Any, Optional, Callable, TypeVar, Dict, Union import PIL.Image import torch @@ -15,80 +15,67 @@ class dispatch: FEATURE_SPECIFIC_PARAM = object() FEATURE_SPECIFIC_DEFAULT = object() - def __init__(self, dispatch_fn: Callable[..., F]): - self._dispatch_fn = dispatch_fn - self.__doc__ = dispatch_fn.__doc__ - self.__signature__ = inspect.signature(dispatch_fn) - - self._fns: Dict[Type[F], Callable[..., F]] = {} - self._pil_fn: Optional[Callable] = None - - def supports(self, obj: Any) -> bool: - return (obj if isinstance(obj, type) else type(obj)) in self._fns.keys() - - def register( + def __init__( self, - fn: Callable[..., Union[torch.Tensor, F]], - feature_type: Type[F], + kernels: Dict[Any, Callable[..., Union[torch.Tensor, F]]], *, - wrap_output: bool = True, - pil_kernel: Optional[Callable[..., PIL.Image.Image]] = None, + pil_kernel: Optional[Callable] = None, ) -> None: - if not (issubclass(feature_type, features.Feature) and feature_type is not features.Feature): - raise TypeError("Can only register kernels for subclasses of `torchvision.prototype.features.Feature`.") - - if pil_kernel is not None: - if not issubclass(feature_type, features.Image): - raise TypeError("PIL kernel can only be registered for images") - - self._pil_fn = pil_kernel - - params = inspect.signature(fn).parameters - feature_specific_params = [ - name - for name, param in self.__signature__.parameters.items() - if param.default is self.FEATURE_SPECIFIC_PARAM - and name in params - and params[name].default is inspect.Parameter.empty - ] - - @functools.wraps(fn) - def wrapper(input: F, *args: Any, **kwargs: Any) -> Any: - missing = [ - param - for param in feature_specific_params - if kwargs.get(param, self.FEATURE_SPECIFIC_PARAM) is self.FEATURE_SPECIFIC_PARAM + self._kernels = kernels + if pil_kernel and features.Image not in kernels: + raise TypeError("PIL kernel can only be registered for images") + self._pil_kernel = pil_kernel + + def __call__(self, dispatch_fn: Callable[..., F]) -> Callable[..., F]: + params = {feature_type: inspect.signature(kernel).parameters for feature_type, kernel in self._kernels.items()} + feature_specific_params = { + feature_type: [ + name + for name, param in inspect.signature(dispatch_fn).parameters.items() + if param.default is self.FEATURE_SPECIFIC_PARAM + and name in params_ + and params_[name].default is inspect.Parameter.empty ] - if missing: - raise TypeError( - f"{self._dispatch_fn.__name__}() missing {len(missing)} required keyword-only arguments " - f"for feature type {feature_type.__name__}: {sequence_to_str(missing, separate_last='and ')}" - ) + for feature_type, params_ in params.items() + } - output = fn(input, *args, **kwargs) + @functools.wraps(dispatch_fn) + def wrapper(input: F, *args: Any, **kwargs: Any) -> F: + feature_type = type(input) - if wrap_output: - output = feature_type.new_like(input, output) + if issubclass(feature_type, PIL.Image.Image): + if self._pil_kernel is None: + raise TypeError("No PIL kernel") - return output + return self._pil_kernel(input, *args, **kwargs) # type: ignore[no-any-return] - self._fns[feature_type] = wrapper + if not issubclass(feature_type, torch.Tensor): + raise TypeError("No tensor") - def __call__(self, input: F, *args: Any, **kwargs: Any) -> Union[F, PIL.Image.Image]: - feature_type = type(input) + if not issubclass(feature_type, features.Feature): + input = features.Image(input) - if issubclass(feature_type, PIL.Image.Image): - if self._pil_fn is None: - raise TypeError("No PIL kernel") + try: + kernel = self._kernels[feature_type] + except KeyError: + raise ValueError(f"No support for {feature_type.__name__}") from None - return self._pil_fn(input, *args, **kwargs) - elif not issubclass(feature_type, torch.Tensor): - raise TypeError("No tensor") + missing_args = [ + param + for param in feature_specific_params[feature_type] + if kwargs.get(param, self.FEATURE_SPECIFIC_PARAM) is self.FEATURE_SPECIFIC_PARAM + ] + if missing_args: + raise TypeError( + f"{dispatch_fn.__name__}() missing {len(missing_args)} required keyword-only arguments " + f"for feature type {feature_type.__name__}: {sequence_to_str(missing_args, separate_last='and ')}" + ) + + output = kernel(input, *args, **kwargs) - if not issubclass(feature_type, features.Feature): - input = features.Image(input) + if not isinstance(output, feature_type): + output = feature_type.new_like(input, output) - if not self.supports(feature_type): - raise ValueError(f"No support for {feature_type.__name__}") + return output - return self._fns[feature_type](input, *args, **kwargs) + return wrapper From 9a45eb0e6e8042e96df4ad10f835f941dce49e21 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 15:10:56 +0100 Subject: [PATCH 17/32] refactor dispatch to be a regular decorator --- .../transforms/functional/__init__.py | 2 - .../transforms/functional/_augment.py | 6 +- .../prototype/transforms/functional/_color.py | 2 +- .../transforms/functional/_geometry.py | 4 +- .../prototype/transforms/functional/_misc.py | 2 +- .../prototype/transforms/functional/_utils.py | 125 ++++++++++++++++++ .../prototype/transforms/functional/utils.py | 81 ------------ 7 files changed, 132 insertions(+), 90 deletions(-) create mode 100644 torchvision/prototype/transforms/functional/_utils.py delete mode 100644 torchvision/prototype/transforms/functional/utils.py diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 5ef198492c6..1afd3f4aa7c 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,5 +1,3 @@ -from . import utils # usort: skip - from torchvision.transforms import InterpolationMode from ._augment import ( diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 315eee55fcc..070d35f7fba 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -5,7 +5,7 @@ from torchvision.prototype import features from torchvision.transforms import functional as _F -from .utils import dispatch +from ._utils import dispatch, FEATURE_SPECIFIC_PARAM T = TypeVar("T", bound=features.Feature) @@ -88,8 +88,8 @@ def cutmix_one_hot_label( def cutmix( input: T, *, - box: Tuple[int, int, int, int] = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] - lam_adjusted: float = dispatch.FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + box: Tuple[int, int, int, int] = FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] + lam_adjusted: float = FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] inplace: bool = False, ) -> T: """ADDME""" diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 99e99323f2b..1bc0add7fef 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -3,7 +3,7 @@ from torchvision.prototype import features from torchvision.transforms import functional as _F -from .utils import dispatch +from ._utils import dispatch T = TypeVar("T", bound=features.Feature) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index c854bfc2e40..a9f73e81958 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -5,7 +5,7 @@ from torchvision.transforms import functional as _F, InterpolationMode from ._meta_conversion import convert_bounding_box_format -from .utils import dispatch +from ._utils import dispatch, FEATURE_SPECIFIC_DEFAULT T = TypeVar("T", bound=features.Feature) @@ -97,7 +97,7 @@ def resize( input: T, *, size: List[int], - interpolation: InterpolationMode = dispatch.FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] + interpolation: InterpolationMode = FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> T: diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 556e834fadf..29656861002 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -4,7 +4,7 @@ from torchvision.prototype import features from torchvision.transforms import functional as _F -from .utils import dispatch +from ._utils import dispatch T = TypeVar("T", bound=features.Feature) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py new file mode 100644 index 00000000000..869e76845c9 --- /dev/null +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -0,0 +1,125 @@ +import functools +import inspect +from typing import Any, Optional, Callable, TypeVar, Dict, Union + +import PIL.Image +import torch +import torch.overrides +from torchvision.prototype import features +from torchvision.prototype.utils._internal import sequence_to_str + +F = TypeVar("F", bound=features.Feature) + +# Sentinel to use as default value of a dispatcher parameter if it is only required for a subset of the kernels. If the +# decorated function is called without the parameter for a kernel that requires it, an expressive :class:`TypeError` is +# raised. +FEATURE_SPECIFIC_PARAM = object() + +# Sentinel to use as default value of a dispatcher parameter if the kernels use different default values for it. +FEATURE_SPECIFIC_DEFAULT = object() + + +def dispatch( + kernels: Dict[Any, Callable[..., Union[torch.Tensor, F]]], + *, + pil_kernel: Optional[Callable] = None, +) -> Callable[[Callable[..., F]], Callable[..., F]]: + """Decorates a function to automatically dispatch to ``kernels`` based on the call arguments. + + The function body of the dispatcher can be empty as it is never called. The signature and the docstring however are + used in the documentation and thus should be accurate. + + The dispatch function should have this signature + + .. code:: python + + @dispatch + def dispatch_fn(input, *args, **kwargs): + ... + + where ``input`` is a strict subclass of :class:`~torchvision.prototype.features.Feature` and is used to determine + which kernel to dispatch to. + + .. note:: + + For backward compatibility, ``input`` can also be a ``PIL`` image in which case the call will be dispatched to + ``pil_kernel`` if available. Furthermore, ``input`` can also be a vanilla :class:`~torch.Tensor` in which case + it will be converted into a :class:`~torchvision.prototype.features.Image`. + + Args: + kernels: Dictionary of subclasses of :class:`~torchvision.prototype.features.Feature` that maps to a kernel + to call for this feature type. + pil_kernel: Optional kernel for ``PIL`` images. + + Raises: + TypeError: If any key in ``kernels`` is not a strict subclass of + :class:`~torchvision.prototype.features.Feature`. + TypeError: If ``pil_kernel`` is specified, but no kernel for :class:`~torchvision.prototype.features.Image` is + available. + TypeError: If the decorated function is called with neither a ``PIL`` image nor a :class:`~torch.Tensor`. + TypeError: If the decorated function is called with an input that cannot be dispatched. + """ + for feature_type in kernels: + if not (issubclass(feature_type, features.Feature) and feature_type is not features.Feature): + raise TypeError("XXX") + if pil_kernel and features.Image not in kernels: + raise TypeError("PIL kernel can only be registered for images") + + params = {feature_type: inspect.signature(kernel).parameters for feature_type, kernel in kernels.items()} + + def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]: + feature_specific_params = { + feature_type: [ + name + for name, param in inspect.signature(dispatch_fn).parameters.items() + if param.default is FEATURE_SPECIFIC_PARAM + and name in params_ + and params_[name].default is inspect.Parameter.empty + ] + for feature_type, params_ in params.items() + } + + @functools.wraps(dispatch_fn) + def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: + feature_type = type(input) + + if issubclass(feature_type, PIL.Image.Image): + if pil_kernel is None: + raise TypeError("No PIL kernel") + + # TODO: maybe warn or fail here if we have decided on the scope of BC and deprecations + return pil_kernel(input, *args, **kwargs) # type: ignore[no-any-return] + + if not issubclass(feature_type, torch.Tensor): + raise TypeError("No tensor") + + if not issubclass(feature_type, features.Feature): + # TODO: maybe warn or fail here if we have decided on the scope of BC and deprecations + input = features.Image(input) + + try: + kernel = kernels[feature_type] + except KeyError: + raise TypeError(f"No support for {feature_type.__name__}") from None + + missing_args = [ + param + for param in feature_specific_params[feature_type] + if kwargs.get(param, FEATURE_SPECIFIC_PARAM) is FEATURE_SPECIFIC_PARAM + ] + if missing_args: + raise TypeError( + f"{dispatch_fn.__name__}() missing {len(missing_args)} required keyword-only arguments " + f"for feature type {feature_type.__name__}: {sequence_to_str(missing_args, separate_last='and ')}" + ) + + output = kernel(input, *args, **kwargs) + + if not isinstance(output, feature_type): + output = feature_type.new_like(input, output) + + return output + + return inner_wrapper + + return outer_wrapper diff --git a/torchvision/prototype/transforms/functional/utils.py b/torchvision/prototype/transforms/functional/utils.py deleted file mode 100644 index ad714c54dc5..00000000000 --- a/torchvision/prototype/transforms/functional/utils.py +++ /dev/null @@ -1,81 +0,0 @@ -import functools -import inspect -from typing import Any, Optional, Callable, TypeVar, Dict, Union - -import PIL.Image -import torch -import torch.overrides -from torchvision.prototype import features -from torchvision.prototype.utils._internal import sequence_to_str - -F = TypeVar("F", bound=features.Feature) - - -class dispatch: - FEATURE_SPECIFIC_PARAM = object() - FEATURE_SPECIFIC_DEFAULT = object() - - def __init__( - self, - kernels: Dict[Any, Callable[..., Union[torch.Tensor, F]]], - *, - pil_kernel: Optional[Callable] = None, - ) -> None: - self._kernels = kernels - if pil_kernel and features.Image not in kernels: - raise TypeError("PIL kernel can only be registered for images") - self._pil_kernel = pil_kernel - - def __call__(self, dispatch_fn: Callable[..., F]) -> Callable[..., F]: - params = {feature_type: inspect.signature(kernel).parameters for feature_type, kernel in self._kernels.items()} - feature_specific_params = { - feature_type: [ - name - for name, param in inspect.signature(dispatch_fn).parameters.items() - if param.default is self.FEATURE_SPECIFIC_PARAM - and name in params_ - and params_[name].default is inspect.Parameter.empty - ] - for feature_type, params_ in params.items() - } - - @functools.wraps(dispatch_fn) - def wrapper(input: F, *args: Any, **kwargs: Any) -> F: - feature_type = type(input) - - if issubclass(feature_type, PIL.Image.Image): - if self._pil_kernel is None: - raise TypeError("No PIL kernel") - - return self._pil_kernel(input, *args, **kwargs) # type: ignore[no-any-return] - - if not issubclass(feature_type, torch.Tensor): - raise TypeError("No tensor") - - if not issubclass(feature_type, features.Feature): - input = features.Image(input) - - try: - kernel = self._kernels[feature_type] - except KeyError: - raise ValueError(f"No support for {feature_type.__name__}") from None - - missing_args = [ - param - for param in feature_specific_params[feature_type] - if kwargs.get(param, self.FEATURE_SPECIFIC_PARAM) is self.FEATURE_SPECIFIC_PARAM - ] - if missing_args: - raise TypeError( - f"{dispatch_fn.__name__}() missing {len(missing_args)} required keyword-only arguments " - f"for feature type {feature_type.__name__}: {sequence_to_str(missing_args, separate_last='and ')}" - ) - - output = kernel(input, *args, **kwargs) - - if not isinstance(output, feature_type): - output = feature_type.new_like(input, output) - - return output - - return wrapper From 71af6f89d19b31ca4984f1698a3878cb172ef230 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 16:20:05 +0100 Subject: [PATCH 18/32] split kernels and dispatchers --- torchvision/prototype/transforms/__init__.py | 6 +- .../transforms/functional/__init__.py | 44 +--------- .../transforms/functional/_augment.py | 60 ++----------- .../prototype/transforms/functional/_color.py | 47 +++------- .../transforms/functional/_geometry.py | 88 +++---------------- .../prototype/transforms/functional/_misc.py | 7 +- .../prototype/transforms/kernels/__init__.py | 33 +++++++ .../prototype/transforms/kernels/_augment.py | 52 +++++++++++ .../prototype/transforms/kernels/_color.py | 12 +++ .../prototype/transforms/kernels/_geometry.py | 70 +++++++++++++++ .../_meta_conversion.py | 0 .../prototype/transforms/kernels/_misc.py | 22 +++++ .../_type_conversion.py | 0 13 files changed, 224 insertions(+), 217 deletions(-) create mode 100644 torchvision/prototype/transforms/kernels/__init__.py create mode 100644 torchvision/prototype/transforms/kernels/_augment.py create mode 100644 torchvision/prototype/transforms/kernels/_color.py create mode 100644 torchvision/prototype/transforms/kernels/_geometry.py rename torchvision/prototype/transforms/{functional => kernels}/_meta_conversion.py (100%) create mode 100644 torchvision/prototype/transforms/kernels/_misc.py rename torchvision/prototype/transforms/{functional => kernels}/_type_conversion.py (100%) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 1fe3d010b28..40c9486334d 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,4 +1,4 @@ -from . import functional -from .functional import InterpolationMode # usort: skip - +from . import kernels # usort: skip +from . import functional # usort: skip from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval +from .functional import InterpolationMode diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 1afd3f4aa7c..9f05f16df2d 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -1,52 +1,14 @@ -from torchvision.transforms import InterpolationMode - -from ._augment import ( - erase_image, - erase, - mixup_image, - mixup_one_hot_label, - mixup, - cutmix_image, - cutmix_one_hot_label, - cutmix, -) +from ._augment import erase, mixup, cutmix from ._color import ( - adjust_brightness_image, adjust_brightness, - adjust_contrast_image, adjust_contrast, - adjust_saturation_image, adjust_saturation, - adjust_sharpness_image, adjust_sharpness, - posterize_image, posterize, - solarize_image, solarize, - autocontrast_image, autocontrast, - equalize_image, equalize, - invert_image, invert, ) -from ._geometry import ( - horizontal_flip_bounding_box, - horizontal_flip_image, - horizontal_flip, - resize_bounding_box, - resize_image, - resize_segmentation_mask, - resize, - center_crop_image, - center_crop, - resized_crop_image, - resized_crop, - affine_image, - affine, - rotate_image, - rotate, -) -from ._meta_conversion import convert_color_space, convert_bounding_box_format -from ._misc import normalize_image, normalize -from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot +from ._geometry import horizontal_flip, resize, center_crop, resized_crop, affine, rotate +from ._misc import normalize diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 070d35f7fba..2467bc0f694 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -3,19 +3,16 @@ import torch from torchvision.prototype import features -from torchvision.transforms import functional as _F +from torchvision.prototype.transforms import kernels as K from ._utils import dispatch, FEATURE_SPECIFIC_PARAM T = TypeVar("T", bound=features.Feature) -erase_image = _F.erase - - @dispatch( { - features.Image: erase_image, + features.Image: K.erase_image, }, ) def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> T: @@ -23,32 +20,10 @@ def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: pass -def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: - if not inplace: - input = input.clone() - - input_rolled = input.roll(1, batch_dim) - return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) - - -def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - return _mixup(image_batch, -4, lam, inplace) - - -def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup(one_hot_label_batch, -2, lam, inplace) - - @dispatch( { - features.Image: mixup_image, - features.OneHotLabel: mixup_one_hot_label, + features.Image: K.mixup_image, + features.OneHotLabel: K.mixup_one_hot_label, }, ) def mixup(input: T, *, lam: float, inplace: bool = False) -> T: @@ -56,33 +31,10 @@ def mixup(input: T, *, lam: float, inplace: bool = False) -> T: pass -def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - if not inplace: - image_batch = image_batch.clone() - - x1, y1, x2, y2 = box - image_rolled = image_batch.roll(1, -4) - - image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return image_batch - - -def cutmix_one_hot_label( - one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False -) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) - - @dispatch( { - features.Image: cutmix_image, - features.OneHotLabel: cutmix_one_hot_label, + features.Image: K.cutmix_image, + features.OneHotLabel: K.cutmix_one_hot_label, }, ) def cutmix( diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 1bc0add7fef..4097f8f4e93 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,19 +1,16 @@ from typing import TypeVar from torchvision.prototype import features -from torchvision.transforms import functional as _F +from torchvision.prototype.transforms import kernels as K from ._utils import dispatch T = TypeVar("T", bound=features.Feature) -adjust_brightness_image = _F.adjust_brightness - - @dispatch( { - features.Image: adjust_brightness_image, + features.Image: K.adjust_brightness_image, } ) def adjust_brightness(input: T, *, brightness_factor: float) -> T: @@ -21,12 +18,9 @@ def adjust_brightness(input: T, *, brightness_factor: float) -> T: pass -adjust_saturation_image = _F.adjust_saturation - - @dispatch( { - features.Image: adjust_saturation_image, + features.Image: K.adjust_saturation_image, } ) def adjust_saturation(input: T, *, saturation_factor: float) -> T: @@ -34,12 +28,9 @@ def adjust_saturation(input: T, *, saturation_factor: float) -> T: pass -adjust_contrast_image = _F.adjust_contrast - - @dispatch( { - features.Image: adjust_contrast_image, + features.Image: K.adjust_contrast_image, } ) def adjust_contrast(input: T, *, contrast_factor: float) -> T: @@ -47,12 +38,9 @@ def adjust_contrast(input: T, *, contrast_factor: float) -> T: pass -adjust_sharpness_image = _F.adjust_sharpness - - @dispatch( { - features.Image: adjust_sharpness_image, + features.Image: K.adjust_sharpness_image, } ) def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: @@ -60,12 +48,9 @@ def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: pass -posterize_image = _F.posterize - - @dispatch( { - features.Image: posterize_image, + features.Image: K.posterize_image, } ) def posterize(input: T, *, bits: int) -> T: @@ -73,12 +58,9 @@ def posterize(input: T, *, bits: int) -> T: pass -solarize_image = _F.solarize - - @dispatch( { - features.Image: solarize_image, + features.Image: K.solarize_image, } ) def solarize(input: T, *, threshold: float) -> T: @@ -86,12 +68,9 @@ def solarize(input: T, *, threshold: float) -> T: pass -autocontrast_image = _F.autocontrast - - @dispatch( { - features.Image: autocontrast_image, + features.Image: K.autocontrast_image, } ) def autocontrast(input: T) -> T: @@ -99,12 +78,9 @@ def autocontrast(input: T) -> T: pass -equalize_image = _F.equalize - - @dispatch( { - features.Image: equalize_image, + features.Image: K.equalize_image, } ) def equalize(input: T) -> T: @@ -112,12 +88,9 @@ def equalize(input: T) -> T: pass -invert_image = _F.invert - - @dispatch( { - features.Image: invert_image, + features.Image: K.invert_image, } ) def invert(input: T) -> T: diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index a9f73e81958..4ccb8e5abe5 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,35 +1,22 @@ -from typing import Tuple, List, Optional, TypeVar, Any +from typing import List, Optional, TypeVar, Any import torch from torchvision.prototype import features +from torchvision.prototype.transforms import kernels as K from torchvision.transforms import functional as _F, InterpolationMode -from ._meta_conversion import convert_bounding_box_format from ._utils import dispatch, FEATURE_SPECIFIC_DEFAULT T = TypeVar("T", bound=features.Feature) -horizontal_flip_image = _F.hflip - - -def horizontal_flip_bounding_box( - bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int] -) -> torch.Tensor: - bounding_box = convert_bounding_box_format( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY, copy=True - ) - bounding_box[..., (0, 2)] = image_size[1] - bounding_box[..., (2, 0)] - return convert_bounding_box_format(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format) - - def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: - return horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + return K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) @dispatch( { - features.Image: horizontal_flip_image, + features.Image: K.horizontal_flip_image, features.BoundingBox: _horizontal_flip_bounding_box, }, pil_kernel=_F.hflip, @@ -39,56 +26,15 @@ def horizontal_flip(input: T) -> T: pass -def resize_image( - image: torch.Tensor, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> torch.Tensor: - new_height, new_width = size - num_channels, old_height, old_width = image.shape[-3:] - batch_shape = image.shape[:-3] - return _F.resize( - image.reshape((-1, num_channels, old_height, old_width)), - size=size, - interpolation=interpolation, - max_size=max_size, - antialias=antialias, - ).reshape(batch_shape + (num_channels, new_height, new_width)) - - -def resize_segmentation_mask( - segmentation_mask: torch.Tensor, - size: List[int], - interpolation: InterpolationMode = InterpolationMode.NEAREST, - max_size: Optional[int] = None, - antialias: Optional[bool] = None, -) -> torch.Tensor: - return resize_image( - segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias - ) - - -# TODO: handle max_size -def resize_bounding_box( - bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int] -) -> torch.Tensor: - old_height, old_width = old_image_size - new_height, new_width = new_image_size - ratios = torch.tensor((new_width / old_width, new_height / old_height)) - return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) - - def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: - output = resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) + output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) return features.BoundingBox.new_like(input, output, image_size=size) @dispatch( { - features.Image: resize_image, - features.SegmentationMask: resize_segmentation_mask, + features.Image: K.resize_image, + features.SegmentationMask: K.resize_segmentation_mask, features.BoundingBox: _resize_bounding_box, }, pil_kernel=_F.resize, @@ -105,12 +51,9 @@ def resize( pass -center_crop_image = _F.center_crop - - @dispatch( { - features.Image: center_crop_image, + features.Image: K.center_crop_image, } ) def center_crop(input: T, *, output_size: List[int]) -> T: @@ -118,12 +61,9 @@ def center_crop(input: T, *, output_size: List[int]) -> T: pass -resized_crop_image = _F.resized_crop - - @dispatch( { - features.Image: resized_crop_image, + features.Image: K.resized_crop_image, } ) def resized_crop( @@ -140,12 +80,9 @@ def resized_crop( pass -affine_image = _F.affine - - @dispatch( { - features.Image: affine_image, + features.Image: K.affine_image, } ) def affine( @@ -165,12 +102,9 @@ def affine( pass -rotate_image = _F.rotate - - @dispatch( { - features.Image: rotate_image, + features.Image: K.rotate_image, } ) def rotate( diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 29656861002..231f56ecf88 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -2,19 +2,16 @@ from typing import TypeVar from torchvision.prototype import features -from torchvision.transforms import functional as _F +from torchvision.prototype.transforms import kernels as K from ._utils import dispatch T = TypeVar("T", bound=features.Feature) -normalize_image = _F.normalize - - @dispatch( { - features.Image: normalize_image, + features.Image: K.normalize_image, } ) def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: diff --git a/torchvision/prototype/transforms/kernels/__init__.py b/torchvision/prototype/transforms/kernels/__init__.py new file mode 100644 index 00000000000..7d1c89d18d2 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/__init__.py @@ -0,0 +1,33 @@ +from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip + +from ._augment import ( + erase_image, + mixup_image, + mixup_one_hot_label, + cutmix_image, + cutmix_one_hot_label, +) +from ._color import ( + adjust_brightness_image, + adjust_contrast_image, + adjust_saturation_image, + adjust_sharpness_image, + posterize_image, + solarize_image, + autocontrast_image, + equalize_image, + invert_image, +) +from ._geometry import ( + horizontal_flip_bounding_box, + horizontal_flip_image, + resize_bounding_box, + resize_image, + resize_segmentation_mask, + center_crop_image, + resized_crop_image, + affine_image, + rotate_image, +) +from ._misc import normalize_image +from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/kernels/_augment.py b/torchvision/prototype/transforms/kernels/_augment.py new file mode 100644 index 00000000000..842ff0cd5d6 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_augment.py @@ -0,0 +1,52 @@ +from typing import Tuple + +import torch +from torchvision.transforms import functional as _F + + +erase_image = _F.erase + + +def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: + if not inplace: + input = input.clone() + + input_rolled = input.roll(1, batch_dim) + return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) + + +def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + return _mixup(image_batch, -4, lam, inplace) + + +def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup(one_hot_label_batch, -2, lam, inplace) + + +def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: + if image_batch.ndim < 4: + raise ValueError("Need a batch of images") + + if not inplace: + image_batch = image_batch.clone() + + x1, y1, x2, y2 = box + image_rolled = image_batch.roll(1, -4) + + image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return image_batch + + +def cutmix_one_hot_label( + one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False +) -> torch.Tensor: + if one_hot_label_batch.ndim < 2: + raise ValueError("Need a batch of one hot labels") + + return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) diff --git a/torchvision/prototype/transforms/kernels/_color.py b/torchvision/prototype/transforms/kernels/_color.py new file mode 100644 index 00000000000..0d828e6d169 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_color.py @@ -0,0 +1,12 @@ +from torchvision.transforms import functional as _F + + +adjust_brightness_image = _F.adjust_brightness +adjust_saturation_image = _F.adjust_saturation +adjust_contrast_image = _F.adjust_contrast +adjust_sharpness_image = _F.adjust_sharpness +posterize_image = _F.posterize +solarize_image = _F.solarize +autocontrast_image = _F.autocontrast +equalize_image = _F.equalize +invert_image = _F.invert diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py new file mode 100644 index 00000000000..79aa5d80b41 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -0,0 +1,70 @@ +from typing import Tuple, List, Optional, TypeVar + +import torch +from torchvision.prototype import features +from torchvision.transforms import functional as _F, InterpolationMode + +from ._meta_conversion import convert_bounding_box_format + + +T = TypeVar("T", bound=features.Feature) + + +horizontal_flip_image = _F.hflip + + +def horizontal_flip_bounding_box( + bounding_box: torch.Tensor, *, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY, copy=True + ) + bounding_box[..., (0, 2)] = image_size[1] - bounding_box[..., (2, 0)] + return convert_bounding_box_format(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format) + + +def resize_image( + image: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + new_height, new_width = size + num_channels, old_height, old_width = image.shape[-3:] + batch_shape = image.shape[:-3] + return _F.resize( + image.reshape((-1, num_channels, old_height, old_width)), + size=size, + interpolation=interpolation, + max_size=max_size, + antialias=antialias, + ).reshape(batch_shape + (num_channels, new_height, new_width)) + + +def resize_segmentation_mask( + segmentation_mask: torch.Tensor, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> torch.Tensor: + return resize_image( + segmentation_mask, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + + +# TODO: handle max_size +def resize_bounding_box( + bounding_box: torch.Tensor, *, old_image_size: List[int], new_image_size: List[int] +) -> torch.Tensor: + old_height, old_width = old_image_size + new_height, new_width = new_image_size + ratios = torch.tensor((new_width / old_width, new_height / old_height)) + return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) + + +center_crop_image = _F.center_crop +resized_crop_image = _F.resized_crop +affine_image = _F.affine +rotate_image = _F.rotate diff --git a/torchvision/prototype/transforms/functional/_meta_conversion.py b/torchvision/prototype/transforms/kernels/_meta_conversion.py similarity index 100% rename from torchvision/prototype/transforms/functional/_meta_conversion.py rename to torchvision/prototype/transforms/kernels/_meta_conversion.py diff --git a/torchvision/prototype/transforms/kernels/_misc.py b/torchvision/prototype/transforms/kernels/_misc.py new file mode 100644 index 00000000000..29656861002 --- /dev/null +++ b/torchvision/prototype/transforms/kernels/_misc.py @@ -0,0 +1,22 @@ +from typing import List +from typing import TypeVar + +from torchvision.prototype import features +from torchvision.transforms import functional as _F + +from ._utils import dispatch + +T = TypeVar("T", bound=features.Feature) + + +normalize_image = _F.normalize + + +@dispatch( + { + features.Image: normalize_image, + } +) +def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: + """ADDME""" + pass diff --git a/torchvision/prototype/transforms/functional/_type_conversion.py b/torchvision/prototype/transforms/kernels/_type_conversion.py similarity index 100% rename from torchvision/prototype/transforms/functional/_type_conversion.py rename to torchvision/prototype/transforms/kernels/_type_conversion.py From 4c13812d4fc13530bb6174809147578744bd7038 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 22:47:58 +0100 Subject: [PATCH 19/32] remove sentinels --- torchvision/prototype/transforms/__init__.py | 1 - .../transforms/functional/_augment.py | 42 ++++++++++++++----- .../transforms/functional/_geometry.py | 34 +++++++-------- .../prototype/transforms/functional/_misc.py | 2 +- .../prototype/transforms/functional/_utils.py | 34 --------------- .../prototype/transforms/kernels/__init__.py | 1 + .../prototype/transforms/kernels/_misc.py | 18 -------- 7 files changed, 50 insertions(+), 82 deletions(-) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 40c9486334d..eb84789fcd1 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,4 +1,3 @@ from . import kernels # usort: skip from . import functional # usort: skip from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval -from .functional import InterpolationMode diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 2467bc0f694..718a90852dc 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -5,7 +5,7 @@ from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K -from ._utils import dispatch, FEATURE_SPECIFIC_PARAM +from ._utils import dispatch T = TypeVar("T", bound=features.Feature) @@ -15,7 +15,7 @@ features.Image: K.erase_image, }, ) -def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False) -> T: +def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool) -> T: """ADDME""" pass @@ -26,7 +26,7 @@ def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: features.OneHotLabel: K.mixup_one_hot_label, }, ) -def mixup(input: T, *, lam: float, inplace: bool = False) -> T: +def mixup(input: T, *, lam: float, inplace: bool) -> T: """ADDME""" pass @@ -37,12 +37,32 @@ def mixup(input: T, *, lam: float, inplace: bool = False) -> T: features.OneHotLabel: K.cutmix_one_hot_label, }, ) -def cutmix( - input: T, - *, - box: Tuple[int, int, int, int] = FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] - lam_adjusted: float = FEATURE_SPECIFIC_PARAM, # type: ignore[assignment] - inplace: bool = False, -) -> T: - """ADDME""" +def cutmix(input: T, *, box: Tuple[int, int, int, int], lam_adjusted: float, inplace: bool) -> T: + """Perform the CutMix operation as introduced in the paper + `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. + + Dispatch to the corresponding kernels happens according to this table: + + .. table:: + :widths: 30 70 + + ==================================================== ================================================================ + :class:`~torchvision.prototype.features.Image` :func:`~torch.prototype.transforms.kernels.cutmix_image` + :class:`~torchvision.prototype.features.OneHotLabel` :func:`~torch.prototype.transforms.kernels.cutmix_one_hot_label` + ==================================================== ================================================================ + + Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. + + .. note:: + + The ``box`` parameter is only required for inputs of type + + - :class:`~torchvision.prototype.features.Image` + + .. note:: + + The ``lam_adjusted`` parameter is only required for inputs of type + + - :class:`~torchvision.prototype.features.OneHotLabel` + """ pass diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 4ccb8e5abe5..da3a76a5017 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,11 +1,11 @@ -from typing import List, Optional, TypeVar, Any +from typing import List, Optional, TypeVar import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K from torchvision.transforms import functional as _F, InterpolationMode -from ._utils import dispatch, FEATURE_SPECIFIC_DEFAULT +from ._utils import dispatch T = TypeVar("T", bound=features.Feature) @@ -26,7 +26,7 @@ def horizontal_flip(input: T) -> T: pass -def _resize_bounding_box(input: features.BoundingBox, *, size: List[int], **_: Any) -> features.BoundingBox: +def _resize_bounding_box(input: features.BoundingBox, *, size: List[int]) -> features.BoundingBox: output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) return features.BoundingBox.new_like(input, output, image_size=size) @@ -43,9 +43,9 @@ def resize( input: T, *, size: List[int], - interpolation: InterpolationMode = FEATURE_SPECIFIC_DEFAULT, # type: ignore[assignment] - max_size: Optional[int] = None, - antialias: Optional[bool] = None, + interpolation: InterpolationMode, + max_size: Optional[int], + antialias: Optional[bool], ) -> T: """ADDME""" pass @@ -74,7 +74,7 @@ def resized_crop( height: int, width: int, size: List[int], - interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation: InterpolationMode, ) -> T: """ADDME""" pass @@ -92,11 +92,11 @@ def affine( translate: List[int], scale: float, shear: List[float], - interpolation: InterpolationMode = InterpolationMode.NEAREST, - fill: Optional[List[float]] = None, - resample: Optional[int] = None, - fillcolor: Optional[List[float]] = None, - center: Optional[List[int]] = None, + interpolation: InterpolationMode, + fill: Optional[List[float]], + resample: Optional[int], + fillcolor: Optional[List[float]], + center: Optional[List[int]], ) -> T: """ADDME""" pass @@ -111,11 +111,11 @@ def rotate( input: T, *, angle: float, - interpolation: InterpolationMode = InterpolationMode.NEAREST, - expand: bool = False, - center: Optional[List[int]] = None, - fill: Optional[List[float]] = None, - resample: Optional[int] = None, + interpolation: InterpolationMode, + expand: bool, + center: Optional[List[int]], + fill: Optional[List[float]], + resample: Optional[int], ) -> T: """ADDME""" pass diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 231f56ecf88..9929109623a 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -14,6 +14,6 @@ features.Image: K.normalize_image, } ) -def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: +def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool) -> T: """ADDME""" pass diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 869e76845c9..6ded898ee1b 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,23 +1,13 @@ import functools -import inspect from typing import Any, Optional, Callable, TypeVar, Dict, Union import PIL.Image import torch import torch.overrides from torchvision.prototype import features -from torchvision.prototype.utils._internal import sequence_to_str F = TypeVar("F", bound=features.Feature) -# Sentinel to use as default value of a dispatcher parameter if it is only required for a subset of the kernels. If the -# decorated function is called without the parameter for a kernel that requires it, an expressive :class:`TypeError` is -# raised. -FEATURE_SPECIFIC_PARAM = object() - -# Sentinel to use as default value of a dispatcher parameter if the kernels use different default values for it. -FEATURE_SPECIFIC_DEFAULT = object() - def dispatch( kernels: Dict[Any, Callable[..., Union[torch.Tensor, F]]], @@ -65,20 +55,7 @@ def dispatch_fn(input, *args, **kwargs): if pil_kernel and features.Image not in kernels: raise TypeError("PIL kernel can only be registered for images") - params = {feature_type: inspect.signature(kernel).parameters for feature_type, kernel in kernels.items()} - def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]: - feature_specific_params = { - feature_type: [ - name - for name, param in inspect.signature(dispatch_fn).parameters.items() - if param.default is FEATURE_SPECIFIC_PARAM - and name in params_ - and params_[name].default is inspect.Parameter.empty - ] - for feature_type, params_ in params.items() - } - @functools.wraps(dispatch_fn) def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: feature_type = type(input) @@ -102,17 +79,6 @@ def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: except KeyError: raise TypeError(f"No support for {feature_type.__name__}") from None - missing_args = [ - param - for param in feature_specific_params[feature_type] - if kwargs.get(param, FEATURE_SPECIFIC_PARAM) is FEATURE_SPECIFIC_PARAM - ] - if missing_args: - raise TypeError( - f"{dispatch_fn.__name__}() missing {len(missing_args)} required keyword-only arguments " - f"for feature type {feature_type.__name__}: {sequence_to_str(missing_args, separate_last='and ')}" - ) - output = kernel(input, *args, **kwargs) if not isinstance(output, feature_type): diff --git a/torchvision/prototype/transforms/kernels/__init__.py b/torchvision/prototype/transforms/kernels/__init__.py index 7d1c89d18d2..6f74f6af0e9 100644 --- a/torchvision/prototype/transforms/kernels/__init__.py +++ b/torchvision/prototype/transforms/kernels/__init__.py @@ -1,3 +1,4 @@ +from torchvision.transforms import InterpolationMode # usort: skip from ._meta_conversion import convert_bounding_box_format, convert_color_space # usort: skip from ._augment import ( diff --git a/torchvision/prototype/transforms/kernels/_misc.py b/torchvision/prototype/transforms/kernels/_misc.py index 29656861002..de148ab194a 100644 --- a/torchvision/prototype/transforms/kernels/_misc.py +++ b/torchvision/prototype/transforms/kernels/_misc.py @@ -1,22 +1,4 @@ -from typing import List -from typing import TypeVar - -from torchvision.prototype import features from torchvision.transforms import functional as _F -from ._utils import dispatch - -T = TypeVar("T", bound=features.Feature) - normalize_image = _F.normalize - - -@dispatch( - { - features.Image: normalize_image, - } -) -def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool = False) -> T: - """ADDME""" - pass From f5df19461c0fcd1c9ce673e5dc88c096dcb87861 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 22:48:29 +0100 Subject: [PATCH 20/32] replace pass with ... --- .../transforms/functional/_augment.py | 6 +++--- .../prototype/transforms/functional/_color.py | 18 +++++++++--------- .../transforms/functional/_geometry.py | 12 ++++++------ .../prototype/transforms/functional/_misc.py | 2 +- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 718a90852dc..7fbd02dd388 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -17,7 +17,7 @@ ) def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool) -> T: """ADDME""" - pass + ... @dispatch( @@ -28,7 +28,7 @@ def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: ) def mixup(input: T, *, lam: float, inplace: bool) -> T: """ADDME""" - pass + ... @dispatch( @@ -65,4 +65,4 @@ def cutmix(input: T, *, box: Tuple[int, int, int, int], lam_adjusted: float, inp - :class:`~torchvision.prototype.features.OneHotLabel` """ - pass + ... diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 4097f8f4e93..a493b4ef726 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -15,7 +15,7 @@ ) def adjust_brightness(input: T, *, brightness_factor: float) -> T: """ADDME""" - pass + ... @dispatch( @@ -25,7 +25,7 @@ def adjust_brightness(input: T, *, brightness_factor: float) -> T: ) def adjust_saturation(input: T, *, saturation_factor: float) -> T: """ADDME""" - pass + ... @dispatch( @@ -35,7 +35,7 @@ def adjust_saturation(input: T, *, saturation_factor: float) -> T: ) def adjust_contrast(input: T, *, contrast_factor: float) -> T: """ADDME""" - pass + ... @dispatch( @@ -45,7 +45,7 @@ def adjust_contrast(input: T, *, contrast_factor: float) -> T: ) def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: """ADDME""" - pass + ... @dispatch( @@ -55,7 +55,7 @@ def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: ) def posterize(input: T, *, bits: int) -> T: """ADDME""" - pass + ... @dispatch( @@ -65,7 +65,7 @@ def posterize(input: T, *, bits: int) -> T: ) def solarize(input: T, *, threshold: float) -> T: """ADDME""" - pass + ... @dispatch( @@ -75,7 +75,7 @@ def solarize(input: T, *, threshold: float) -> T: ) def autocontrast(input: T) -> T: """ADDME""" - pass + ... @dispatch( @@ -85,7 +85,7 @@ def autocontrast(input: T) -> T: ) def equalize(input: T) -> T: """ADDME""" - pass + ... @dispatch( @@ -95,4 +95,4 @@ def equalize(input: T) -> T: ) def invert(input: T) -> T: """ADDME""" - pass + ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index da3a76a5017..baa30689879 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -23,7 +23,7 @@ def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: ) def horizontal_flip(input: T) -> T: """ADDME""" - pass + ... def _resize_bounding_box(input: features.BoundingBox, *, size: List[int]) -> features.BoundingBox: @@ -48,7 +48,7 @@ def resize( antialias: Optional[bool], ) -> T: """ADDME""" - pass + ... @dispatch( @@ -58,7 +58,7 @@ def resize( ) def center_crop(input: T, *, output_size: List[int]) -> T: """ADDME""" - pass + ... @dispatch( @@ -77,7 +77,7 @@ def resized_crop( interpolation: InterpolationMode, ) -> T: """ADDME""" - pass + ... @dispatch( @@ -99,7 +99,7 @@ def affine( center: Optional[List[int]], ) -> T: """ADDME""" - pass + ... @dispatch( @@ -118,4 +118,4 @@ def rotate( resample: Optional[int], ) -> T: """ADDME""" - pass + ... diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 9929109623a..70ee1a76c46 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -16,4 +16,4 @@ ) def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool) -> T: """ADDME""" - pass + ... From 9014e209e49711424f6deb5f7f8e95e32c462d90 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Feb 2022 23:01:34 +0100 Subject: [PATCH 21/32] appease mypy --- docs/source/transforms.rst | 4 ++++ torchvision/prototype/features/_bounding_box.py | 2 +- torchvision/prototype/features/_encoded.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 7f835267200..935f74ed8cc 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -3,6 +3,10 @@ Transforming and augmenting images ================================== +.. currentmodule:: torchvision.prototype.transforms.functional + +.. autofunction:: cutmix + .. currentmodule:: torchvision.transforms Transforms are common image transformations available in the diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index fbf4522be93..1ffd1fb84dc 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -40,7 +40,7 @@ def __new__( def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.functional import convert_bounding_box_format + from torchvision.prototype.transforms.kernels import convert_bounding_box_format if isinstance(format, str): format = BoundingBoxFormat[format] diff --git a/torchvision/prototype/features/_encoded.py b/torchvision/prototype/features/_encoded.py index ed2ede62921..ea8bdeae32e 100644 --- a/torchvision/prototype/features/_encoded.py +++ b/torchvision/prototype/features/_encoded.py @@ -40,7 +40,7 @@ def image_size(self) -> Tuple[int, int]: def decode(self) -> Image: # import at runtime to avoid cyclic imports - from torchvision.prototype.transforms.functional import decode_image_with_pil + from torchvision.prototype.transforms.kernels import decode_image_with_pil return Image(decode_image_with_pil(self)) From 0238184ed24d518da8fe5c80520c933cf7d52e51 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 09:12:45 +0100 Subject: [PATCH 22/32] make single kernel dispatchers more concise --- .../transforms/functional/_augment.py | 6 +-- .../prototype/transforms/functional/_color.py | 54 ++++--------------- .../transforms/functional/_geometry.py | 24 ++------- .../prototype/transforms/functional/_misc.py | 6 +-- 4 files changed, 15 insertions(+), 75 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 7fbd02dd388..c93aa03214a 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -10,11 +10,7 @@ T = TypeVar("T", bound=features.Feature) -@dispatch( - { - features.Image: K.erase_image, - }, -) +@dispatch({features.Image: K.erase_image}) def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool) -> T: """ADDME""" ... diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index a493b4ef726..3e520b5e7a2 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -8,91 +8,55 @@ T = TypeVar("T", bound=features.Feature) -@dispatch( - { - features.Image: K.adjust_brightness_image, - } -) +@dispatch({features.Image: K.adjust_brightness_image}) def adjust_brightness(input: T, *, brightness_factor: float) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.adjust_saturation_image, - } -) +@dispatch({features.Image: K.adjust_saturation_image}) def adjust_saturation(input: T, *, saturation_factor: float) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.adjust_contrast_image, - } -) +@dispatch({features.Image: K.adjust_contrast_image}) def adjust_contrast(input: T, *, contrast_factor: float) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.adjust_sharpness_image, - } -) +@dispatch({features.Image: K.adjust_sharpness_image}) def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.posterize_image, - } -) +@dispatch({features.Image: K.posterize_image}) def posterize(input: T, *, bits: int) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.solarize_image, - } -) +@dispatch({features.Image: K.solarize_image}) def solarize(input: T, *, threshold: float) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.autocontrast_image, - } -) +@dispatch({features.Image: K.autocontrast_image}) def autocontrast(input: T) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.equalize_image, - } -) +@dispatch({features.Image: K.equalize_image}) def equalize(input: T) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.invert_image, - } -) +@dispatch({features.Image: K.invert_image}) def invert(input: T) -> T: """ADDME""" ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index baa30689879..e8efb85a525 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -51,21 +51,13 @@ def resize( ... -@dispatch( - { - features.Image: K.center_crop_image, - } -) +@dispatch({features.Image: K.center_crop_image}) def center_crop(input: T, *, output_size: List[int]) -> T: """ADDME""" ... -@dispatch( - { - features.Image: K.resized_crop_image, - } -) +@dispatch({features.Image: K.resized_crop_image}) def resized_crop( input: T, *, @@ -80,11 +72,7 @@ def resized_crop( ... -@dispatch( - { - features.Image: K.affine_image, - } -) +@dispatch({features.Image: K.affine_image}) def affine( input: T, *, @@ -102,11 +90,7 @@ def affine( ... -@dispatch( - { - features.Image: K.rotate_image, - } -) +@dispatch({features.Image: K.rotate_image}) def rotate( input: T, *, diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 70ee1a76c46..a7fc2084c3f 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -9,11 +9,7 @@ T = TypeVar("T", bound=features.Feature) -@dispatch( - { - features.Image: K.normalize_image, - } -) +@dispatch({features.Image: K.normalize_image}) def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool) -> T: """ADDME""" ... From 22f4d29964d64b7234e1f51bcb77e42e7bdfa2b8 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 09:17:27 +0100 Subject: [PATCH 23/32] make dispatcher signatures more generic --- torchvision/prototype/transforms/__init__.py | 2 + .../transforms/functional/_augment.py | 22 ++------ .../prototype/transforms/functional/_color.py | 20 +++---- .../transforms/functional/_geometry.py | 53 +++---------------- .../prototype/transforms/functional/_misc.py | 5 +- 5 files changed, 26 insertions(+), 76 deletions(-) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index eb84789fcd1..c9988be1930 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -1,3 +1,5 @@ from . import kernels # usort: skip from . import functional # usort: skip +from .kernels import InterpolationMode # usort: skip + from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index c93aa03214a..164e7b2ec5e 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,7 +1,5 @@ -from typing import Tuple -from typing import TypeVar +from typing import TypeVar, Any -import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K @@ -11,7 +9,7 @@ @dispatch({features.Image: K.erase_image}) -def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool) -> T: +def erase(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @@ -22,7 +20,7 @@ def erase(input: T, *, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: features.OneHotLabel: K.mixup_one_hot_label, }, ) -def mixup(input: T, *, lam: float, inplace: bool) -> T: +def mixup(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @@ -33,7 +31,7 @@ def mixup(input: T, *, lam: float, inplace: bool) -> T: features.OneHotLabel: K.cutmix_one_hot_label, }, ) -def cutmix(input: T, *, box: Tuple[int, int, int, int], lam_adjusted: float, inplace: bool) -> T: +def cutmix(input: T, *args: Any, **kwargs: Any) -> T: """Perform the CutMix operation as introduced in the paper `"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" `_. @@ -48,17 +46,5 @@ def cutmix(input: T, *, box: Tuple[int, int, int, int], lam_adjusted: float, inp ==================================================== ================================================================ Please refer to the kernel documentations for a detailed explanation of the functionality and parameters. - - .. note:: - - The ``box`` parameter is only required for inputs of type - - - :class:`~torchvision.prototype.features.Image` - - .. note:: - - The ``lam_adjusted`` parameter is only required for inputs of type - - - :class:`~torchvision.prototype.features.OneHotLabel` """ ... diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 3e520b5e7a2..923f9ea0e3b 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,4 +1,4 @@ -from typing import TypeVar +from typing import TypeVar, Any from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K @@ -9,54 +9,54 @@ @dispatch({features.Image: K.adjust_brightness_image}) -def adjust_brightness(input: T, *, brightness_factor: float) -> T: +def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.adjust_saturation_image}) -def adjust_saturation(input: T, *, saturation_factor: float) -> T: +def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.adjust_contrast_image}) -def adjust_contrast(input: T, *, contrast_factor: float) -> T: +def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.adjust_sharpness_image}) -def adjust_sharpness(input: T, *, sharpness_factor: float) -> T: +def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.posterize_image}) -def posterize(input: T, *, bits: int) -> T: +def posterize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.solarize_image}) -def solarize(input: T, *, threshold: float) -> T: +def solarize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.autocontrast_image}) -def autocontrast(input: T) -> T: +def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.equalize_image}) -def equalize(input: T) -> T: +def equalize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.invert_image}) -def invert(input: T) -> T: +def invert(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index e8efb85a525..ffe8c73c088 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,9 +1,9 @@ -from typing import List, Optional, TypeVar +from typing import List, TypeVar, Any import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K -from torchvision.transforms import functional as _F, InterpolationMode +from torchvision.transforms import functional as _F from ._utils import dispatch @@ -21,7 +21,7 @@ def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: }, pil_kernel=_F.hflip, ) -def horizontal_flip(input: T) -> T: +def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @@ -39,67 +39,30 @@ def _resize_bounding_box(input: features.BoundingBox, *, size: List[int]) -> fea }, pil_kernel=_F.resize, ) -def resize( - input: T, - *, - size: List[int], - interpolation: InterpolationMode, - max_size: Optional[int], - antialias: Optional[bool], -) -> T: +def resize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.center_crop_image}) -def center_crop(input: T, *, output_size: List[int]) -> T: +def center_crop(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.resized_crop_image}) -def resized_crop( - input: T, - *, - top: int, - left: int, - height: int, - width: int, - size: List[int], - interpolation: InterpolationMode, -) -> T: +def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.affine_image}) -def affine( - input: T, - *, - angle: float, - translate: List[int], - scale: float, - shear: List[float], - interpolation: InterpolationMode, - fill: Optional[List[float]], - resample: Optional[int], - fillcolor: Optional[List[float]], - center: Optional[List[int]], -) -> T: +def affine(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @dispatch({features.Image: K.rotate_image}) -def rotate( - input: T, - *, - angle: float, - interpolation: InterpolationMode, - expand: bool, - center: Optional[List[int]], - fill: Optional[List[float]], - resample: Optional[int], -) -> T: +def rotate(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index a7fc2084c3f..44b84b499a8 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,5 +1,4 @@ -from typing import List -from typing import TypeVar +from typing import TypeVar, Any from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K @@ -10,6 +9,6 @@ @dispatch({features.Image: K.normalize_image}) -def normalize(input: T, *, mean: List[float], std: List[float], inplace: bool) -> T: +def normalize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... From 1cd2166780bb1a351bc245d50a304d7d55bee7a5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 09:35:17 +0100 Subject: [PATCH 24/32] make kernel checking more strict --- .../prototype/transforms/functional/_utils.py | 33 +++++++++++++++++-- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 6ded898ee1b..1f46a586b6e 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,4 +1,5 @@ import functools +import inspect from typing import Any, Optional, Callable, TypeVar, Dict, Union import PIL.Image @@ -23,8 +24,14 @@ def dispatch( .. code:: python + from typing import Any, TypeVar + + from torchvision.protoype import features + + T = TypeVar("T", bound=features.Feature) + @dispatch - def dispatch_fn(input, *args, **kwargs): + def dispatch_fn(input: T, *args: Any, **kwargs: Any) -> T: ... where ``input`` is a strict subclass of :class:`~torchvision.prototype.features.Feature` and is used to determine @@ -44,14 +51,34 @@ def dispatch_fn(input, *args, **kwargs): Raises: TypeError: If any key in ``kernels`` is not a strict subclass of :class:`~torchvision.prototype.features.Feature`. + TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``. TypeError: If ``pil_kernel`` is specified, but no kernel for :class:`~torchvision.prototype.features.Image` is available. TypeError: If the decorated function is called with neither a ``PIL`` image nor a :class:`~torch.Tensor`. TypeError: If the decorated function is called with an input that cannot be dispatched. """ - for feature_type in kernels: + + def check_kernel(kernel: Any) -> bool: + if not callable(kernel): + return False + + params = list(inspect.signature(kernel).parameters.values()) + if not params: + return False + + return params[0].kind != inspect.Parameter.KEYWORD_ONLY + + for feature_type, kernel in kernels.items(): if not (issubclass(feature_type, features.Feature) and feature_type is not features.Feature): - raise TypeError("XXX") + raise TypeError( + "Can only register kernels for strict subclasses of `torchvision.prototype.features.Feature`." + ) + + if not check_kernel(kernel): + raise TypeError( + f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)." + ) + if pil_kernel and features.Image not in kernels: raise TypeError("PIL kernel can only be registered for images") From cca5040cc954e0c597a31e9b2aac84a25d427a0f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 09:36:10 +0100 Subject: [PATCH 25/32] revert doc changes --- docs/source/transforms.rst | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 935f74ed8cc..7f835267200 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -3,10 +3,6 @@ Transforming and augmenting images ================================== -.. currentmodule:: torchvision.prototype.transforms.functional - -.. autofunction:: cutmix - .. currentmodule:: torchvision.transforms Transforms are common image transformations available in the From 4216d9150d9e87a0c4041e41cb68f8f42b02b7de Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 10:53:30 +0100 Subject: [PATCH 26/32] address Franciscos comments --- torchvision/prototype/transforms/kernels/_geometry.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/kernels/_geometry.py b/torchvision/prototype/transforms/kernels/_geometry.py index 5fff969e8d2..c3cbbb34b02 100644 --- a/torchvision/prototype/transforms/kernels/_geometry.py +++ b/torchvision/prototype/transforms/kernels/_geometry.py @@ -22,7 +22,7 @@ def horizontal_flip_bounding_box( bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY ).view(-1, 4) - bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [0, 2]] + bounding_box[:, [0, 2]] = image_size[1] - bounding_box[:, [2, 0]] return convert_bounding_box_format( bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format @@ -66,7 +66,7 @@ def resize_bounding_box( ) -> torch.Tensor: old_height, old_width = old_image_size new_height, new_width = new_image_size - ratios = torch.tensor((new_width / old_width, new_height / old_height)) + ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) From ecd1425f1655a988958bfe43c6ed94837e19873f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 10:59:20 +0100 Subject: [PATCH 27/32] remove inplace --- .../prototype/transforms/kernels/_augment.py | 24 +++++++------------ .../transforms/kernels/_meta_conversion.py | 4 ++-- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/torchvision/prototype/transforms/kernels/_augment.py b/torchvision/prototype/transforms/kernels/_augment.py index 842ff0cd5d6..c8120eb161a 100644 --- a/torchvision/prototype/transforms/kernels/_augment.py +++ b/torchvision/prototype/transforms/kernels/_augment.py @@ -7,35 +7,29 @@ erase_image = _F.erase -def _mixup(input: torch.Tensor, batch_dim: int, lam: float, inplace: bool) -> torch.Tensor: - if not inplace: - input = input.clone() - +def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: input_rolled = input.roll(1, batch_dim) return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) -def mixup_image(image_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: +def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: if image_batch.ndim < 4: raise ValueError("Need a batch of images") - return _mixup(image_batch, -4, lam, inplace) + return _mixup(image_batch, -4, lam) -def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float, inplace: bool = False) -> torch.Tensor: +def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: if one_hot_label_batch.ndim < 2: raise ValueError("Need a batch of one hot labels") - return _mixup(one_hot_label_batch, -2, lam, inplace) + return _mixup(one_hot_label_batch, -2, lam) -def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], inplace: bool = False) -> torch.Tensor: +def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: if image_batch.ndim < 4: raise ValueError("Need a batch of images") - if not inplace: - image_batch = image_batch.clone() - x1, y1, x2, y2 = box image_rolled = image_batch.roll(1, -4) @@ -43,10 +37,8 @@ def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int], i return image_batch -def cutmix_one_hot_label( - one_hot_label_batch: torch.Tensor, *, lam_adjusted: float, inplace: bool = False -) -> torch.Tensor: +def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: if one_hot_label_batch.ndim < 2: raise ValueError("Need a batch of one hot labels") - return _mixup(one_hot_label_batch, -2, lam_adjusted, inplace) + return _mixup(one_hot_label_batch, -2, lam_adjusted) diff --git a/torchvision/prototype/transforms/kernels/_meta_conversion.py b/torchvision/prototype/transforms/kernels/_meta_conversion.py index 484066a39ee..4acaf9fe9e4 100644 --- a/torchvision/prototype/transforms/kernels/_meta_conversion.py +++ b/torchvision/prototype/transforms/kernels/_meta_conversion.py @@ -37,7 +37,7 @@ def convert_bounding_box_format( bounding_box: torch.Tensor, *, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat ) -> torch.Tensor: if new_format == old_format: - return bounding_box + return bounding_box.clone() if old_format == BoundingBoxFormat.XYWH: bounding_box = _xywh_to_xyxy(bounding_box) @@ -58,7 +58,7 @@ def _grayscale_to_rgb(grayscale: torch.Tensor) -> torch.Tensor: def convert_color_space(image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace) -> torch.Tensor: if new_color_space == old_color_space: - return image + return image.clone() if old_color_space == ColorSpace.GRAYSCALE: image = _grayscale_to_rgb(image) From 8771f404f200efd22d3e39a195d81e308f5aefee Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 11:03:41 +0100 Subject: [PATCH 28/32] rename kernel test module --- ...nsforms_functional.py => test_prototype_transforms_kernels.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{test_prototype_transforms_functional.py => test_prototype_transforms_kernels.py} (100%) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_kernels.py similarity index 100% rename from test/test_prototype_transforms_functional.py rename to test/test_prototype_transforms_kernels.py From 0de4ba7bfadf94e81dddd5d329e6dc0702e896ae Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 11:27:11 +0100 Subject: [PATCH 29/32] fix inplace --- torchvision/prototype/transforms/kernels/_augment.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/kernels/_augment.py b/torchvision/prototype/transforms/kernels/_augment.py index c8120eb161a..526ed85ffd8 100644 --- a/torchvision/prototype/transforms/kernels/_augment.py +++ b/torchvision/prototype/transforms/kernels/_augment.py @@ -8,8 +8,8 @@ def _mixup(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: - input_rolled = input.roll(1, batch_dim) - return input.mul_(lam).add_(input_rolled.mul_(1 - lam)) + input = input.clone() + return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) def mixup_image(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: @@ -33,6 +33,7 @@ def cutmix_image(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) - x1, y1, x2, y2 = box image_rolled = image_batch.roll(1, -4) + image_batch = image_batch.clone() image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] return image_batch From 6ef6bf1a2b75c057e923b12b7058918b215d0c43 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 17:01:34 +0100 Subject: [PATCH 30/32] remove special casing for pil and vanilla tensors --- .../transforms/functional/_augment.py | 13 +++- .../prototype/transforms/functional/_color.py | 75 ++++++++++++++++--- .../transforms/functional/_geometry.py | 73 +++++++++++++----- .../prototype/transforms/functional/_misc.py | 9 ++- .../prototype/transforms/functional/_utils.py | 60 +++++++-------- 5 files changed, 162 insertions(+), 68 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 164e7b2ec5e..bbae796c1c9 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,14 +1,21 @@ from typing import TypeVar, Any +import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F from ._utils import dispatch T = TypeVar("T", bound=features.Feature) -@dispatch({features.Image: K.erase_image}) +@dispatch( + { + torch.Tensor: _F.erase, + features.Image: K.erase_image, + } +) def erase(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... @@ -18,7 +25,7 @@ def erase(input: T, *args: Any, **kwargs: Any) -> T: { features.Image: K.mixup_image, features.OneHotLabel: K.mixup_one_hot_label, - }, + } ) def mixup(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" @@ -29,7 +36,7 @@ def mixup(input: T, *args: Any, **kwargs: Any) -> T: { features.Image: K.cutmix_image, features.OneHotLabel: K.cutmix_one_hot_label, - }, + } ) def cutmix(input: T, *args: Any, **kwargs: Any) -> T: """Perform the CutMix operation as introduced in the paper diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 923f9ea0e3b..479b55a1b03 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,62 +1,119 @@ from typing import TypeVar, Any +import PIL.Image +import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F from ._utils import dispatch T = TypeVar("T", bound=features.Feature) -@dispatch({features.Image: K.adjust_brightness_image}) +@dispatch( + { + torch.Tensor: _F.adjust_brightness, + PIL.Image.Image: _F.adjust_brightness, + features.Image: K.adjust_brightness_image, + } +) def adjust_brightness(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.adjust_saturation_image}) +@dispatch( + { + torch.Tensor: _F.adjust_saturation, + PIL.Image.Image: _F.adjust_saturation, + features.Image: K.adjust_saturation_image, + } +) def adjust_saturation(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.adjust_contrast_image}) +@dispatch( + { + torch.Tensor: _F.adjust_contrast, + PIL.Image.Image: _F.adjust_contrast, + features.Image: K.adjust_contrast_image, + } +) def adjust_contrast(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.adjust_sharpness_image}) +@dispatch( + { + torch.Tensor: _F.adjust_sharpness, + PIL.Image.Image: _F.adjust_sharpness, + features.Image: K.adjust_sharpness_image, + } +) def adjust_sharpness(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.posterize_image}) +@dispatch( + { + torch.Tensor: _F.posterize, + PIL.Image.Image: _F.posterize, + features.Image: K.posterize_image, + } +) def posterize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.solarize_image}) +@dispatch( + { + torch.Tensor: _F.solarize, + PIL.Image.Image: _F.solarize, + features.Image: K.solarize_image, + } +) def solarize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.autocontrast_image}) +@dispatch( + { + torch.Tensor: _F.autocontrast, + PIL.Image.Image: _F.autocontrast, + features.Image: K.autocontrast_image, + } +) def autocontrast(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.equalize_image}) +@dispatch( + { + torch.Tensor: _F.equalize, + PIL.Image.Image: _F.equalize, + features.Image: K.equalize_image, + } +) def equalize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.invert_image}) +@dispatch( + { + torch.Tensor: _F.invert, + PIL.Image.Image: _F.invert, + features.Image: K.invert_image, + } +) def invert(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ffe8c73c088..fa0cc993525 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,5 +1,6 @@ -from typing import List, TypeVar, Any +from typing import TypeVar, Any, cast +import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K @@ -10,59 +11,91 @@ T = TypeVar("T", bound=features.Feature) -def _horizontal_flip_bounding_box(input: features.BoundingBox) -> torch.Tensor: - return K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) - - @dispatch( { + torch.Tensor: _F.hflip, + PIL.Image.Image: _F.hflip, features.Image: K.horizontal_flip_image, - features.BoundingBox: _horizontal_flip_bounding_box, + features.BoundingBox: None, }, - pil_kernel=_F.hflip, ) def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" - ... + if isinstance(input, features.BoundingBox): + output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) + return cast(T, features.BoundingBox.new_like(input, output)) - -def _resize_bounding_box(input: features.BoundingBox, *, size: List[int]) -> features.BoundingBox: - output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) - return features.BoundingBox.new_like(input, output, image_size=size) + raise RuntimeError( + f"horizontal_flip() did not handle inputs of type {type(input).__name__} " + f"although it was configured to do so." + ) @dispatch( { + torch.Tensor: _F.resize, + PIL.Image.Image: _F.resize, features.Image: K.resize_image, features.SegmentationMask: K.resize_segmentation_mask, - features.BoundingBox: _resize_bounding_box, - }, - pil_kernel=_F.resize, + features.BoundingBox: None, + } ) def resize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" - ... + if isinstance(input, features.BoundingBox): + size = kwargs.pop("size") + output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) + return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) + + raise RuntimeError( + f"horizontal_flip() did not handle inputs of type {type(input).__name__} " + f"although it was configured to do so." + ) -@dispatch({features.Image: K.center_crop_image}) +@dispatch( + { + torch.Tensor: _F.center_crop, + PIL.Image.Image: _F.center_crop, + features.Image: K.center_crop_image, + } +) def center_crop(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.resized_crop_image}) +@dispatch( + { + torch.Tensor: _F.resized_crop, + PIL.Image.Image: _F.resized_crop, + features.Image: K.resized_crop_image, + } +) def resized_crop(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.affine_image}) +@dispatch( + { + torch.Tensor: _F.affine, + PIL.Image.Image: _F.affine, + features.Image: K.affine_image, + } +) def affine(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... -@dispatch({features.Image: K.rotate_image}) +@dispatch( + { + torch.Tensor: _F.rotate, + PIL.Image.Image: _F.rotate, + features.Image: K.rotate_image, + } +) def rotate(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 44b84b499a8..40fc5894f03 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -1,14 +1,21 @@ from typing import TypeVar, Any +import torch from torchvision.prototype import features from torchvision.prototype.transforms import kernels as K +from torchvision.transforms import functional as _F from ._utils import dispatch T = TypeVar("T", bound=features.Feature) -@dispatch({features.Image: K.normalize_image}) +@dispatch( + { + torch.Tensor: _F.normalize, + features.Image: K.normalize_image, + } +) def normalize(input: T, *args: Any, **kwargs: Any) -> T: """ADDME""" ... diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 1f46a586b6e..81f75c36bb5 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -1,8 +1,7 @@ import functools import inspect -from typing import Any, Optional, Callable, TypeVar, Dict, Union +from typing import Any, Optional, Callable, TypeVar, Dict -import PIL.Image import torch import torch.overrides from torchvision.prototype import features @@ -10,11 +9,7 @@ F = TypeVar("F", bound=features.Feature) -def dispatch( - kernels: Dict[Any, Callable[..., Union[torch.Tensor, F]]], - *, - pil_kernel: Optional[Callable] = None, -) -> Callable[[Callable[..., F]], Callable[..., F]]: +def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]: """Decorates a function to automatically dispatch to ``kernels`` based on the call arguments. The function body of the dispatcher can be empty as it is never called. The signature and the docstring however are @@ -59,6 +54,9 @@ def dispatch_fn(input: T, *args: Any, **kwargs: Any) -> T: """ def check_kernel(kernel: Any) -> bool: + if kernel is None: + return True + if not callable(kernel): return False @@ -69,46 +67,38 @@ def check_kernel(kernel: Any) -> bool: return params[0].kind != inspect.Parameter.KEYWORD_ONLY for feature_type, kernel in kernels.items(): - if not (issubclass(feature_type, features.Feature) and feature_type is not features.Feature): - raise TypeError( - "Can only register kernels for strict subclasses of `torchvision.prototype.features.Feature`." - ) - if not check_kernel(kernel): raise TypeError( f"Kernel for feature type {feature_type.__name__} is not callable with kernel(input, *args, **kwargs)." ) - if pil_kernel and features.Image not in kernels: - raise TypeError("PIL kernel can only be registered for images") - def outer_wrapper(dispatch_fn: Callable[..., F]) -> Callable[..., F]: @functools.wraps(dispatch_fn) def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: feature_type = type(input) - - if issubclass(feature_type, PIL.Image.Image): - if pil_kernel is None: - raise TypeError("No PIL kernel") - - # TODO: maybe warn or fail here if we have decided on the scope of BC and deprecations - return pil_kernel(input, *args, **kwargs) # type: ignore[no-any-return] - - if not issubclass(feature_type, torch.Tensor): - raise TypeError("No tensor") - - if not issubclass(feature_type, features.Feature): - # TODO: maybe warn or fail here if we have decided on the scope of BC and deprecations - input = features.Image(input) - try: kernel = kernels[feature_type] except KeyError: - raise TypeError(f"No support for {feature_type.__name__}") from None - - output = kernel(input, *args, **kwargs) - - if not isinstance(output, feature_type): + try: + feature_type, kernel = next( + (feature_type, kernel) + for feature_type, kernel in kernels.items() + if isinstance(input, feature_type) + ) + except StopIteration: + raise TypeError(f"No support for {type(input).__name__}") from None + + if kernel is None: + output = dispatch_fn(input, *args, **kwargs) + if output is None: + raise RuntimeError( + f"dispatch_fn() did not handle inputs of type {type(input).__name__} " + f"although it was configured to do so." + ) + else: + output = kernel(input, *args, **kwargs) + + if issubclass(feature_type, features.Feature) and type(output) is torch.Tensor: output = feature_type.new_like(input, output) return output From 886552ce909e1597198dda8419ecbd3ffdd527e3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 18:13:54 +0100 Subject: [PATCH 31/32] address comments --- .../prototype/transforms/functional/_geometry.py | 10 ++-------- torchvision/prototype/transforms/functional/_utils.py | 2 +- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index fa0cc993525..2f9f0f76e39 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -25,10 +25,7 @@ def horizontal_flip(input: T, *args: Any, **kwargs: Any) -> T: output = K.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) return cast(T, features.BoundingBox.new_like(input, output)) - raise RuntimeError( - f"horizontal_flip() did not handle inputs of type {type(input).__name__} " - f"although it was configured to do so." - ) + raise RuntimeError @dispatch( @@ -47,10 +44,7 @@ def resize(input: T, *args: Any, **kwargs: Any) -> T: output = K.resize_bounding_box(input, old_image_size=list(input.image_size), new_image_size=size) return cast(T, features.BoundingBox.new_like(input, output, image_size=size)) - raise RuntimeError( - f"horizontal_flip() did not handle inputs of type {type(input).__name__} " - f"although it was configured to do so." - ) + raise RuntimeError @dispatch( diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 81f75c36bb5..86490c1df15 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -92,7 +92,7 @@ def inner_wrapper(input: F, *args: Any, **kwargs: Any) -> F: output = dispatch_fn(input, *args, **kwargs) if output is None: raise RuntimeError( - f"dispatch_fn() did not handle inputs of type {type(input).__name__} " + f"{dispatch_fn.__name__}() did not handle inputs of type {type(input).__name__} " f"although it was configured to do so." ) else: From c7785b0bad6135c0de0310fa04d3ae6012f19b91 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Feb 2022 18:30:44 +0100 Subject: [PATCH 32/32] update docs --- .../prototype/transforms/functional/_utils.py | 37 +++++-------------- 1 file changed, 9 insertions(+), 28 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_utils.py b/torchvision/prototype/transforms/functional/_utils.py index 86490c1df15..eb44b3421bf 100644 --- a/torchvision/prototype/transforms/functional/_utils.py +++ b/torchvision/prototype/transforms/functional/_utils.py @@ -10,46 +10,27 @@ def dispatch(kernels: Dict[Any, Optional[Callable]]) -> Callable[[Callable[..., F]], Callable[..., F]]: - """Decorates a function to automatically dispatch to ``kernels`` based on the call arguments. - - The function body of the dispatcher can be empty as it is never called. The signature and the docstring however are - used in the documentation and thus should be accurate. + """Decorates a function to automatically dispatch to registered kernels based on the call arguments. The dispatch function should have this signature .. code:: python - from typing import Any, TypeVar - - from torchvision.protoype import features - - T = TypeVar("T", bound=features.Feature) - - @dispatch - def dispatch_fn(input: T, *args: Any, **kwargs: Any) -> T: + @dispatch( + ... + ) + def dispatch_fn(input, *args, **kwargs): ... - where ``input`` is a strict subclass of :class:`~torchvision.prototype.features.Feature` and is used to determine - which kernel to dispatch to. - - .. note:: - - For backward compatibility, ``input`` can also be a ``PIL`` image in which case the call will be dispatched to - ``pil_kernel`` if available. Furthermore, ``input`` can also be a vanilla :class:`~torch.Tensor` in which case - it will be converted into a :class:`~torchvision.prototype.features.Image`. + where ``input`` is used to determine which kernel to dispatch to. Args: - kernels: Dictionary of subclasses of :class:`~torchvision.prototype.features.Feature` that maps to a kernel - to call for this feature type. - pil_kernel: Optional kernel for ``PIL`` images. + kernels: Dictionary with types as keys that maps to a kernel to call. The resolution order is checking for + exact type matches first and if none is found falls back to checking for subclasses. If a value is + ``None``, the decorated function is called. Raises: - TypeError: If any key in ``kernels`` is not a strict subclass of - :class:`~torchvision.prototype.features.Feature`. TypeError: If any value in ``kernels`` is not callable with ``kernel(input, *args, **kwargs)``. - TypeError: If ``pil_kernel`` is specified, but no kernel for :class:`~torchvision.prototype.features.Image` is - available. - TypeError: If the decorated function is called with neither a ``PIL`` image nor a :class:`~torch.Tensor`. TypeError: If the decorated function is called with an input that cannot be dispatched. """