From 56ae0abc52861f7741ff0f95243567a0aedf9167 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 10:54:39 +0200 Subject: [PATCH 01/11] move passthrough for unknown types from dispatchers to transforms --- test/test_transforms_v2_refactored.py | 56 ++++------- .../transforms/v2/functional/_augment.py | 4 +- .../transforms/v2/functional/_color.py | 14 +-- .../transforms/v2/functional/_geometry.py | 27 +++--- torchvision/transforms/v2/functional/_meta.py | 5 +- torchvision/transforms/v2/functional/_misc.py | 6 +- .../transforms/v2/functional/_temporal.py | 6 +- .../transforms/v2/functional/_utils.py | 94 +++++++------------ 8 files changed, 74 insertions(+), 138 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index c910882f9fd..53b21c33e51 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -39,7 +39,7 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _get_kernel, _KERNEL_REGISTRY, _noop, _register_kernel_internal +from torchvision.transforms.v2.functional._utils import _get_kernel, _noop, _register_kernel_internal @pytest.fixture(autouse=True) @@ -384,35 +384,6 @@ def transform(bbox): return torch.stack([transform(b) for b in bounding_boxes.reshape(-1, 4).unbind()]).reshape(bounding_boxes.shape) -@pytest.mark.parametrize( - ("dispatcher", "registered_input_types"), - [(dispatcher, set(registry.keys())) for dispatcher, registry in _KERNEL_REGISTRY.items()], -) -def test_exhaustive_kernel_registration(dispatcher, registered_input_types): - missing = { - torch.Tensor, - PIL.Image.Image, - datapoints.Image, - datapoints.BoundingBoxes, - datapoints.Mask, - datapoints.Video, - } - registered_input_types - if missing: - names = sorted(str(t) for t in missing) - raise AssertionError( - "\n".join( - [ - f"The dispatcher '{dispatcher.__name__}' has no kernel registered for", - "", - *[f"- {name}" for name in names], - "", - f"If available, register the kernels with @_register_kernel_internal({dispatcher.__name__}, ...).", - f"If not, register explicit no-ops with @_register_explicit_noop({', '.join(names)})", - ] - ) - ) - - class TestResize: INPUT_SIZE = (17, 11) OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)] @@ -2188,9 +2159,20 @@ def test_errors(self): with pytest.raises(ValueError, match="Kernels can only be registered for subclasses"): F.register_kernel(F.resize, object) - with pytest.raises(ValueError, match="already has a kernel registered for type"): + with pytest.raises(ValueError, match="cannot be registered for the builtin datapoint classes"): F.register_kernel(F.resize, datapoints.Image)(F.resize_image_tensor) + class CustomDatapoint(datapoints.Datapoint): + pass + + def resize_custom_datapoint(): + pass + + F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint) + + with pytest.raises(ValueError, match="already has a kernel registered for type"): + F.register_kernel(F.resize, CustomDatapoint)(resize_custom_datapoint) + class TestGetKernel: # We are using F.resize as dispatcher and the kernels below as proxy. Any other dispatcher / kernels combination @@ -2212,13 +2194,7 @@ class MyPILImage(PIL.Image.Image): pass for input_type in [str, int, object, MyTensor, MyPILImage]: - with pytest.raises( - TypeError, - match=( - "supports inputs of type torch.Tensor, PIL.Image.Image, " - "and subclasses of torchvision.datapoints.Datapoint" - ), - ): + with pytest.raises(TypeError, match="supports inputs of type"): _get_kernel(F.resize, input_type) def test_exact_match(self): @@ -2271,8 +2247,8 @@ def test_datapoint_subclass(self): class MyDatapoint(datapoints.Datapoint): pass - # Note that this will be an error in the future - assert _get_kernel(F.resize, MyDatapoint) is _noop + with pytest.raises(TypeError, match="supports inputs of type"): + assert _get_kernel(F.resize, MyDatapoint) is _noop def resize_my_datapoint(): pass diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index 89fa254374d..119dd0b4740 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -7,10 +7,10 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal +from ._utils import _get_kernel, _register_kernel_internal, _register_temporary_passthrough_kernels_internal -@_register_explicit_noop(datapoints.Mask, datapoints.BoundingBoxes, warn_passthrough=True) +@_register_temporary_passthrough_kernels_internal(datapoints.BoundingBoxes, datapoints.Mask) def erase( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], i: int, diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index 71797fd2500..21974304f98 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -10,10 +10,9 @@ from torchvision.utils import _log_api_usage_once from ._misc import _num_value_bits, to_dtype_image_tensor -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal +from ._utils import _get_kernel, _register_kernel_internal -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, datapoints.Video) def rgb_to_grayscale( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]: @@ -70,7 +69,6 @@ def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Te return output if fp else output.to(image1.dtype) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_brightness(inpt: datapoints._InputTypeJIT, brightness_factor: float) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) @@ -107,7 +105,6 @@ def adjust_brightness_video(video: torch.Tensor, brightness_factor: float) -> to return adjust_brightness_image_tensor(video, brightness_factor=brightness_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_saturation(inpt: datapoints._InputTypeJIT, saturation_factor: float) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) @@ -146,7 +143,6 @@ def adjust_saturation_video(video: torch.Tensor, saturation_factor: float) -> to return adjust_saturation_image_tensor(video, saturation_factor=saturation_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_contrast(inpt: datapoints._InputTypeJIT, contrast_factor: float) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) @@ -185,7 +181,6 @@ def adjust_contrast_video(video: torch.Tensor, contrast_factor: float) -> torch. return adjust_contrast_image_tensor(video, contrast_factor=contrast_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_sharpness(inpt: datapoints._InputTypeJIT, sharpness_factor: float) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) @@ -258,7 +253,6 @@ def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torc return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_hue(inpt: datapoints._InputTypeJIT, hue_factor: float) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) @@ -370,7 +364,6 @@ def adjust_hue_video(video: torch.Tensor, hue_factor: float) -> torch.Tensor: return adjust_hue_image_tensor(video, hue_factor=hue_factor) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def adjust_gamma(inpt: datapoints._InputTypeJIT, gamma: float, gain: float = 1) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) @@ -410,7 +403,6 @@ def adjust_gamma_video(video: torch.Tensor, gamma: float, gain: float = 1) -> to return adjust_gamma_image_tensor(video, gamma=gamma, gain=gain) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def posterize(inpt: datapoints._InputTypeJIT, bits: int) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return posterize_image_tensor(inpt, bits=bits) @@ -444,7 +436,6 @@ def posterize_video(video: torch.Tensor, bits: int) -> torch.Tensor: return posterize_image_tensor(video, bits=bits) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def solarize(inpt: datapoints._InputTypeJIT, threshold: float) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return solarize_image_tensor(inpt, threshold=threshold) @@ -472,7 +463,6 @@ def solarize_video(video: torch.Tensor, threshold: float) -> torch.Tensor: return solarize_image_tensor(video, threshold=threshold) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def autocontrast(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return autocontrast_image_tensor(inpt) @@ -522,7 +512,6 @@ def autocontrast_video(video: torch.Tensor) -> torch.Tensor: return autocontrast_image_tensor(video) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def equalize(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return equalize_image_tensor(inpt) @@ -612,7 +601,6 @@ def equalize_video(video: torch.Tensor) -> torch.Tensor: return equalize_image_tensor(video) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def invert(inpt: datapoints._InputTypeJIT) -> datapoints._InputTypeJIT: if torch.jit.is_scripting(): return invert_image_tensor(inpt) diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index bb19def2c93..02dfaadc0fd 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,7 +25,12 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import _get_kernel, _register_explicit_noop, _register_five_ten_crop_kernel, _register_kernel_internal +from ._utils import ( + _get_kernel, + _register_five_ten_crop_kernel_internal, + _register_kernel_internal, + _register_temporary_passthrough_kernels_internal, +) def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -2199,7 +2204,7 @@ def resized_crop_video( ) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) +@_register_temporary_passthrough_kernels_internal(datapoints.BoundingBoxes, datapoints.Mask) def five_crop( inpt: datapoints._InputTypeJIT, size: List[int] ) -> Tuple[ @@ -2232,8 +2237,8 @@ def _parse_five_crop_size(size: List[int]) -> List[int]: return size -@_register_five_ten_crop_kernel(five_crop, torch.Tensor) -@_register_five_ten_crop_kernel(five_crop, datapoints.Image) +@_register_five_ten_crop_kernel_internal(five_crop, torch.Tensor) +@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Image) def five_crop_image_tensor( image: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2252,7 +2257,7 @@ def five_crop_image_tensor( return tl, tr, bl, br, center -@_register_five_ten_crop_kernel(five_crop, PIL.Image.Image) +@_register_five_ten_crop_kernel_internal(five_crop, PIL.Image.Image) def five_crop_image_pil( image: PIL.Image.Image, size: List[int] ) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: @@ -2271,14 +2276,14 @@ def five_crop_image_pil( return tl, tr, bl, br, center -@_register_five_ten_crop_kernel(five_crop, datapoints.Video) +@_register_five_ten_crop_kernel_internal(five_crop, datapoints.Video) def five_crop_video( video: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return five_crop_image_tensor(video, size) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True) +@_register_temporary_passthrough_kernels_internal(datapoints.BoundingBoxes, datapoints.Mask) def ten_crop( inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT], size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2302,8 +2307,8 @@ def ten_crop( return kernel(inpt, size=size, vertical_flip=vertical_flip) -@_register_five_ten_crop_kernel(ten_crop, torch.Tensor) -@_register_five_ten_crop_kernel(ten_crop, datapoints.Image) +@_register_five_ten_crop_kernel_internal(ten_crop, torch.Tensor) +@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Image) def ten_crop_image_tensor( image: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2330,7 +2335,7 @@ def ten_crop_image_tensor( return non_flipped + flipped -@_register_five_ten_crop_kernel(ten_crop, PIL.Image.Image) +@_register_five_ten_crop_kernel_internal(ten_crop, PIL.Image.Image) def ten_crop_image_pil( image: PIL.Image.Image, size: List[int], vertical_flip: bool = False ) -> Tuple[ @@ -2357,7 +2362,7 @@ def ten_crop_image_pil( return non_flipped + flipped -@_register_five_ten_crop_kernel(ten_crop, datapoints.Video) +@_register_five_ten_crop_kernel_internal(ten_crop, datapoints.Video) def ten_crop_video( video: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ diff --git a/torchvision/transforms/v2/functional/_meta.py b/torchvision/transforms/v2/functional/_meta.py index fc1aa05f319..0dbb6f41c33 100644 --- a/torchvision/transforms/v2/functional/_meta.py +++ b/torchvision/transforms/v2/functional/_meta.py @@ -8,10 +8,9 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal, _register_unsupported_type, is_simple_tensor +from ._utils import _get_kernel, _register_kernel_internal, is_simple_tensor -@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_dimensions(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> List[int]: if torch.jit.is_scripting(): return get_dimensions_image_tensor(inpt) @@ -44,7 +43,6 @@ def get_dimensions_video(video: torch.Tensor) -> List[int]: return get_dimensions_image_tensor(video) -@_register_unsupported_type(datapoints.BoundingBoxes, datapoints.Mask) def get_num_channels(inpt: Union[datapoints._ImageTypeJIT, datapoints._VideoTypeJIT]) -> int: if torch.jit.is_scripting(): return get_num_channels_image_tensor(inpt) @@ -123,7 +121,6 @@ def get_size_bounding_boxes(bounding_box: datapoints.BoundingBoxes) -> List[int] return list(bounding_box.canvas_size) -@_register_unsupported_type(PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask) def get_num_frames(inpt: datapoints._VideoTypeJIT) -> int: if torch.jit.is_scripting(): return get_num_frames_video(inpt) diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index e3a800ea79c..b40b9737cb5 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -11,11 +11,9 @@ from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal, _register_unsupported_type +from ._utils import _get_kernel, _register_kernel_internal -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) -@_register_unsupported_type(PIL.Image.Image) def normalize( inpt: Union[datapoints._TensorImageTypeJIT, datapoints._TensorVideoTypeJIT], mean: List[float], @@ -73,7 +71,6 @@ def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], in return normalize_image_tensor(video, mean, std, inplace=inplace) -@_register_explicit_noop(datapoints.BoundingBoxes, datapoints.Mask) def gaussian_blur( inpt: datapoints._InputTypeJIT, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> datapoints._InputTypeJIT: @@ -184,7 +181,6 @@ def gaussian_blur_video( return gaussian_blur_image_tensor(video, kernel_size, sigma) -@_register_unsupported_type(PIL.Image.Image) def to_dtype( inpt: datapoints._InputTypeJIT, dtype: torch.dtype = torch.float, scale: bool = False ) -> datapoints._InputTypeJIT: diff --git a/torchvision/transforms/v2/functional/_temporal.py b/torchvision/transforms/v2/functional/_temporal.py index 62d12cb4b4e..ae988b39315 100644 --- a/torchvision/transforms/v2/functional/_temporal.py +++ b/torchvision/transforms/v2/functional/_temporal.py @@ -1,16 +1,12 @@ -import PIL.Image import torch from torchvision import datapoints from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_explicit_noop, _register_kernel_internal +from ._utils import _get_kernel, _register_kernel_internal -@_register_explicit_noop( - PIL.Image.Image, datapoints.Image, datapoints.BoundingBoxes, datapoints.Mask, warn_passthrough=True -) def uniform_temporal_subsample(inpt: datapoints._VideoTypeJIT, num_samples: int) -> datapoints._VideoTypeJIT: if torch.jit.is_scripting(): return uniform_temporal_subsample_video(inpt, num_samples=num_samples) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 576a2b99dbf..0a1e2e070a2 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -2,6 +2,8 @@ import warnings from typing import Any, Callable, Dict, Type +import PIL.Image + import torch from torchvision import datapoints @@ -50,6 +52,11 @@ def _name_to_dispatcher(name): ) from None +_BUILTIN_DATAPOINT_TYPES = { + obj for obj in datapoints.__dict__.values() if isinstance(obj, type) and issubclass(obj, datapoints.Datapoint) +} + + def register_kernel(dispatcher, datapoint_cls): """Decorate a kernel to register it for a dispatcher and a (custom) datapoint type. @@ -67,20 +74,25 @@ def register_kernel(dispatcher, datapoint_cls): f"but got {dispatcher}." ) - if not ( - isinstance(datapoint_cls, type) - and issubclass(datapoint_cls, datapoints.Datapoint) - and datapoint_cls is not datapoints.Datapoint - ): + if not (isinstance(datapoint_cls, type) and issubclass(datapoint_cls, datapoints.Datapoint)): raise ValueError( f"Kernels can only be registered for subclasses of torchvision.datapoints.Datapoint, " f"but got {datapoint_cls}." ) + if datapoint_cls in _BUILTIN_DATAPOINT_TYPES: + raise ValueError(f"Kernels cannot be registered for the builtin datapoint classes, but got {datapoint_cls}") + return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) -def _get_kernel(dispatcher, input_type): +def _noop(inpt, *args, __msg__=None, **kwargs): + if __msg__: + warnings.warn(__msg__, UserWarning, stacklevel=2) + return inpt + + +def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: raise ValueError(f"No kernel registered for dispatcher {dispatcher.__name__}.") @@ -101,70 +113,36 @@ def _get_kernel(dispatcher, input_type): elif cls in registry: return registry[cls] - # Note that in the future we are not going to return a noop here, but rather raise the error below + if allow_passthrough: return _noop - raise TypeError( - f"Dispatcher {dispatcher} supports inputs of type torch.Tensor, PIL.Image.Image, " - f"and subclasses of torchvision.datapoints.Datapoint, " - f"but got {input_type} instead." - ) - + supported_datapoint_types = registry.keys() - {torch.Tensor, PIL.Image.Image} + builtin_datapoint_types = supported_datapoint_types & _BUILTIN_DATAPOINT_TYPES + custom_datapoint_types = supported_datapoint_types - builtin_datapoint_types -# Everything below this block is stuff that we need right now, since it looks like we need to release in an intermediate -# stage. See https://github.com/pytorch/vision/pull/7747#issuecomment-1661698450 for details. + builtin_datapoint_type_names = ", ".join(f"datapoints.{t.__name__}" for t in builtin_datapoint_types) + custom_datapoint_types_names = ", ".join(str(t) for t in custom_datapoint_types) + supported_type_names = "torch.Tensor" + supported_type_names += ", PIL.Image.Image, " if PIL.Image.Image in registry else " " + supported_type_names += f"and subclasses of the builtin {builtin_datapoint_type_names}" + if custom_datapoint_types: + supported_type_names += f", and the custom datapoint types {custom_datapoint_types_names}" -# In the future, the default behavior will be to error on unsupported types in dispatchers. The noop behavior that we -# need for transforms will be handled by _get_kernel rather than actually registering no-ops on the dispatcher. -# Finally, the use case of preventing users from registering kernels for our builtin types will be handled inside -# register_kernel. -def _register_explicit_noop(*datapoints_classes, warn_passthrough=False): - """ - Although this looks redundant with the no-op behavior of _get_kernel, this explicit registration prevents users - from registering kernels for builtin datapoints on builtin dispatchers that rely on the no-op behavior. - - For example, without explicit no-op registration the following would be valid user code: - - .. code:: - from torchvision.transforms.v2 import functional as F + raise TypeError( + f"Dispatcher F.{dispatcher.__name__} supports inputs of type {supported_type_names}, " + f"but got {input_type} instead." + ) - @F.register_kernel(F.adjust_brightness, datapoints.BoundingBox) - def lol(...): - ... - """ +def _register_temporary_passthrough_kernels_internal(*datapoints_classes): def decorator(dispatcher): for cls in datapoints_classes: msg = ( f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"This will likely change in the future." ) - _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)( - functools.partial(_noop, __msg__=msg if warn_passthrough else None) - ) - return dispatcher - - return decorator - - -def _noop(inpt, *args, __msg__=None, **kwargs): - if __msg__: - warnings.warn(__msg__, UserWarning, stacklevel=2) - return inpt - - -# TODO: we only need this, since our default behavior in case no kernel is found is passthrough. When we change that -# to error later, this decorator can be removed, since the error will be raised by _get_kernel -def _register_unsupported_type(*input_types): - def kernel(inpt, *args, __dispatcher_name__, **kwargs): - raise TypeError(f"F.{__dispatcher_name__} does not support inputs of type {type(inpt)}.") - - def decorator(dispatcher): - for input_type in input_types: - _register_kernel_internal(dispatcher, input_type, datapoint_wrapper=False)( - functools.partial(kernel, __dispatcher_name__=dispatcher.__name__) - ) + _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(functools.partial(_noop, __msg__=msg)) return dispatcher return decorator @@ -172,7 +150,7 @@ def decorator(dispatcher): # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool -def _register_five_ten_crop_kernel(dispatcher, input_type): +def _register_five_ten_crop_kernel_internal(dispatcher, input_type): registry = _KERNEL_REGISTRY.setdefault(dispatcher, {}) if input_type in registry: raise TypeError(f"Dispatcher '{dispatcher}' already has a kernel registered for type '{input_type}'.") From 044e6d1cbdb358e1d474ed9af7868d67344b58a2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 7 Aug 2023 12:47:25 +0200 Subject: [PATCH 02/11] implement noop behavior for transforms --- torchvision/transforms/v2/_augment.py | 2 +- torchvision/transforms/v2/_color.py | 48 +++++++------------ torchvision/transforms/v2/_geometry.py | 63 ++++++++++++++++--------- torchvision/transforms/v2/_misc.py | 13 ++--- torchvision/transforms/v2/_temporal.py | 2 +- torchvision/transforms/v2/_transform.py | 6 +++ 6 files changed, 70 insertions(+), 64 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index 87a43b118ce..f475ac34c6c 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -131,7 +131,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: - inpt = F.erase(inpt, **params, inplace=self.inplace) + inpt = self._call_or_noop(F.erase, inpt, **params, inplace=self.inplace) return inpt diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 7dd8eeae236..60ecda56d3c 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -24,19 +24,12 @@ class Grayscale(Transform): _v1_transform_cls = _transforms.Grayscale - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def __init__(self, num_output_channels: int = 1): super().__init__() self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) + return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) class RandomGrayscale(_RandomApplyTransform): @@ -55,13 +48,6 @@ class RandomGrayscale(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomGrayscale - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - def __init__(self, p: float = 0.1) -> None: super().__init__(p=p) @@ -70,7 +56,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(num_input_channels=num_input_channels) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) + return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) class ColorJitter(Transform): @@ -167,13 +153,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: hue_factor = params["hue_factor"] for fn_id in params["fn_idx"]: if fn_id == 0 and brightness_factor is not None: - output = F.adjust_brightness(output, brightness_factor=brightness_factor) + output = self._call_or_noop(F.adjust_brightness, output, brightness_factor=brightness_factor) elif fn_id == 1 and contrast_factor is not None: - output = F.adjust_contrast(output, contrast_factor=contrast_factor) + output = self._call_or_noop(F.adjust_contrast, output, contrast_factor=contrast_factor) elif fn_id == 2 and saturation_factor is not None: - output = F.adjust_saturation(output, saturation_factor=saturation_factor) + output = self._call_or_noop(F.adjust_saturation, output, saturation_factor=saturation_factor) elif fn_id == 3 and hue_factor is not None: - output = F.adjust_hue(output, hue_factor=hue_factor) + output = self._call_or_noop(F.adjust_hue, output, hue_factor=hue_factor) return output @@ -260,15 +246,15 @@ def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: if params["brightness_factor"] is not None: - inpt = F.adjust_brightness(inpt, brightness_factor=params["brightness_factor"]) + inpt = self._call_or_noop(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) if params["contrast_factor"] is not None and params["contrast_before"]: - inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) + inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["saturation_factor"] is not None: - inpt = F.adjust_saturation(inpt, saturation_factor=params["saturation_factor"]) + inpt = self._call_or_noop(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"]) if params["hue_factor"] is not None: - inpt = F.adjust_hue(inpt, hue_factor=params["hue_factor"]) + inpt = self._call_or_noop(F.adjust_hue, inpt, hue_factor=params["hue_factor"]) if params["contrast_factor"] is not None and not params["contrast_before"]: - inpt = F.adjust_contrast(inpt, contrast_factor=params["contrast_factor"]) + inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) return inpt @@ -290,7 +276,7 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.equalize(inpt) + return self._call_or_noop(F.equalize, inpt) class RandomInvert(_RandomApplyTransform): @@ -309,7 +295,7 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.invert(inpt) + return self._call_or_noop(F.invert, inpt) class RandomPosterize(_RandomApplyTransform): @@ -334,7 +320,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None: self.bits = bits def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.posterize(inpt, bits=self.bits) + return self._call_or_noop(F.posterize, inpt, bits=self.bits) class RandomSolarize(_RandomApplyTransform): @@ -359,7 +345,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None: self.threshold = threshold def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.solarize(inpt, threshold=self.threshold) + return self._call_or_noop(F.solarize, inpt, threshold=self.threshold) class RandomAutocontrast(_RandomApplyTransform): @@ -378,7 +364,7 @@ class RandomAutocontrast(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomAutocontrast def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.autocontrast(inpt) + return self._call_or_noop(F.autocontrast, inpt) class RandomAdjustSharpness(_RandomApplyTransform): @@ -403,4 +389,4 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: self.sharpness_factor = sharpness_factor def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor) + return self._call_or_noop(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index e43aa868a34..09f4429a8a5 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -43,7 +43,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.horizontal_flip(inpt) + return self._call_or_noop(F.horizontal_flip, inpt) class RandomVerticalFlip(_RandomApplyTransform): @@ -63,7 +63,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.vertical_flip(inpt) + return self._call_or_noop(F.vertical_flip, inpt) class Resize(Transform): @@ -151,7 +151,8 @@ def __init__( self.antialias = antialias def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize( + return self._call_or_noop( + F.resize, inpt, self.size, interpolation=self.interpolation, @@ -185,7 +186,7 @@ def __init__(self, size: Union[int, Sequence[int]]): self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.center_crop(inpt, output_size=self.size) + return self._call_or_noop(F.center_crop, inpt, output_size=self.size) class RandomResizedCrop(Transform): @@ -306,8 +307,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resized_crop( - inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + return self._call_or_noop( + F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) @@ -360,7 +361,7 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.five_crop(inpt, self.size) + return self._call_or_noop(F.five_crop, inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): @@ -403,7 +404,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) + return self._call_or_noop(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) class Pad(Transform): @@ -477,7 +478,7 @@ def __init__( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.pad(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + return self._call_or_noop(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] class RandomZoomOut(_RandomApplyTransform): @@ -547,7 +548,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.pad(inpt, **params, fill=fill) + return self._call_or_noop(F.pad, inpt, **params, fill=fill) class RandomRotation(Transform): @@ -613,7 +614,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.rotate( + return self._call_or_noop( + F.rotate, inpt, **params, interpolation=self.interpolation, @@ -735,7 +737,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.affine( + return self._call_or_noop( + F.affine, inpt, **params, interpolation=self.interpolation, @@ -891,10 +894,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) - inpt = F.pad(inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) + inpt = self._call_or_noop(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: - inpt = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + inpt = self._call_or_noop( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ) return inpt @@ -975,7 +980,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.perspective( + return self._call_or_noop( + F.perspective, inpt, None, None, @@ -1052,7 +1058,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if kx % 2 == 0: kx += 1 - dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) + dx = self._call_or_noop(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / size[0] dy = torch.rand([1, 1] + size) * 2 - 1 @@ -1061,14 +1067,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if ky % 2 == 0: ky += 1 - dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) + dy = self._call_or_noop(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) dy = dy * self.alpha[1] / size[1] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return F.elastic( + return self._call_or_noop( + F.elastic, inpt, **params, fill=fill, @@ -1166,7 +1173,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # check for any valid boxes with centers within the crop area xyxy_bboxes = F.convert_format_bounding_boxes( - bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY + bboxes.as_subclass(torch.Tensor), + bboxes.format, + datapoints.BoundingBoxFormat.XYXY, ) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) @@ -1190,7 +1199,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt - output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) + output = self._call_or_noop( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + ) if isinstance(output, datapoints.BoundingBoxes): # We "mark" the invalid boxes as degenreate, and they can be @@ -1264,7 +1275,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + return self._call_or_noop( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + ) class RandomShortestSize(Transform): @@ -1332,7 +1345,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias) + return self._call_or_noop( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + ) class RandomResize(Transform): @@ -1402,4 +1417,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=[size]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.resize(inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias) + return self._call_or_noop( + F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias + ) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index a799070ee1e..a66cb5e0865 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -106,7 +106,7 @@ def __init__(self, transformation_matrix: torch.Tensor, mean_vector: torch.Tenso def _check_inputs(self, sample: Any) -> Any: if has_any(sample, PIL.Image.Image): - raise TypeError("LinearTransformation does not work on PIL Images") + raise TypeError(f"{type(self).__name__}() does not support PIL images.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: shape = inpt.shape @@ -157,7 +157,6 @@ class Normalize(Transform): """ _v1_transform_cls = _transforms.Normalize - _transformed_types = (datapoints.Image, is_simple_tensor, datapoints.Video) def __init__(self, mean: Sequence[float], std: Sequence[float], inplace: bool = False): super().__init__() @@ -172,7 +171,7 @@ def _check_inputs(self, sample: Any) -> Any: def _transform( self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] ) -> Any: - return F.normalize(inpt, mean=self.mean, std=self.std, inplace=self.inplace) + return self._call_or_noop(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) class GaussianBlur(Transform): @@ -219,7 +218,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(sigma=[sigma, sigma]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.gaussian_blur(inpt, self.kernel_size, **params) + return self._call_or_noop(F.gaussian_blur, inpt, self.kernel_size, **params) class ToDtype(Transform): @@ -292,7 +291,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) return inpt - return F.to_dtype(inpt, dtype=dtype, scale=self.scale) + return self._call_or_noop(F.to_dtype, inpt, dtype=dtype, scale=self.scale) class ConvertImageDtype(Transform): @@ -322,14 +321,12 @@ class ConvertImageDtype(Transform): _v1_transform_cls = _transforms.ConvertImageDtype - _transformed_types = (is_simple_tensor, datapoints.Image) - def __init__(self, dtype: torch.dtype = torch.float32) -> None: super().__init__() self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.to_dtype(inpt, dtype=self.dtype, scale=True) + return self._call_or_noop(F.to_dtype, inpt, dtype=self.dtype, scale=True) class SanitizeBoundingBoxes(Transform): diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index 868314e9e33..3e3d332e024 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -26,4 +26,4 @@ def __init__(self, num_samples: int): self.num_samples = num_samples def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType: - return F.uniform_temporal_subsample(inpt, self.num_samples) + return self._call_or_noop(F.uniform_temporal_subsample, inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index f83ed5d6e11..aa337cdce3d 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,6 +11,8 @@ from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once +from .functional._utils import _get_kernel + class Transform(nn.Module): @@ -28,6 +30,10 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict() + def _call_or_noop(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True) + return kernel(inpt, *args, **kwargs) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError From 38f85c5806129c4b96b7e93b877dba1537c8696a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 09:51:59 +0200 Subject: [PATCH 03/11] simplify error message --- torchvision/transforms/v2/functional/_utils.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 0a1e2e070a2..a1f4720da30 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -2,8 +2,6 @@ import warnings from typing import Any, Callable, Dict, Type -import PIL.Image - import torch from torchvision import datapoints @@ -116,21 +114,8 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): if allow_passthrough: return _noop - supported_datapoint_types = registry.keys() - {torch.Tensor, PIL.Image.Image} - builtin_datapoint_types = supported_datapoint_types & _BUILTIN_DATAPOINT_TYPES - custom_datapoint_types = supported_datapoint_types - builtin_datapoint_types - - builtin_datapoint_type_names = ", ".join(f"datapoints.{t.__name__}" for t in builtin_datapoint_types) - custom_datapoint_types_names = ", ".join(str(t) for t in custom_datapoint_types) - - supported_type_names = "torch.Tensor" - supported_type_names += ", PIL.Image.Image, " if PIL.Image.Image in registry else " " - supported_type_names += f"and subclasses of the builtin {builtin_datapoint_type_names}" - if custom_datapoint_types: - supported_type_names += f", and the custom datapoint types {custom_datapoint_types_names}" - raise TypeError( - f"Dispatcher F.{dispatcher.__name__} supports inputs of type {supported_type_names}, " + f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, " f"but got {input_type} instead." ) From 9f512301bdcc7128c97d5cf95501b2dbf108d155 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 09:55:38 +0200 Subject: [PATCH 04/11] inline kernel call --- torchvision/transforms/v2/_augment.py | 3 +- torchvision/transforms/v2/_color.py | 55 ++++++++++++++------- torchvision/transforms/v2/_geometry.py | 64 ++++++++++++------------- torchvision/transforms/v2/_misc.py | 11 +++-- torchvision/transforms/v2/_temporal.py | 4 +- torchvision/transforms/v2/_transform.py | 6 --- 6 files changed, 82 insertions(+), 61 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index f475ac34c6c..ee067e467d3 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -9,6 +9,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F +from torchvision.transforms.v2.functional._utils import _get_kernel from ._transform import _RandomApplyTransform, Transform from ._utils import _parse_labels_getter @@ -131,7 +132,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: - inpt = self._call_or_noop(F.erase, inpt, **params, inplace=self.inplace) + inpt = _get_kernel(F.erase, type(inpt), allow_passthrough=True)(inpt, **params, inplace=self.inplace) return inpt diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 60ecda56d3c..6b6a89ec666 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -5,6 +5,7 @@ import torch from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _get_kernel from ._transform import _RandomApplyTransform from .utils import is_simple_tensor, query_chw @@ -29,7 +30,9 @@ def __init__(self, num_output_channels: int = 1): self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) + return _get_kernel(F.rgb_to_grayscale, type(inpt), allow_passthrough=True)( + inpt, num_output_channels=self.num_output_channels + ) class RandomGrayscale(_RandomApplyTransform): @@ -56,7 +59,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(num_input_channels=num_input_channels) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) + return _get_kernel(F.rgb_to_grayscale, type(inpt), allow_passthrough=True)( + inpt, num_output_channels=params["num_input_channels"] + ) class ColorJitter(Transform): @@ -153,13 +158,19 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: hue_factor = params["hue_factor"] for fn_id in params["fn_idx"]: if fn_id == 0 and brightness_factor is not None: - output = self._call_or_noop(F.adjust_brightness, output, brightness_factor=brightness_factor) + output = _get_kernel(F.adjust_brightness, type(output), allow_passthrough=True)( + output, brightness_factor=brightness_factor + ) elif fn_id == 1 and contrast_factor is not None: - output = self._call_or_noop(F.adjust_contrast, output, contrast_factor=contrast_factor) + output = _get_kernel(F.adjust_contrast, type(output), allow_passthrough=True)( + output, contrast_factor=contrast_factor + ) elif fn_id == 2 and saturation_factor is not None: - output = self._call_or_noop(F.adjust_saturation, output, saturation_factor=saturation_factor) + output = _get_kernel(F.adjust_saturation, type(output), allow_passthrough=True)( + output, saturation_factor=saturation_factor + ) elif fn_id == 3 and hue_factor is not None: - output = self._call_or_noop(F.adjust_hue, output, hue_factor=hue_factor) + output = _get_kernel(F.adjust_hue, type(output), allow_passthrough=True)(output, hue_factor=hue_factor) return output @@ -246,15 +257,23 @@ def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: if params["brightness_factor"] is not None: - inpt = self._call_or_noop(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) + inpt = _get_kernel(F.adjust_brightness, type(inpt), allow_passthrough=True)( + inpt, brightness_factor=params["brightness_factor"] + ) if params["contrast_factor"] is not None and params["contrast_before"]: - inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) + inpt = _get_kernel(F.adjust_contrast, type(inpt), allow_passthrough=True)( + inpt, contrast_factor=params["contrast_factor"] + ) if params["saturation_factor"] is not None: - inpt = self._call_or_noop(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"]) + inpt = _get_kernel(F.adjust_saturation, type(inpt), allow_passthrough=True)( + inpt, saturation_factor=params["saturation_factor"] + ) if params["hue_factor"] is not None: - inpt = self._call_or_noop(F.adjust_hue, inpt, hue_factor=params["hue_factor"]) + inpt = _get_kernel(F.adjust_hue, type(inpt), allow_passthrough=True)(inpt, hue_factor=params["hue_factor"]) if params["contrast_factor"] is not None and not params["contrast_before"]: - inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) + inpt = _get_kernel(F.adjust_contrast, type(inpt), allow_passthrough=True)( + inpt, contrast_factor=params["contrast_factor"] + ) if params["channel_permutation"] is not None: inpt = self._permute_channels(inpt, permutation=params["channel_permutation"]) return inpt @@ -276,7 +295,7 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.equalize, inpt) + return _get_kernel(F.equalize, type(inpt), allow_passthrough=True)(inpt) class RandomInvert(_RandomApplyTransform): @@ -295,7 +314,7 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.invert, inpt) + return _get_kernel(F.invert, type(inpt), allow_passthrough=True)(inpt) class RandomPosterize(_RandomApplyTransform): @@ -320,7 +339,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None: self.bits = bits def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.posterize, inpt, bits=self.bits) + return _get_kernel(F.posterize, type(inpt), allow_passthrough=True)(inpt, bits=self.bits) class RandomSolarize(_RandomApplyTransform): @@ -345,7 +364,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None: self.threshold = threshold def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.solarize, inpt, threshold=self.threshold) + return _get_kernel(F.solarize, type(inpt), allow_passthrough=True)(inpt, threshold=self.threshold) class RandomAutocontrast(_RandomApplyTransform): @@ -364,7 +383,7 @@ class RandomAutocontrast(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomAutocontrast def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.autocontrast, inpt) + return _get_kernel(F.autocontrast, type(inpt), allow_passthrough=True)(inpt) class RandomAdjustSharpness(_RandomApplyTransform): @@ -389,4 +408,6 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: self.sharpness_factor = sharpness_factor def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) + return _get_kernel(F.adjust_sharpness, type(inpt), allow_passthrough=True)( + inpt, sharpness_factor=self.sharpness_factor + ) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 9cf5dadf64c..78108aa69d7 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,6 +11,7 @@ from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation +from torchvision.transforms.v2.functional._utils import _get_kernel from ._transform import _RandomApplyTransform from ._utils import ( @@ -43,7 +44,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.horizontal_flip, inpt) + return _get_kernel(F.horizontal_flip, type(inpt), allow_passthrough=True)(inpt) class RandomVerticalFlip(_RandomApplyTransform): @@ -63,7 +64,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.vertical_flip, inpt) + return _get_kernel(F.vertical_flip, type(inpt), allow_passthrough=True)(inpt) class Resize(Transform): @@ -151,8 +152,7 @@ def __init__( self.antialias = antialias def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( - F.resize, + return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( inpt, self.size, interpolation=self.interpolation, @@ -186,7 +186,7 @@ def __init__(self, size: Union[int, Sequence[int]]): self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.center_crop, inpt, output_size=self.size) + return _get_kernel(F.center_crop, type(inpt), allow_passthrough=True)(inpt, output_size=self.size) class RandomResizedCrop(Transform): @@ -307,8 +307,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( - F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + return _get_kernel(F.resized_crop, type(inpt), allow_passthrough=True)( + inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) @@ -361,7 +361,7 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.five_crop, inpt, self.size) + return _get_kernel(F.five_crop, type(inpt), allow_passthrough=True)(inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): @@ -404,7 +404,9 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) + return _get_kernel(F.ten_crop, type(inpt), allow_passthrough=True)( + inpt, self.size, vertical_flip=self.vertical_flip + ) class Pad(Transform): @@ -478,7 +480,7 @@ def __init__( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + return _get_kernel(F.pad, type(inpt), allow_passthrough=True)(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] class RandomZoomOut(_RandomApplyTransform): @@ -548,7 +550,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop(F.pad, inpt, **params, fill=fill) + return _get_kernel(F.pad, type(inpt), allow_passthrough=True)(inpt, **params, fill=fill) class RandomRotation(Transform): @@ -614,8 +616,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( - F.rotate, + return _get_kernel(F.rotate, type(inpt), allow_passthrough=True)( inpt, **params, interpolation=self.interpolation, @@ -737,8 +738,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( - F.affine, + return _get_kernel(F.affine, type(inpt), allow_passthrough=True)( inpt, **params, interpolation=self.interpolation, @@ -894,11 +894,13 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) - inpt = self._call_or_noop(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) + inpt = _get_kernel(F.pad, type(inpt), allow_passthrough=True)( + inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode + ) if params["needs_crop"]: - inpt = self._call_or_noop( - F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + inpt = _get_kernel(F.crop, type(inpt), allow_passthrough=True)( + inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] ) return inpt @@ -980,8 +982,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( - F.perspective, + return _get_kernel(F.perspective, type(inpt), allow_passthrough=True)( inpt, None, None, @@ -1058,7 +1059,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if kx % 2 == 0: kx += 1 - dx = self._call_or_noop(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) + dx = _get_kernel(F.gaussian_blur, type(dx), allow_passthrough=True)(dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / size[0] dy = torch.rand([1, 1] + size) * 2 - 1 @@ -1067,15 +1068,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if ky % 2 == 0: ky += 1 - dy = self._call_or_noop(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) + dy = _get_kernel(F.gaussian_blur, type(dy), allow_passthrough=True)(dy, [ky, ky], list(self.sigma)) dy = dy * self.alpha[1] / size[1] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( - F.elastic, + return _get_kernel(F.elastic, type(inpt), allow_passthrough=True)( inpt, **params, fill=fill, @@ -1199,8 +1199,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt - output = self._call_or_noop( - F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + output = _get_kernel(F.crop, type(inpt), allow_passthrough=True)( + inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] ) if isinstance(output, datapoints.BoundingBoxes): @@ -1275,8 +1275,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( - F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( + inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1345,8 +1345,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( - F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( + inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1417,6 +1417,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=[size]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( - F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias + return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( + inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias ) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 668e4408a97..3cc93315779 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -8,6 +8,7 @@ from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _get_kernel from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size from .utils import get_bounding_boxes, has_any, is_simple_tensor @@ -171,7 +172,9 @@ def _check_inputs(self, sample: Any) -> Any: def _transform( self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] ) -> Any: - return self._call_or_noop(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) + return _get_kernel(F.normalize, type(inpt), allow_passthrough=True)( + inpt, mean=self.mean, std=self.std, inplace=self.inplace + ) class GaussianBlur(Transform): @@ -218,7 +221,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(sigma=[sigma, sigma]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.gaussian_blur, inpt, self.kernel_size, **params) + return _get_kernel(F.gaussian_blur, type(inpt), allow_passthrough=True)(inpt, self.kernel_size, **params) class ToDtype(Transform): @@ -291,7 +294,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) return inpt - return self._call_or_noop(F.to_dtype, inpt, dtype=dtype, scale=self.scale) + return _get_kernel(F.to_dtype, type(inpt), allow_passthrough=True)(inpt, dtype=dtype, scale=self.scale) class ConvertImageDtype(Transform): @@ -326,7 +329,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.to_dtype, inpt, dtype=self.dtype, scale=True) + return _get_kernel(F.to_dtype, type(inpt), allow_passthrough=True)(inpt, dtype=self.dtype, scale=True) class SanitizeBoundingBoxes(Transform): diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index 3e3d332e024..49ad3deaaeb 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -4,6 +4,8 @@ from torchvision import datapoints from torchvision.transforms.v2 import functional as F, Transform +from torchvision.transforms.v2.functional._utils import _get_kernel + class UniformTemporalSubsample(Transform): """[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video. @@ -26,4 +28,4 @@ def __init__(self, num_samples: int): self.num_samples = num_samples def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType: - return self._call_or_noop(F.uniform_temporal_subsample, inpt, self.num_samples) + return _get_kernel(F.uniform_temporal_subsample, type(inpt), allow_passthrough=True)(inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index aa337cdce3d..f83ed5d6e11 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,8 +11,6 @@ from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once -from .functional._utils import _get_kernel - class Transform(nn.Module): @@ -30,10 +28,6 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict() - def _call_or_noop(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: - kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True) - return kernel(inpt, *args, **kwargs) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError From c0c2517db242fa88e51b82da7fba2937b96b9d0f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 10:08:13 +0200 Subject: [PATCH 05/11] fix test --- test/test_transforms_v2_refactored.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_transforms_v2_refactored.py b/test/test_transforms_v2_refactored.py index 53b21c33e51..156dbc77f06 100644 --- a/test/test_transforms_v2_refactored.py +++ b/test/test_transforms_v2_refactored.py @@ -39,7 +39,7 @@ from torchvision.transforms._functional_tensor import _max_value as get_max_value from torchvision.transforms.functional import pil_modes_mapping from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _get_kernel, _noop, _register_kernel_internal +from torchvision.transforms.v2.functional._utils import _get_kernel, _register_kernel_internal @pytest.fixture(autouse=True) @@ -2248,7 +2248,7 @@ class MyDatapoint(datapoints.Datapoint): pass with pytest.raises(TypeError, match="supports inputs of type"): - assert _get_kernel(F.resize, MyDatapoint) is _noop + _get_kernel(F.resize, MyDatapoint) def resize_my_datapoint(): pass From 99bf83f747a7b7b92d6e1b87bf9a746d8d65142e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 10:38:15 +0200 Subject: [PATCH 06/11] noop -> passthrough --- torchvision/transforms/v2/functional/_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index a1f4720da30..6f96197c8c9 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -84,7 +84,7 @@ def register_kernel(dispatcher, datapoint_cls): return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) -def _noop(inpt, *args, __msg__=None, **kwargs): +def _passthrough(inpt, *args, __msg__=None, **kwargs): if __msg__: warnings.warn(__msg__, UserWarning, stacklevel=2) return inpt @@ -112,7 +112,7 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): return registry[cls] if allow_passthrough: - return _noop + return _passthrough raise TypeError( f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, " @@ -127,7 +127,8 @@ def decorator(dispatcher): f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " f"This will likely change in the future." ) - _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(functools.partial(_noop, __msg__=msg)) + kernel = functools.partial(_passthrough, __msg__=msg) + _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(kernel) return dispatcher return decorator From 84af424cb34e91abefe8874f7bfc46944779c67d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 10:57:04 +0200 Subject: [PATCH 07/11] Revert "inline kernel call" This reverts commit 9f512301bdcc7128c97d5cf95501b2dbf108d155. --- torchvision/transforms/v2/_augment.py | 3 +- torchvision/transforms/v2/_color.py | 55 +++++++-------------- torchvision/transforms/v2/_geometry.py | 64 ++++++++++++------------- torchvision/transforms/v2/_misc.py | 11 ++--- torchvision/transforms/v2/_temporal.py | 4 +- torchvision/transforms/v2/_transform.py | 6 +++ 6 files changed, 61 insertions(+), 82 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index ee067e467d3..f475ac34c6c 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -9,7 +9,6 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F -from torchvision.transforms.v2.functional._utils import _get_kernel from ._transform import _RandomApplyTransform, Transform from ._utils import _parse_labels_getter @@ -132,7 +131,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: - inpt = _get_kernel(F.erase, type(inpt), allow_passthrough=True)(inpt, **params, inplace=self.inplace) + inpt = self._call_or_noop(F.erase, inpt, **params, inplace=self.inplace) return inpt diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 24cafdd4905..6444f5613f4 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -4,7 +4,6 @@ import torch from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2.functional._utils import _get_kernel from ._transform import _RandomApplyTransform from .utils import query_chw @@ -29,9 +28,7 @@ def __init__(self, num_output_channels: int = 1): self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.rgb_to_grayscale, type(inpt), allow_passthrough=True)( - inpt, num_output_channels=self.num_output_channels - ) + return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) class RandomGrayscale(_RandomApplyTransform): @@ -58,9 +55,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(num_input_channels=num_input_channels) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.rgb_to_grayscale, type(inpt), allow_passthrough=True)( - inpt, num_output_channels=params["num_input_channels"] - ) + return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) class ColorJitter(Transform): @@ -157,19 +152,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: hue_factor = params["hue_factor"] for fn_id in params["fn_idx"]: if fn_id == 0 and brightness_factor is not None: - output = _get_kernel(F.adjust_brightness, type(output), allow_passthrough=True)( - output, brightness_factor=brightness_factor - ) + output = self._call_or_noop(F.adjust_brightness, output, brightness_factor=brightness_factor) elif fn_id == 1 and contrast_factor is not None: - output = _get_kernel(F.adjust_contrast, type(output), allow_passthrough=True)( - output, contrast_factor=contrast_factor - ) + output = self._call_or_noop(F.adjust_contrast, output, contrast_factor=contrast_factor) elif fn_id == 2 and saturation_factor is not None: - output = _get_kernel(F.adjust_saturation, type(output), allow_passthrough=True)( - output, saturation_factor=saturation_factor - ) + output = self._call_or_noop(F.adjust_saturation, output, saturation_factor=saturation_factor) elif fn_id == 3 and hue_factor is not None: - output = _get_kernel(F.adjust_hue, type(output), allow_passthrough=True)(output, hue_factor=hue_factor) + output = self._call_or_noop(F.adjust_hue, output, hue_factor=hue_factor) return output @@ -247,23 +236,15 @@ def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: if params["brightness_factor"] is not None: - inpt = _get_kernel(F.adjust_brightness, type(inpt), allow_passthrough=True)( - inpt, brightness_factor=params["brightness_factor"] - ) + inpt = self._call_or_noop(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) if params["contrast_factor"] is not None and params["contrast_before"]: - inpt = _get_kernel(F.adjust_contrast, type(inpt), allow_passthrough=True)( - inpt, contrast_factor=params["contrast_factor"] - ) + inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["saturation_factor"] is not None: - inpt = _get_kernel(F.adjust_saturation, type(inpt), allow_passthrough=True)( - inpt, saturation_factor=params["saturation_factor"] - ) + inpt = self._call_or_noop(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"]) if params["hue_factor"] is not None: - inpt = _get_kernel(F.adjust_hue, type(inpt), allow_passthrough=True)(inpt, hue_factor=params["hue_factor"]) + inpt = self._call_or_noop(F.adjust_hue, inpt, hue_factor=params["hue_factor"]) if params["contrast_factor"] is not None and not params["contrast_before"]: - inpt = _get_kernel(F.adjust_contrast, type(inpt), allow_passthrough=True)( - inpt, contrast_factor=params["contrast_factor"] - ) + inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: inpt = _get_kernel(F.permute_channels, type(inpt), allow_passthrough=True)( inpt, permutation=params["channel_permutation"] @@ -287,7 +268,7 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.equalize, type(inpt), allow_passthrough=True)(inpt) + return self._call_or_noop(F.equalize, inpt) class RandomInvert(_RandomApplyTransform): @@ -306,7 +287,7 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.invert, type(inpt), allow_passthrough=True)(inpt) + return self._call_or_noop(F.invert, inpt) class RandomPosterize(_RandomApplyTransform): @@ -331,7 +312,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None: self.bits = bits def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.posterize, type(inpt), allow_passthrough=True)(inpt, bits=self.bits) + return self._call_or_noop(F.posterize, inpt, bits=self.bits) class RandomSolarize(_RandomApplyTransform): @@ -356,7 +337,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None: self.threshold = threshold def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.solarize, type(inpt), allow_passthrough=True)(inpt, threshold=self.threshold) + return self._call_or_noop(F.solarize, inpt, threshold=self.threshold) class RandomAutocontrast(_RandomApplyTransform): @@ -375,7 +356,7 @@ class RandomAutocontrast(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomAutocontrast def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.autocontrast, type(inpt), allow_passthrough=True)(inpt) + return self._call_or_noop(F.autocontrast, inpt) class RandomAdjustSharpness(_RandomApplyTransform): @@ -400,6 +381,4 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: self.sharpness_factor = sharpness_factor def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.adjust_sharpness, type(inpt), allow_passthrough=True)( - inpt, sharpness_factor=self.sharpness_factor - ) + return self._call_or_noop(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 78108aa69d7..9cf5dadf64c 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -11,7 +11,6 @@ from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.v2 import functional as F, InterpolationMode, Transform from torchvision.transforms.v2.functional._geometry import _check_interpolation -from torchvision.transforms.v2.functional._utils import _get_kernel from ._transform import _RandomApplyTransform from ._utils import ( @@ -44,7 +43,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.horizontal_flip, type(inpt), allow_passthrough=True)(inpt) + return self._call_or_noop(F.horizontal_flip, inpt) class RandomVerticalFlip(_RandomApplyTransform): @@ -64,7 +63,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.vertical_flip, type(inpt), allow_passthrough=True)(inpt) + return self._call_or_noop(F.vertical_flip, inpt) class Resize(Transform): @@ -152,7 +151,8 @@ def __init__( self.antialias = antialias def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( + return self._call_or_noop( + F.resize, inpt, self.size, interpolation=self.interpolation, @@ -186,7 +186,7 @@ def __init__(self, size: Union[int, Sequence[int]]): self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.center_crop, type(inpt), allow_passthrough=True)(inpt, output_size=self.size) + return self._call_or_noop(F.center_crop, inpt, output_size=self.size) class RandomResizedCrop(Transform): @@ -307,8 +307,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.resized_crop, type(inpt), allow_passthrough=True)( - inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + return self._call_or_noop( + F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) @@ -361,7 +361,7 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.five_crop, type(inpt), allow_passthrough=True)(inpt, self.size) + return self._call_or_noop(F.five_crop, inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): @@ -404,9 +404,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.ten_crop, type(inpt), allow_passthrough=True)( - inpt, self.size, vertical_flip=self.vertical_flip - ) + return self._call_or_noop(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) class Pad(Transform): @@ -480,7 +478,7 @@ def __init__( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return _get_kernel(F.pad, type(inpt), allow_passthrough=True)(inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + return self._call_or_noop(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] class RandomZoomOut(_RandomApplyTransform): @@ -550,7 +548,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return _get_kernel(F.pad, type(inpt), allow_passthrough=True)(inpt, **params, fill=fill) + return self._call_or_noop(F.pad, inpt, **params, fill=fill) class RandomRotation(Transform): @@ -616,7 +614,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return _get_kernel(F.rotate, type(inpt), allow_passthrough=True)( + return self._call_or_noop( + F.rotate, inpt, **params, interpolation=self.interpolation, @@ -738,7 +737,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return _get_kernel(F.affine, type(inpt), allow_passthrough=True)( + return self._call_or_noop( + F.affine, inpt, **params, interpolation=self.interpolation, @@ -894,13 +894,11 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) - inpt = _get_kernel(F.pad, type(inpt), allow_passthrough=True)( - inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode - ) + inpt = self._call_or_noop(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: - inpt = _get_kernel(F.crop, type(inpt), allow_passthrough=True)( - inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + inpt = self._call_or_noop( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] ) return inpt @@ -982,7 +980,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return _get_kernel(F.perspective, type(inpt), allow_passthrough=True)( + return self._call_or_noop( + F.perspective, inpt, None, None, @@ -1059,7 +1058,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if kx % 2 == 0: kx += 1 - dx = _get_kernel(F.gaussian_blur, type(dx), allow_passthrough=True)(dx, [kx, kx], list(self.sigma)) + dx = self._call_or_noop(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / size[0] dy = torch.rand([1, 1] + size) * 2 - 1 @@ -1068,14 +1067,15 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if ky % 2 == 0: ky += 1 - dy = _get_kernel(F.gaussian_blur, type(dy), allow_passthrough=True)(dy, [ky, ky], list(self.sigma)) + dy = self._call_or_noop(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) dy = dy * self.alpha[1] / size[1] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return _get_kernel(F.elastic, type(inpt), allow_passthrough=True)( + return self._call_or_noop( + F.elastic, inpt, **params, fill=fill, @@ -1199,8 +1199,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt - output = _get_kernel(F.crop, type(inpt), allow_passthrough=True)( - inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] + output = self._call_or_noop( + F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] ) if isinstance(output, datapoints.BoundingBoxes): @@ -1275,8 +1275,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( - inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + return self._call_or_noop( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1345,8 +1345,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( - inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias + return self._call_or_noop( + F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1417,6 +1417,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=[size]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.resize, type(inpt), allow_passthrough=True)( - inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias + return self._call_or_noop( + F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias ) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 3cc93315779..668e4408a97 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -8,7 +8,6 @@ from torchvision import datapoints, transforms as _transforms from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2.functional._utils import _get_kernel from ._utils import _parse_labels_getter, _setup_float_or_seq, _setup_size from .utils import get_bounding_boxes, has_any, is_simple_tensor @@ -172,9 +171,7 @@ def _check_inputs(self, sample: Any) -> Any: def _transform( self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] ) -> Any: - return _get_kernel(F.normalize, type(inpt), allow_passthrough=True)( - inpt, mean=self.mean, std=self.std, inplace=self.inplace - ) + return self._call_or_noop(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) class GaussianBlur(Transform): @@ -221,7 +218,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(sigma=[sigma, sigma]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.gaussian_blur, type(inpt), allow_passthrough=True)(inpt, self.kernel_size, **params) + return self._call_or_noop(F.gaussian_blur, inpt, self.kernel_size, **params) class ToDtype(Transform): @@ -294,7 +291,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) return inpt - return _get_kernel(F.to_dtype, type(inpt), allow_passthrough=True)(inpt, dtype=dtype, scale=self.scale) + return self._call_or_noop(F.to_dtype, inpt, dtype=dtype, scale=self.scale) class ConvertImageDtype(Transform): @@ -329,7 +326,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return _get_kernel(F.to_dtype, type(inpt), allow_passthrough=True)(inpt, dtype=self.dtype, scale=True) + return self._call_or_noop(F.to_dtype, inpt, dtype=self.dtype, scale=True) class SanitizeBoundingBoxes(Transform): diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index 49ad3deaaeb..3e3d332e024 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -4,8 +4,6 @@ from torchvision import datapoints from torchvision.transforms.v2 import functional as F, Transform -from torchvision.transforms.v2.functional._utils import _get_kernel - class UniformTemporalSubsample(Transform): """[BETA] Uniformly subsample ``num_samples`` indices from the temporal dimension of the video. @@ -28,4 +26,4 @@ def __init__(self, num_samples: int): self.num_samples = num_samples def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType: - return _get_kernel(F.uniform_temporal_subsample, type(inpt), allow_passthrough=True)(inpt, self.num_samples) + return self._call_or_noop(F.uniform_temporal_subsample, inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index f83ed5d6e11..aa337cdce3d 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -11,6 +11,8 @@ from torchvision.transforms.v2.utils import check_type, has_any, is_simple_tensor from torchvision.utils import _log_api_usage_once +from .functional._utils import _get_kernel + class Transform(nn.Module): @@ -28,6 +30,10 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict() + def _call_or_noop(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True) + return kernel(inpt, *args, **kwargs) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: raise NotImplementedError From afd48df7e14ecc9d6a837a1aad8090ee761deef7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 9 Aug 2023 10:57:44 +0200 Subject: [PATCH 08/11] _call_or_noop -> _call_kernel --- torchvision/transforms/v2/_augment.py | 2 +- torchvision/transforms/v2/_color.py | 34 ++++++++++---------- torchvision/transforms/v2/_geometry.py | 42 ++++++++++++------------- torchvision/transforms/v2/_misc.py | 8 ++--- torchvision/transforms/v2/_temporal.py | 2 +- torchvision/transforms/v2/_transform.py | 2 +- 6 files changed, 45 insertions(+), 45 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index f475ac34c6c..ca424adb6c9 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -131,7 +131,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["v"] is not None: - inpt = self._call_or_noop(F.erase, inpt, **params, inplace=self.inplace) + inpt = self._call_kernel(F.erase, inpt, **params, inplace=self.inplace) return inpt diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 6444f5613f4..463bd87bd94 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -28,7 +28,7 @@ def __init__(self, num_output_channels: int = 1): self.num_output_channels = num_output_channels def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) + return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=self.num_output_channels) class RandomGrayscale(_RandomApplyTransform): @@ -55,7 +55,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(num_input_channels=num_input_channels) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) + return self._call_kernel(F.rgb_to_grayscale, inpt, num_output_channels=params["num_input_channels"]) class ColorJitter(Transform): @@ -152,13 +152,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: hue_factor = params["hue_factor"] for fn_id in params["fn_idx"]: if fn_id == 0 and brightness_factor is not None: - output = self._call_or_noop(F.adjust_brightness, output, brightness_factor=brightness_factor) + output = self._call_kernel(F.adjust_brightness, output, brightness_factor=brightness_factor) elif fn_id == 1 and contrast_factor is not None: - output = self._call_or_noop(F.adjust_contrast, output, contrast_factor=contrast_factor) + output = self._call_kernel(F.adjust_contrast, output, contrast_factor=contrast_factor) elif fn_id == 2 and saturation_factor is not None: - output = self._call_or_noop(F.adjust_saturation, output, saturation_factor=saturation_factor) + output = self._call_kernel(F.adjust_saturation, output, saturation_factor=saturation_factor) elif fn_id == 3 and hue_factor is not None: - output = self._call_or_noop(F.adjust_hue, output, hue_factor=hue_factor) + output = self._call_kernel(F.adjust_hue, output, hue_factor=hue_factor) return output @@ -236,15 +236,15 @@ def _transform( self, inpt: Union[datapoints._ImageType, datapoints._VideoType], params: Dict[str, Any] ) -> Union[datapoints._ImageType, datapoints._VideoType]: if params["brightness_factor"] is not None: - inpt = self._call_or_noop(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) + inpt = self._call_kernel(F.adjust_brightness, inpt, brightness_factor=params["brightness_factor"]) if params["contrast_factor"] is not None and params["contrast_before"]: - inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) + inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["saturation_factor"] is not None: - inpt = self._call_or_noop(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"]) + inpt = self._call_kernel(F.adjust_saturation, inpt, saturation_factor=params["saturation_factor"]) if params["hue_factor"] is not None: - inpt = self._call_or_noop(F.adjust_hue, inpt, hue_factor=params["hue_factor"]) + inpt = self._call_kernel(F.adjust_hue, inpt, hue_factor=params["hue_factor"]) if params["contrast_factor"] is not None and not params["contrast_before"]: - inpt = self._call_or_noop(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) + inpt = self._call_kernel(F.adjust_contrast, inpt, contrast_factor=params["contrast_factor"]) if params["channel_permutation"] is not None: inpt = _get_kernel(F.permute_channels, type(inpt), allow_passthrough=True)( inpt, permutation=params["channel_permutation"] @@ -268,7 +268,7 @@ class RandomEqualize(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomEqualize def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.equalize, inpt) + return self._call_kernel(F.equalize, inpt) class RandomInvert(_RandomApplyTransform): @@ -287,7 +287,7 @@ class RandomInvert(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomInvert def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.invert, inpt) + return self._call_kernel(F.invert, inpt) class RandomPosterize(_RandomApplyTransform): @@ -312,7 +312,7 @@ def __init__(self, bits: int, p: float = 0.5) -> None: self.bits = bits def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.posterize, inpt, bits=self.bits) + return self._call_kernel(F.posterize, inpt, bits=self.bits) class RandomSolarize(_RandomApplyTransform): @@ -337,7 +337,7 @@ def __init__(self, threshold: float, p: float = 0.5) -> None: self.threshold = threshold def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.solarize, inpt, threshold=self.threshold) + return self._call_kernel(F.solarize, inpt, threshold=self.threshold) class RandomAutocontrast(_RandomApplyTransform): @@ -356,7 +356,7 @@ class RandomAutocontrast(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomAutocontrast def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.autocontrast, inpt) + return self._call_kernel(F.autocontrast, inpt) class RandomAdjustSharpness(_RandomApplyTransform): @@ -381,4 +381,4 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: self.sharpness_factor = sharpness_factor def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) + return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=self.sharpness_factor) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index 9cf5dadf64c..8806efa6bb5 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -43,7 +43,7 @@ class RandomHorizontalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomHorizontalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.horizontal_flip, inpt) + return self._call_kernel(F.horizontal_flip, inpt) class RandomVerticalFlip(_RandomApplyTransform): @@ -63,7 +63,7 @@ class RandomVerticalFlip(_RandomApplyTransform): _v1_transform_cls = _transforms.RandomVerticalFlip def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.vertical_flip, inpt) + return self._call_kernel(F.vertical_flip, inpt) class Resize(Transform): @@ -151,7 +151,7 @@ def __init__( self.antialias = antialias def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( + return self._call_kernel( F.resize, inpt, self.size, @@ -186,7 +186,7 @@ def __init__(self, size: Union[int, Sequence[int]]): self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.center_crop, inpt, output_size=self.size) + return self._call_kernel(F.center_crop, inpt, output_size=self.size) class RandomResizedCrop(Transform): @@ -307,7 +307,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( + return self._call_kernel( F.resized_crop, inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias ) @@ -361,7 +361,7 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.five_crop, inpt, self.size) + return self._call_kernel(F.five_crop, inpt, self.size) def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): @@ -404,7 +404,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) + return self._call_kernel(F.ten_crop, inpt, self.size, vertical_flip=self.vertical_flip) class Pad(Transform): @@ -478,7 +478,7 @@ def __init__( def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] + return self._call_kernel(F.pad, inpt, padding=self.padding, fill=fill, padding_mode=self.padding_mode) # type: ignore[arg-type] class RandomZoomOut(_RandomApplyTransform): @@ -548,7 +548,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop(F.pad, inpt, **params, fill=fill) + return self._call_kernel(F.pad, inpt, **params, fill=fill) class RandomRotation(Transform): @@ -614,7 +614,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( + return self._call_kernel( F.rotate, inpt, **params, @@ -737,7 +737,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( + return self._call_kernel( F.affine, inpt, **params, @@ -894,10 +894,10 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) - inpt = self._call_or_noop(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) + inpt = self._call_kernel(F.pad, inpt, padding=params["padding"], fill=fill, padding_mode=self.padding_mode) if params["needs_crop"]: - inpt = self._call_or_noop( + inpt = self._call_kernel( F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] ) @@ -980,7 +980,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( + return self._call_kernel( F.perspective, inpt, None, @@ -1058,7 +1058,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if kx % 2 == 0: kx += 1 - dx = self._call_or_noop(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) + dx = self._call_kernel(F.gaussian_blur, dx, [kx, kx], list(self.sigma)) dx = dx * self.alpha[0] / size[0] dy = torch.rand([1, 1] + size) * 2 - 1 @@ -1067,14 +1067,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: # if kernel size is even we have to make it odd if ky % 2 == 0: ky += 1 - dy = self._call_or_noop(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) + dy = self._call_kernel(F.gaussian_blur, dy, [ky, ky], list(self.sigma)) dy = dy * self.alpha[1] / size[1] displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 return dict(displacement=displacement) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill = _get_fill(self._fill, type(inpt)) - return self._call_or_noop( + return self._call_kernel( F.elastic, inpt, **params, @@ -1199,7 +1199,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if len(params) < 1: return inpt - output = self._call_or_noop( + output = self._call_kernel( F.crop, inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"] ) @@ -1275,7 +1275,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( + return self._call_kernel( F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1345,7 +1345,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=(new_height, new_width)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( + return self._call_kernel( F.resize, inpt, size=params["size"], interpolation=self.interpolation, antialias=self.antialias ) @@ -1417,6 +1417,6 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(size=[size]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop( + return self._call_kernel( F.resize, inpt, params["size"], interpolation=self.interpolation, antialias=self.antialias ) diff --git a/torchvision/transforms/v2/_misc.py b/torchvision/transforms/v2/_misc.py index 668e4408a97..11ea00935dc 100644 --- a/torchvision/transforms/v2/_misc.py +++ b/torchvision/transforms/v2/_misc.py @@ -171,7 +171,7 @@ def _check_inputs(self, sample: Any) -> Any: def _transform( self, inpt: Union[datapoints._TensorImageType, datapoints._TensorVideoType], params: Dict[str, Any] ) -> Any: - return self._call_or_noop(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) + return self._call_kernel(F.normalize, inpt, mean=self.mean, std=self.std, inplace=self.inplace) class GaussianBlur(Transform): @@ -218,7 +218,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(sigma=[sigma, sigma]) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.gaussian_blur, inpt, self.kernel_size, **params) + return self._call_kernel(F.gaussian_blur, inpt, self.kernel_size, **params) class ToDtype(Transform): @@ -291,7 +291,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: ) return inpt - return self._call_or_noop(F.to_dtype, inpt, dtype=dtype, scale=self.scale) + return self._call_kernel(F.to_dtype, inpt, dtype=dtype, scale=self.scale) class ConvertImageDtype(Transform): @@ -326,7 +326,7 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None: self.dtype = dtype def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return self._call_or_noop(F.to_dtype, inpt, dtype=self.dtype, scale=True) + return self._call_kernel(F.to_dtype, inpt, dtype=self.dtype, scale=True) class SanitizeBoundingBoxes(Transform): diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index 3e3d332e024..b55305426d2 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -26,4 +26,4 @@ def __init__(self, num_samples: int): self.num_samples = num_samples def _transform(self, inpt: datapoints._VideoType, params: Dict[str, Any]) -> datapoints._VideoType: - return self._call_or_noop(F.uniform_temporal_subsample, inpt, self.num_samples) + return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples) diff --git a/torchvision/transforms/v2/_transform.py b/torchvision/transforms/v2/_transform.py index aa337cdce3d..5a310ddbd4c 100644 --- a/torchvision/transforms/v2/_transform.py +++ b/torchvision/transforms/v2/_transform.py @@ -30,7 +30,7 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict() - def _call_or_noop(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: kernel = _get_kernel(dispatcher, type(inpt), allow_passthrough=True) return kernel(inpt, *args, **kwargs) From 06d422f14c9a156ca705144a97005801aa6284db Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Aug 2023 12:23:13 +0200 Subject: [PATCH 09/11] fix tests --- test/test_prototype_transforms.py | 63 --- test/test_transforms_v2.py | 362 +----------------- test/test_transforms_v2_consistency.py | 31 +- torchvision/prototype/transforms/_geometry.py | 5 +- torchvision/transforms/v2/_color.py | 2 +- torchvision/transforms/v2/_temporal.py | 2 +- 6 files changed, 15 insertions(+), 450 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d395c224785..43a7df4f3a2 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,3 @@ -import itertools - import re import PIL.Image @@ -19,7 +17,6 @@ from torchvision.datapoints import BoundingBoxes, BoundingBoxFormat, Image, Mask, Video from torchvision.prototype import datapoints, transforms -from torchvision.transforms.v2._utils import _convert_fill_arg from torchvision.transforms.v2.functional import clamp_bounding_boxes, InterpolationMode, pil_to_tensor, to_image_pil from torchvision.transforms.v2.utils import check_type, is_simple_tensor @@ -187,66 +184,6 @@ def test__get_params(self, mocker): assert params["needs_pad"] assert any(pad > 0 for pad in params["padding"]) - @pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2))) - def test__transform(self, mocker, needs): - fill_sentinel = 12 - padding_mode_sentinel = mocker.MagicMock() - - transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel) - transform._transformed_types = (mocker.MagicMock,) - mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True) - - needs_crop, needs_pad = needs - top_sentinel = mocker.MagicMock() - left_sentinel = mocker.MagicMock() - height_sentinel = mocker.MagicMock() - width_sentinel = mocker.MagicMock() - is_valid = mocker.MagicMock() if needs_crop else None - padding_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", - return_value=dict( - needs_crop=needs_crop, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - is_valid=is_valid, - padding=padding_sentinel, - needs_pad=needs_pad, - ), - ) - - inpt_sentinel = mocker.MagicMock() - - mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop") - mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad") - transform(inpt_sentinel) - - if needs_crop: - mock_crop.assert_called_once_with( - inpt_sentinel, - top=top_sentinel, - left=left_sentinel, - height=height_sentinel, - width=width_sentinel, - ) - else: - mock_crop.assert_not_called() - - if needs_pad: - # If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use - # `MagicMock.assert_called_once_with` and have to perform the checks manually - mock_pad.assert_called_once() - args, kwargs = mock_pad.call_args - if not needs_crop: - assert args[0] is inpt_sentinel - assert args[1] is padding_sentinel - fill_sentinel = _convert_fill_arg(fill_sentinel) - assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel) - else: - mock_pad.assert_not_called() - def test__transform_culling(self, mocker): batch_size = 10 canvas_size = (10, 10) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 5f4a9b62898..4db2abe7fc4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -27,7 +27,7 @@ from torch.utils._pytree import tree_flatten, tree_unflatten from torchvision import datapoints from torchvision.ops.boxes import box_iou -from torchvision.transforms.functional import InterpolationMode, to_pil_image +from torchvision.transforms.functional import to_pil_image from torchvision.transforms.v2 import functional as F from torchvision.transforms.v2.utils import check_type, is_simple_tensor, query_chw @@ -419,46 +419,6 @@ def test_assertions(self): with pytest.raises(ValueError, match="Padding mode should be either"): transforms.Pad(12, padding_mode="abc") - @pytest.mark.parametrize("padding", [1, (1, 2), [1, 2, 3, 4]]) - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) - def test__transform(self, padding, fill, padding_mode, mocker): - transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - inpt = mocker.MagicMock(spec=datapoints.Image) - _ = transform(inpt) - - fill = transforms._utils._convert_fill_arg(fill) - if isinstance(padding, tuple): - padding = list(padding) - fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) - - @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}]) - def test__transform_image_mask(self, fill, mocker): - transform = transforms.Pad(1, fill=fill, padding_mode="constant") - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - image = datapoints.Image(torch.rand(3, 32, 32)) - mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) - inpt = [image, mask] - _ = transform(inpt) - - if isinstance(fill, int): - fill = transforms._utils._convert_fill_arg(fill) - calls = [ - mocker.call(image, padding=1, fill=fill, padding_mode="constant"), - mocker.call(mask, padding=1, fill=fill, padding_mode="constant"), - ] - else: - fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) - fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) - calls = [ - mocker.call(image, padding=1, fill=fill_img, padding_mode="constant"), - mocker.call(mask, padding=1, fill=fill_mask, padding_mode="constant"), - ] - fn.assert_has_calls(calls) - class TestRandomZoomOut: def test_assertions(self): @@ -487,56 +447,6 @@ def test__get_params(self, fill, side_range): assert 0 <= params["padding"][2] <= (side_range[1] - 1) * w assert 0 <= params["padding"][3] <= (side_range[1] - 1) * h - @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) - @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__transform(self, fill, side_range, mocker): - inpt = make_image((24, 32)) - - transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params([inpt]) - - fill = transforms._utils._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, fill=fill) - - @pytest.mark.parametrize("fill", [12, {datapoints.Image: 12, datapoints.Mask: 34}]) - def test__transform_image_mask(self, fill, mocker): - transform = transforms.RandomZoomOut(fill=fill, p=1.0) - - fn = mocker.patch("torchvision.transforms.v2.functional.pad") - image = datapoints.Image(torch.rand(3, 32, 32)) - mask = datapoints.Mask(torch.randint(0, 5, size=(32, 32))) - inpt = [image, mask] - - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params(inpt) - - if isinstance(fill, int): - fill = transforms._utils._convert_fill_arg(fill) - calls = [ - mocker.call(image, **params, fill=fill), - mocker.call(mask, **params, fill=fill), - ] - else: - fill_img = transforms._utils._convert_fill_arg(fill[type(image)]) - fill_mask = transforms._utils._convert_fill_arg(fill[type(mask)]) - calls = [ - mocker.call(image, **params, fill=fill_img), - mocker.call(mask, **params, fill=fill_mask), - ] - fn.assert_has_calls(calls) - class TestRandomCrop: def test_assertions(self): @@ -599,51 +509,6 @@ def test__get_params(self, padding, pad_if_needed, size): assert params["needs_pad"] is any(padding) assert params["padding"] == padding - @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) - @pytest.mark.parametrize("pad_if_needed", [False, True]) - @pytest.mark.parametrize("fill", [False, True]) - @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) - def test__transform(self, padding, pad_if_needed, fill, padding_mode, mocker): - output_size = [10, 12] - transform = transforms.RandomCrop( - output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode - ) - - h, w = size = (32, 32) - inpt = make_image(size) - - if isinstance(padding, int): - new_size = (h + padding, w + padding) - elif isinstance(padding, list): - new_size = (h + sum(padding[0::2]), w + sum(padding[1::2])) - else: - new_size = size - expected = make_image(new_size) - _ = mocker.patch("torchvision.transforms.v2.functional.pad", return_value=expected) - fn_crop = mocker.patch("torchvision.transforms.v2.functional.crop") - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params([inpt]) - if padding is None and not pad_if_needed: - fn_crop.assert_called_once_with( - inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] - ) - elif not pad_if_needed: - fn_crop.assert_called_once_with( - expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] - ) - elif padding is None: - # vfdev-5: I do not know how to mock and test this case - pass - else: - # vfdev-5: I do not know how to mock and test this case - pass - class TestGaussianBlur: def test_assertions(self): @@ -675,62 +540,6 @@ def test__get_params(self, sigma): assert sigma[0] <= params["sigma"][0] <= sigma[1] assert sigma[0] <= params["sigma"][1] <= sigma[1] - @pytest.mark.parametrize("kernel_size", [3, [3, 5], (5, 3)]) - @pytest.mark.parametrize("sigma", [2.0, [2.0, 3.0]]) - def test__transform(self, kernel_size, sigma, mocker): - transform = transforms.GaussianBlur(kernel_size=kernel_size, sigma=sigma) - - if isinstance(kernel_size, (tuple, list)): - assert transform.kernel_size == kernel_size - else: - kernel_size = (kernel_size, kernel_size) - assert transform.kernel_size == kernel_size - - if isinstance(sigma, (tuple, list)): - assert transform.sigma == sigma - else: - assert transform.sigma == [sigma, sigma] - - fn = mocker.patch("torchvision.transforms.v2.functional.gaussian_blur") - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.canvas_size = (24, 32) - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - params = transform._get_params([inpt]) - - fn.assert_called_once_with(inpt, kernel_size, **params) - - -class TestRandomColorOp: - @pytest.mark.parametrize("p", [0.0, 1.0]) - @pytest.mark.parametrize( - "transform_cls, func_op_name, kwargs", - [ - (transforms.RandomEqualize, "equalize", {}), - (transforms.RandomInvert, "invert", {}), - (transforms.RandomAutocontrast, "autocontrast", {}), - (transforms.RandomPosterize, "posterize", {"bits": 4}), - (transforms.RandomSolarize, "solarize", {"threshold": 0.5}), - (transforms.RandomAdjustSharpness, "adjust_sharpness", {"sharpness_factor": 0.5}), - ], - ) - def test__transform(self, p, transform_cls, func_op_name, kwargs, mocker): - transform = transform_cls(p=p, **kwargs) - - fn = mocker.patch(f"torchvision.transforms.v2.functional.{func_op_name}") - inpt = mocker.MagicMock(spec=datapoints.Image) - _ = transform(inpt) - if p > 0.0: - fn.assert_called_once_with(inpt, **kwargs) - else: - assert fn.call_count == 0 - class TestRandomPerspective: def test_assertions(self): @@ -751,28 +560,6 @@ def test__get_params(self): assert "coefficients" in params assert len(params["coefficients"]) == 8 - @pytest.mark.parametrize("distortion_scale", [0.1, 0.7]) - def test__transform(self, distortion_scale, mocker): - interpolation = InterpolationMode.BILINEAR - fill = 12 - transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) - - fn = mocker.patch("torchvision.transforms.v2.functional.perspective") - - inpt = make_image((24, 32)) - - # vfdev-5, Feature Request: let's store params as Transform attribute - # This could be also helpful for users - # Otherwise, we can mock transform._get_params - torch.manual_seed(12) - _ = transform(inpt) - torch.manual_seed(12) - torch.rand(1) # random apply changes random state - params = transform._get_params([inpt]) - - fill = transforms._utils._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, None, None, **params, fill=fill, interpolation=interpolation) - class TestElasticTransform: def test_assertions(self): @@ -813,35 +600,6 @@ def test__get_params(self): assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all() - @pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]]) - @pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]]) - def test__transform(self, alpha, sigma, mocker): - interpolation = InterpolationMode.BILINEAR - fill = 12 - transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation) - - if isinstance(alpha, float): - assert transform.alpha == [alpha, alpha] - else: - assert transform.alpha == alpha - - if isinstance(sigma, float): - assert transform.sigma == [sigma, sigma] - else: - assert transform.sigma == sigma - - fn = mocker.patch("torchvision.transforms.v2.functional.elastic") - inpt = mocker.MagicMock(spec=datapoints.Image) - inpt.num_channels = 3 - inpt.canvas_size = (24, 32) - - # Let's mock transform._get_params to control the output: - transform._get_params = mocker.MagicMock() - _ = transform(inpt) - params = transform._get_params([inpt]) - fill = transforms._utils._convert_fill_arg(fill) - fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) - class TestRandomErasing: def test_assertions(self): @@ -889,40 +647,6 @@ def test__get_params(self, value): assert 0 <= i <= height - h assert 0 <= j <= width - w - @pytest.mark.parametrize("p", [0, 1]) - def test__transform(self, mocker, p): - transform = transforms.RandomErasing(p=p) - transform._transformed_types = (mocker.MagicMock,) - - i_sentinel = mocker.MagicMock() - j_sentinel = mocker.MagicMock() - h_sentinel = mocker.MagicMock() - w_sentinel = mocker.MagicMock() - v_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._augment.RandomErasing._get_params", - return_value=dict(i=i_sentinel, j=j_sentinel, h=h_sentinel, w=w_sentinel, v=v_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.transforms.v2._augment.F.erase") - output = transform(inpt_sentinel) - - if p: - mock.assert_called_once_with( - inpt_sentinel, - i=i_sentinel, - j=j_sentinel, - h=h_sentinel, - w=w_sentinel, - v=v_sentinel, - inplace=transform.inplace, - ) - else: - mock.assert_not_called() - assert output is inpt_sentinel - class TestTransform: @pytest.mark.parametrize( @@ -1111,23 +835,12 @@ def test__transform(self, mocker): sample = [image, bboxes, masks] - fn = mocker.patch("torchvision.transforms.v2.functional.crop", side_effect=lambda x, **params: x) is_within_crop_area = torch.tensor([0, 1, 0, 1, 0, 1], dtype=torch.bool) params = dict(top=1, left=2, height=12, width=12, is_within_crop_area=is_within_crop_area) transform._get_params = mocker.MagicMock(return_value=params) output = transform(sample) - assert fn.call_count == 3 - - expected_calls = [ - mocker.call(image, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - mocker.call(bboxes, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - mocker.call(masks, top=params["top"], left=params["left"], height=params["height"], width=params["width"]), - ] - - fn.assert_has_calls(expected_calls) - # check number of bboxes vs number of labels: output_bboxes = output[1] assert isinstance(output_bboxes, datapoints.BoundingBoxes) @@ -1164,29 +877,6 @@ def test__get_params(self): assert int(canvas_size[0] * r_min) <= height <= int(canvas_size[0] * r_max) assert int(canvas_size[1] * r_min) <= width <= int(canvas_size[1] * r_max) - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) - antialias_sentinel = mocker.MagicMock() - - transform = transforms.ScaleJitter( - target_size=(16, 12), interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._geometry.ScaleJitter._get_params", return_value=dict(size=size_sentinel) - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - transform(inpt_sentinel) - - mock.assert_called_once_with( - inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - class TestRandomShortestSize: @pytest.mark.parametrize("min_size,max_size", [([5, 9], 20), ([5, 9], None)]) @@ -1211,30 +901,6 @@ def test__get_params(self, min_size, max_size): else: assert shorter in min_size - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) - antialias_sentinel = mocker.MagicMock() - - transform = transforms.RandomShortestSize( - min_size=[3, 5, 7], max_size=12, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._geometry.RandomShortestSize._get_params", - return_value=dict(size=size_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - transform(inpt_sentinel) - - mock.assert_called_once_with( - inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - class TestLinearTransformation: def test_assertions(self): @@ -1260,7 +926,7 @@ def test__transform(self, inpt): transform = transforms.LinearTransformation(m, v) if isinstance(inpt, PIL.Image.Image): - with pytest.raises(TypeError, match="LinearTransformation does not work on PIL Images"): + with pytest.raises(TypeError, match="does not support PIL images"): transform(inpt) else: output = transform(inpt) @@ -1284,30 +950,6 @@ def test__get_params(self): assert min_size <= size < max_size - def test__transform(self, mocker): - interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode) - antialias_sentinel = mocker.MagicMock() - - transform = transforms.RandomResize( - min_size=-1, max_size=-1, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - transform._transformed_types = (mocker.MagicMock,) - - size_sentinel = mocker.MagicMock() - mocker.patch( - "torchvision.transforms.v2._geometry.RandomResize._get_params", - return_value=dict(size=size_sentinel), - ) - - inpt_sentinel = mocker.MagicMock() - - mock_resize = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - transform(inpt_sentinel) - - mock_resize.assert_called_with( - inpt_sentinel, size_sentinel, interpolation=interpolation_sentinel, antialias=antialias_sentinel - ) - class TestUniformTemporalSubsample: @pytest.mark.parametrize( diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index f5ea69279a1..ef5815c056d 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -1259,32 +1259,17 @@ def check(self, t, t_ref, data_kwargs=None): def test_common(self, t_ref, t, data_kwargs): self.check(t, t_ref, data_kwargs) - def check_resize(self, mocker, t_ref, t): - mock = mocker.patch("torchvision.transforms.v2._geometry.F.resize") - mock_ref = mocker.patch("torchvision.transforms.functional.resize") + def check_resize(self, t_ref, t): for dp, dp_ref in self.make_datapoints(): - mock.reset_mock() - mock_ref.reset_mock() - self.set_seed() - t(dp) - assert mock.call_count == 2 - assert all( - actual is expected - for actual, expected in zip([call_args[0][0] for call_args in mock.call_args_list], dp) - ) + actual_image, actual_mask = t(dp) self.set_seed() - t_ref(*dp_ref) - assert mock_ref.call_count == 2 - assert all( - actual is expected - for actual, expected in zip([call_args[0][0] for call_args in mock_ref.call_args_list], dp_ref) - ) + expected_image, expected_mask = t_ref(*dp_ref) - for args_kwargs, args_kwargs_ref in zip(mock.call_args_list, mock_ref.call_args_list): - assert args_kwargs[0][1] == [args_kwargs_ref[0][1]] + assert prototype_F.get_size(actual_image) == prototype_F.get_size(expected_image) + assert prototype_F.get_size(actual_mask) == prototype_F.get_size(expected_mask) def test_random_resize_train(self, mocker): base_size = 520 @@ -1309,9 +1294,9 @@ def patched_randint(a, b, *other_args, **kwargs): t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size) - self.check_resize(mocker, t_ref, t) + self.check_resize(t_ref, t) - def test_random_resize_eval(self, mocker): + def test_random_resize_eval(self): torch.manual_seed(0) base_size = 520 @@ -1319,7 +1304,7 @@ def test_random_resize_eval(self, mocker): t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) - self.check_resize(mocker, t_ref, t) + self.check_resize(t_ref, t) @pytest.mark.parametrize( diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 1a2802db0ac..fe2e8df47eb 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -101,7 +101,8 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_crop"]: - inpt = F.crop( + inpt = self._call_kernel( + F.crop, inpt, top=params["top"], left=params["left"], @@ -120,6 +121,6 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["needs_pad"]: fill = _get_fill(self._fill, type(inpt)) - inpt = F.pad(inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) + inpt = self._call_kernel(F.pad, inpt, params["padding"], fill=fill, padding_mode=self.padding_mode) return inpt diff --git a/torchvision/transforms/v2/_color.py b/torchvision/transforms/v2/_color.py index 26876822347..a3792797959 100644 --- a/torchvision/transforms/v2/_color.py +++ b/torchvision/transforms/v2/_color.py @@ -173,7 +173,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(permutation=torch.randperm(num_channels)) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - self._call_kernel(F.permute_channels, inpt, params["permutation"]) + return self._call_kernel(F.permute_channels, inpt, params["permutation"]) class RandomPhotometricDistort(Transform): diff --git a/torchvision/transforms/v2/_temporal.py b/torchvision/transforms/v2/_temporal.py index 490a1d31062..df39cde0ecd 100644 --- a/torchvision/transforms/v2/_temporal.py +++ b/torchvision/transforms/v2/_temporal.py @@ -25,4 +25,4 @@ def __init__(self, num_samples: int): self.num_samples = num_samples def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples) + return self._call_kernel(F.uniform_temporal_subsample, inpt, self.num_samples) From 84dc2bf0fb56c676bf883d1b3d9f9da717fa34f6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Aug 2023 14:00:49 +0200 Subject: [PATCH 10/11] remove obsolete test --- test/test_transforms_v2_consistency.py | 47 -------------------------- 1 file changed, 47 deletions(-) diff --git a/test/test_transforms_v2_consistency.py b/test/test_transforms_v2_consistency.py index ef5815c056d..bcab4355c54 100644 --- a/test/test_transforms_v2_consistency.py +++ b/test/test_transforms_v2_consistency.py @@ -1259,53 +1259,6 @@ def check(self, t, t_ref, data_kwargs=None): def test_common(self, t_ref, t, data_kwargs): self.check(t, t_ref, data_kwargs) - def check_resize(self, t_ref, t): - - for dp, dp_ref in self.make_datapoints(): - self.set_seed() - actual_image, actual_mask = t(dp) - - self.set_seed() - expected_image, expected_mask = t_ref(*dp_ref) - - assert prototype_F.get_size(actual_image) == prototype_F.get_size(expected_image) - assert prototype_F.get_size(actual_mask) == prototype_F.get_size(expected_mask) - - def test_random_resize_train(self, mocker): - base_size = 520 - min_size = base_size // 2 - max_size = base_size * 2 - - randint = torch.randint - - def patched_randint(a, b, *other_args, **kwargs): - if kwargs or len(other_args) > 1 or other_args[0] != (): - return randint(a, b, *other_args, **kwargs) - - return random.randint(a, b) - - # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported - # normally - t = v2_transforms.RandomResize(min_size=min_size, max_size=max_size, antialias=True) - mocker.patch( - "torchvision.transforms.v2._geometry.torch.randint", - new=patched_randint, - ) - - t_ref = seg_transforms.RandomResize(min_size=min_size, max_size=max_size) - - self.check_resize(t_ref, t) - - def test_random_resize_eval(self): - torch.manual_seed(0) - base_size = 520 - - t = v2_transforms.Resize(size=base_size, antialias=True) - - t_ref = seg_transforms.RandomResize(min_size=base_size, max_size=base_size) - - self.check_resize(t_ref, t) - @pytest.mark.parametrize( ("legacy_dispatcher", "name_only_params"), From 127e3bdd32988454152cf2ae886ea8bd699b0cc4 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 10 Aug 2023 14:26:52 +0200 Subject: [PATCH 11/11] move passthrough warning to transforms --- torchvision/transforms/v2/_augment.py | 10 +++++++- torchvision/transforms/v2/_geometry.py | 18 ++++++++++++++- .../transforms/v2/functional/_augment.py | 3 +-- .../transforms/v2/functional/_geometry.py | 10 +------- .../transforms/v2/functional/_utils.py | 23 +------------------ 5 files changed, 29 insertions(+), 35 deletions(-) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index ca424adb6c9..9be7a40e8ca 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Dict, List, Tuple +from typing import Any, Callable, Dict, List, Tuple import PIL.Image import torch @@ -91,6 +91,14 @@ def __init__( self._log_ratio = torch.log(torch.tensor(self.ratio)) + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"datapoints.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: img_c, img_h, img_w = query_chw(flat_inputs) diff --git a/torchvision/transforms/v2/_geometry.py b/torchvision/transforms/v2/_geometry.py index d48c7cfaa6c..b209140614e 100644 --- a/torchvision/transforms/v2/_geometry.py +++ b/torchvision/transforms/v2/_geometry.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union +from typing import Any, Callable, cast, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union import PIL.Image import torch @@ -358,6 +358,14 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"datapoints.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return self._call_kernel(F.five_crop, inpt, self.size) @@ -397,6 +405,14 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip + def _call_kernel(self, dispatcher: Callable, inpt: Any, *args: Any, **kwargs: Any) -> Any: + if isinstance(inpt, (datapoints.BoundingBoxes, datapoints.Mask)): + warnings.warn( + f"{type(self).__name__}() is currently passing through inputs of type " + f"datapoints.{type(inpt).__name__}. This will likely change in the future." + ) + return super()._call_kernel(dispatcher, inpt, *args, **kwargs) + def _check_inputs(self, flat_inputs: List[Any]) -> None: if has_any(flat_inputs, datapoints.BoundingBoxes, datapoints.Mask): raise TypeError(f"BoundingBoxes'es and Mask's are not supported by {type(self).__name__}()") diff --git a/torchvision/transforms/v2/functional/_augment.py b/torchvision/transforms/v2/functional/_augment.py index a0230c45bf7..4a927be9777 100644 --- a/torchvision/transforms/v2/functional/_augment.py +++ b/torchvision/transforms/v2/functional/_augment.py @@ -5,10 +5,9 @@ from torchvision.transforms.functional import pil_to_tensor, to_pil_image from torchvision.utils import _log_api_usage_once -from ._utils import _get_kernel, _register_kernel_internal, _register_temporary_passthrough_kernels_internal +from ._utils import _get_kernel, _register_kernel_internal -@_register_temporary_passthrough_kernels_internal(datapoints.BoundingBoxes, datapoints.Mask) def erase( inpt: torch.Tensor, i: int, diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index 4f080252a9c..f8f3b1da0b3 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -25,13 +25,7 @@ from ._meta import clamp_bounding_boxes, convert_format_bounding_boxes, get_size_image_pil -from ._utils import ( - _FillTypeJIT, - _get_kernel, - _register_five_ten_crop_kernel_internal, - _register_kernel_internal, - _register_temporary_passthrough_kernels_internal, -) +from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal def _check_interpolation(interpolation: Union[InterpolationMode, int]) -> InterpolationMode: @@ -2203,7 +2197,6 @@ def resized_crop_video( ) -@_register_temporary_passthrough_kernels_internal(datapoints.BoundingBoxes, datapoints.Mask) def five_crop( inpt: torch.Tensor, size: List[int] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -2276,7 +2269,6 @@ def five_crop_video( return five_crop_image_tensor(video, size) -@_register_temporary_passthrough_kernels_internal(datapoints.BoundingBoxes, datapoints.Mask) def ten_crop( inpt: torch.Tensor, size: List[int], vertical_flip: bool = False ) -> Tuple[ diff --git a/torchvision/transforms/v2/functional/_utils.py b/torchvision/transforms/v2/functional/_utils.py index 8d6b8cc1387..8c95828ee4d 100644 --- a/torchvision/transforms/v2/functional/_utils.py +++ b/torchvision/transforms/v2/functional/_utils.py @@ -1,5 +1,4 @@ import functools -import warnings from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union import torch @@ -87,12 +86,6 @@ def register_kernel(dispatcher, datapoint_cls): return _register_kernel_internal(dispatcher, datapoint_cls, datapoint_wrapper=False) -def _passthrough(inpt, *args, __msg__=None, **kwargs): - if __msg__: - warnings.warn(__msg__, UserWarning, stacklevel=2) - return inpt - - def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): registry = _KERNEL_REGISTRY.get(dispatcher) if not registry: @@ -115,7 +108,7 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): return registry[cls] if allow_passthrough: - return _passthrough + return lambda inpt, *args, **kwargs: inpt raise TypeError( f"Dispatcher F.{dispatcher.__name__} supports inputs of type {registry.keys()}, " @@ -123,20 +116,6 @@ def _get_kernel(dispatcher, input_type, *, allow_passthrough=False): ) -def _register_temporary_passthrough_kernels_internal(*datapoints_classes): - def decorator(dispatcher): - for cls in datapoints_classes: - msg = ( - f"F.{dispatcher.__name__} is currently passing through inputs of type datapoints.{cls.__name__}. " - f"This will likely change in the future." - ) - kernel = functools.partial(_passthrough, __msg__=msg) - _register_kernel_internal(dispatcher, cls, datapoint_wrapper=False)(kernel) - return dispatcher - - return decorator - - # This basically replicates _register_kernel_internal, but with a specialized wrapper for five_crop / ten_crop # We could get rid of this by letting _register_kernel_internal take arbitrary dispatchers rather than wrap_kernel: bool def _register_five_ten_crop_kernel_internal(dispatcher, input_type):