From 25b36671d91445982365b86a867daf765c2ad116 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 09:53:11 +0000 Subject: [PATCH 01/20] Added base tests for rotate_image_tensor --- test/test_prototype_transforms_functional.py | 16 ++++++++++++++++ .../prototype/transforms/functional/_geometry.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index be3932a8b7f..6ec912a4770 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -284,6 +284,22 @@ def affine_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def rotate_image_tensor(): + for image, angle, expand, center, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [-87, 15, 90], # angle + [True, False], # expand + [None, [12, 23]], # center + [None, [128]], # fill + ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + + yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill) + + @register_kernel_info_from_sample_inputs_fn def rotate_bounding_box(): for bounding_box, angle, expand, center in itertools.product( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ac0e8e0eb13..d71706dbb65 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -318,8 +318,8 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[List[float]] = None, center: Optional[List[float]] = None, + fill: Optional[List[float]] = None, ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: From 6b3483ddc90625bd780b4e9eca1c88592a9c63c3 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 11:07:45 +0000 Subject: [PATCH 02/20] Updated resize_image_tensor API and tests and fixed a bug with max_size --- test/test_prototype_transforms_functional.py | 17 +++++++++++------ .../transforms/functional/_geometry.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6ec912a4770..5550c99158a 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -201,19 +201,24 @@ def horizontal_flip_bounding_box(): @register_kernel_info_from_sample_inputs_fn def resize_image_tensor(): - for image, interpolation in itertools.product( + for image, interpolation, max_size, antialias in itertools.product( make_images(), - [ - F.InterpolationMode.BILINEAR, - F.InterpolationMode.NEAREST, - ], + [F.InterpolationMode.BILINEAR, F.InterpolationMode.NEAREST], # interpolation + [None, 34], # max_size + [False, True], # antialias ): + + if antialias and interpolation == F.InterpolationMode.NEAREST: + continue + height, width = image.shape[-2:] for size in [ (height, width), (int(height * 0.75), int(width * 1.25)), ]: - yield SampleInput(image, size=size, interpolation=interpolation) + if max_size is not None: + size = [size[0]] + yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d71706dbb65..9b3e370dde8 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -42,16 +42,17 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: - new_height, new_width = size num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] - return _FT.resize( + output = _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias, - ).reshape(batch_shape + (num_channels, new_height, new_width)) + ) + num_channels, new_height, new_width = get_dimensions_image_tensor(output) + return output.reshape(batch_shape + (num_channels, new_height, new_width)) def resize_image_pil( From ea7c513ff69dedface1d5e4c1708ae5b6ebe9fdf Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 15:36:11 +0000 Subject: [PATCH 03/20] Refactored and modified private api for resize functional op --- test/test_transforms_tensor.py | 20 +++------ torchvision/transforms/functional.py | 46 +++++++++++++++++++- torchvision/transforms/functional_pil.py | 34 +-------------- torchvision/transforms/functional_tensor.py | 47 +-------------------- 4 files changed, 52 insertions(+), 95 deletions(-) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index ba2321ec455..f0cd3ba0021 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -394,9 +394,7 @@ def test_resize_int(self, size): @pytest.mark.parametrize( "size", [ - [ - 32, - ], + [32], [32, 32], (32, 32), [34, 35], @@ -412,7 +410,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): # This is a trivial cast to float of uint8 data to test all cases tensor = tensor.to(dt) if max_size is not None and len(size) != 1: - pytest.xfail("with max_size, size must be a sequence with 2 elements") + pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified") transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size) s_transform = torch.jit.script(transform) @@ -420,11 +418,7 @@ def test_resize_scripted(self, dt, size, max_size, interpolation, device): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) def test_resize_save(self, tmpdir): - transform = T.Resize( - size=[ - 32, - ] - ) + transform = T.Resize(size=[32]) s_transform = torch.jit.script(transform) s_transform.save(os.path.join(tmpdir, "t_resize.pt")) @@ -435,12 +429,8 @@ def test_resize_save(self, tmpdir): "size", [ (32,), - [ - 44, - ], - [ - 32, - ], + [44], + [32], [32, 32], (32, 32), [44, 55], diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index c40ae1eb92b..609c64ad4ff 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -360,6 +360,31 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) +def _compute_output_size( + image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None +) -> Tuple[int, int]: + if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge + h, w = image_size + short, long = (w, h) if w <= h else (h, w) + requested_new_short = size if isinstance(size, int) else size[0] + + new_short, new_long = requested_new_short, int(requested_new_short * long / short) + + if max_size is not None: + if max_size <= requested_new_short: + raise ValueError( + f"max_size = {max_size} must be strictly greater than the requested " + f"size for the smaller edge size = {size}" + ) + if new_long > max_size: + new_short, new_long = int(max_size * new_short / new_long), max_size + + new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) + else: # specified both h and w + new_w, new_h = size[1], size[0] + return new_h, new_w + + def resize( img: Tensor, size: List[int], @@ -423,13 +448,30 @@ def resize( if not isinstance(interpolation, InterpolationMode): raise TypeError("Argument interpolation should be a InterpolationMode") + if isinstance(size, (list, tuple)): + if len(size) not in [1, 2]: + raise ValueError( + f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" + ) + if max_size is not None and len(size) != 1: + raise ValueError( + "max_size should only be passed if size specifies the length of the smaller edge, " + "i.e. size should be an int or a sequence of length 1 in torchscript mode." + ) + + _, image_height, image_width = get_dimensions(img) + output_size = _compute_output_size((image_height, image_width), size, max_size) + + if (image_height, image_width) == output_size: + return img + if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") pil_interpolation = pil_modes_mapping[interpolation] - return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) + return F_pil.resize(img, size=output_size, interpolation=pil_interpolation) - return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) + return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias) def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 0203ee4495b..3c1a911a5d4 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -242,44 +242,14 @@ def resize( img: Image.Image, size: Union[Sequence[int], int], interpolation: int = _pil_constants.BILINEAR, - max_size: Optional[int] = None, ) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): + if not (isinstance(size, Sequence) and len(size) == 2): raise TypeError(f"Got inappropriate size arg: {size}") - if isinstance(size, Sequence) and len(size) == 1: - size = size[0] - if isinstance(size, int): - w, h = img.size - - short, long = (w, h) if w <= h else (h, w) - new_short, new_long = size, int(size * long / short) - - if max_size is not None: - if max_size <= size: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - else: - return img.resize((new_w, new_h), interpolation) - else: - if max_size is not None: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - return img.resize(size[::-1], interpolation) + return img.resize(size[::-1], interpolation) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 1899caebfc3..acc8d3ae3e1 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -430,70 +430,25 @@ def resize( img: Tensor, size: List[int], interpolation: str = "bilinear", - max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> Tensor: _assert_image_tensor(img) - if not isinstance(size, (int, tuple, list)): - raise TypeError("Got inappropriate size arg") - if not isinstance(interpolation, str): - raise TypeError("Got inappropriate interpolation arg") - - if interpolation not in ["nearest", "bilinear", "bicubic"]: - raise ValueError("This interpolation mode is unsupported with Tensor input") - if isinstance(size, tuple): size = list(size) - if isinstance(size, list): - if len(size) not in [1, 2]: - raise ValueError( - f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list" - ) - if max_size is not None and len(size) != 1: - raise ValueError( - "max_size should only be passed if size specifies the length of the smaller edge, " - "i.e. size should be an int or a sequence of length 1 in torchscript mode." - ) - if antialias is None: antialias = False if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") - _, h, w = get_dimensions(img) - - if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge - short, long = (w, h) if w <= h else (h, w) - requested_new_short = size if isinstance(size, int) else size[0] - - new_short, new_long = requested_new_short, int(requested_new_short * long / short) - - if max_size is not None: - if max_size <= requested_new_short: - raise ValueError( - f"max_size = {max_size} must be strictly greater than the requested " - f"size for the smaller edge size = {size}" - ) - if new_long > max_size: - new_short, new_long = int(max_size * new_short / new_long), max_size - - new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) - - if (w, h) == (new_w, new_h): - return img - - else: # specified both h and w - new_w, new_h = size[1], size[0] - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias) + img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias) if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255) From dc64e8a85cfd1e76647985a6e2aa8e852f91930c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 09:53:11 +0000 Subject: [PATCH 04/20] Added base tests for rotate_image_tensor --- test/test_prototype_transforms_functional.py | 16 ++++++++++++++++ .../prototype/transforms/functional/_geometry.py | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index be3932a8b7f..6ec912a4770 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -284,6 +284,22 @@ def affine_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def rotate_image_tensor(): + for image, angle, expand, center, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [-87, 15, 90], # angle + [True, False], # expand + [None, [12, 23]], # center + [None, [128]], # fill + ): + if center is not None and expand: + # Skip warning: The provided center argument is ignored if expand is True + continue + + yield SampleInput(image, angle=angle, expand=expand, center=center, fill=fill) + + @register_kernel_info_from_sample_inputs_fn def rotate_bounding_box(): for bounding_box, angle, expand, center in itertools.product( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ac0e8e0eb13..d71706dbb65 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -318,8 +318,8 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - fill: Optional[List[float]] = None, center: Optional[List[float]] = None, + fill: Optional[List[float]] = None, ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: From d341bebb9d17353603b94993e61e183b531d7f16 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 11:07:45 +0000 Subject: [PATCH 05/20] Updated resize_image_tensor API and tests and fixed a bug with max_size --- test/test_prototype_transforms_functional.py | 17 +++++++++++------ .../transforms/functional/_geometry.py | 7 ++++--- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6ec912a4770..5550c99158a 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -201,19 +201,24 @@ def horizontal_flip_bounding_box(): @register_kernel_info_from_sample_inputs_fn def resize_image_tensor(): - for image, interpolation in itertools.product( + for image, interpolation, max_size, antialias in itertools.product( make_images(), - [ - F.InterpolationMode.BILINEAR, - F.InterpolationMode.NEAREST, - ], + [F.InterpolationMode.BILINEAR, F.InterpolationMode.NEAREST], # interpolation + [None, 34], # max_size + [False, True], # antialias ): + + if antialias and interpolation == F.InterpolationMode.NEAREST: + continue + height, width = image.shape[-2:] for size in [ (height, width), (int(height * 0.75), int(width * 1.25)), ]: - yield SampleInput(image, size=size, interpolation=interpolation) + if max_size is not None: + size = [size[0]] + yield SampleInput(image, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d71706dbb65..9b3e370dde8 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -42,16 +42,17 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: - new_height, new_width = size num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] - return _FT.resize( + output = _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias, - ).reshape(batch_shape + (num_channels, new_height, new_width)) + ) + num_channels, new_height, new_width = get_dimensions_image_tensor(output) + return output.reshape(batch_shape + (num_channels, new_height, new_width)) def resize_image_pil( From aade78f8a7bf36dbe70aaca7afcd2abf546d3ccb Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 16:07:31 +0000 Subject: [PATCH 06/20] Fixed failures --- torchvision/prototype/transforms/functional/_geometry.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ac0e8e0eb13..f1d51fded82 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -42,6 +42,8 @@ def resize_image_tensor( max_size: Optional[int] = None, antialias: Optional[bool] = None, ) -> torch.Tensor: + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now new_height, new_width = size num_channels, old_height, old_width = get_dimensions_image_tensor(image) batch_shape = image.shape[:-3] @@ -49,7 +51,6 @@ def resize_image_tensor( image.reshape((-1, num_channels, old_height, old_width)), size=size, interpolation=interpolation.value, - max_size=max_size, antialias=antialias, ).reshape(batch_shape + (num_channels, new_height, new_width)) @@ -60,7 +61,9 @@ def resize_image_pil( interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, ) -> PIL.Image.Image: - return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size) + # TODO: use _compute_output_size to enable max_size option + max_size # ununsed right now + return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) def resize_segmentation_mask( From a812a3bcdcca6ca8d7af79330220f1344ae89aa9 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 20:57:14 +0000 Subject: [PATCH 07/20] More updates --- torchvision/transforms/functional.py | 12 ++++++------ torchvision/transforms/functional_pil.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 609c64ad4ff..77feadc51f1 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional +from typing import List, Tuple, Any, Optional, Union import numpy as np import torch @@ -360,10 +360,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) -def _compute_output_size( - image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None -) -> Tuple[int, int]: - if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge +def _compute_output_size(image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None) -> List[int]: + if len(size) == 1: # specified size only for the smallest edge h, w = image_size short, long = (w, h) if w <= h else (h, w) requested_new_short = size if isinstance(size, int) else size[0] @@ -382,7 +380,7 @@ def _compute_output_size( new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short) else: # specified both h and w new_w, new_h = size[1], size[0] - return new_h, new_w + return [new_h, new_w] def resize( @@ -460,6 +458,8 @@ def resize( ) _, image_height, image_width = get_dimensions(img) + if isinstance(size, int): + size = [size] output_size = _compute_output_size((image_height, image_width), size, max_size) if (image_height, image_width) == output_size: diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 3c1a911a5d4..7ebd9f71588 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -240,13 +240,13 @@ def crop( @torch.jit.unused def resize( img: Image.Image, - size: Union[Sequence[int], int], + size: Union[List[int], int], interpolation: int = _pil_constants.BILINEAR, ) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - if not (isinstance(size, Sequence) and len(size) == 2): + if not (isinstance(size, list) and len(size) == 2): raise TypeError(f"Got inappropriate size arg: {size}") return img.resize(size[::-1], interpolation) From 6661d8d948a0535cc22e30fe84a81ffac8d48661 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 21:01:36 +0000 Subject: [PATCH 08/20] Updated proto functional op: resize_image_* --- .../prototype/transforms/functional/_geometry.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 2aaed3e4a2e..36015b5c25d 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -6,7 +6,12 @@ import torch from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP -from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix, InterpolationMode, _compute_output_size +from torchvision.transforms.functional import ( + pil_modes_mapping, + _get_inverse_affine_matrix, + InterpolationMode, + _compute_output_size, +) from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil @@ -43,7 +48,8 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: num_channels, old_height, old_width = get_dimensions_image_tensor(image) - new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) + size = _compute_output_size((old_height, old_width), size=size, max_size=max_size) + new_height, new_width = size batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), @@ -59,6 +65,10 @@ def resize_image_pil( interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, ) -> PIL.Image.Image: + if isinstance(size, int): + size = [size, size] + # Explicitly cast size to list otherwise mypy issue: incompatible type "Sequence[int]"; expected "List[int]" + size: List[int] = list(size) size = _compute_output_size(img.size[::-1], size=size, max_size=max_size) return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) From 09728221bbc2a1e59d143e620d5df86033ac677d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 21:03:02 +0000 Subject: [PATCH 09/20] Fixed flake8 --- torchvision/transforms/functional.py | 2 +- torchvision/transforms/functional_pil.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 77feadc51f1..80444c31204 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional, Union +from typing import List, Tuple, Any, Optional import numpy as np import torch diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index 7ebd9f71588..93bdeb8f308 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -1,5 +1,5 @@ import numbers -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch From f0c896ff1391dcac098539db79f14e2c50549d7a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 21:27:44 +0000 Subject: [PATCH 10/20] Added max_size arg to resize_bounding_box and updated basic tests --- test/test_prototype_transforms_functional.py | 23 ++++++++++++++++++- .../transforms/functional/_geometry.py | 6 +++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 5550c99158a..30d9b833ec8 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -223,15 +223,36 @@ def resize_image_tensor(): @register_kernel_info_from_sample_inputs_fn def resize_bounding_box(): - for bounding_box in make_bounding_boxes(): + for bounding_box, max_size in itertools.product( + make_bounding_boxes(), + [None, 34], # max_size + ): height, width = bounding_box.image_size for size in [ (height, width), (int(height * 0.75), int(width * 1.25)), ]: + if max_size is not None: + size = [size[0]] yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size) +@register_kernel_info_from_sample_inputs_fn +def resize_segmentation_mask(): + for mask, max_size in itertools.product( + make_segmentation_masks(), + [None, 34], # max_size + ): + height, width = mask.shape[-2:] + for size in [ + (height, width), + (int(height * 0.75), int(width * 1.25)), + ]: + if max_size is not None: + size = [size[0]] + yield SampleInput(mask, size=size, max_size=max_size) + + @register_kernel_info_from_sample_inputs_fn def affine_image_tensor(): for image, angle, translate, scale, shear in itertools.product( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 36015b5c25d..19085d2a974 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -79,9 +79,11 @@ def resize_segmentation_mask( return resize_image_tensor(segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size) -# TODO: handle max_size -def resize_bounding_box(bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int]) -> torch.Tensor: +def resize_bounding_box( + bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None +) -> torch.Tensor: old_height, old_width = image_size + size = _compute_output_size(image_size, size=size, max_size=max_size) new_height, new_width = size ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) From 1a3a7490db9b0c7ebb8b93b7468e30c1ea1d4fd8 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 22 Jun 2022 21:55:35 +0000 Subject: [PATCH 11/20] WIP Adding ops: - Added H/V flip ops - Added resize op - Added center_crop --- test/test_prototype_transforms.py | 135 +++++----- .../prototype/features/_bounding_box.py | 19 ++ torchvision/prototype/features/_feature.py | 37 ++- torchvision/prototype/features/_image.py | 18 ++ .../prototype/features/_segmentation_mask.py | 19 +- torchvision/prototype/transforms/_geometry.py | 237 ++++++++---------- 6 files changed, 272 insertions(+), 193 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index dc3de480d1f..de1ac950569 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -3,7 +3,12 @@ import pytest import torch from common_utils import assert_equal -from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels +from test_prototype_transforms_functional import ( + make_images, + make_bounding_boxes, + make_one_hot_labels, + make_segmentation_masks, +) from torchvision.prototype import transforms, features from torchvision.transforms.functional import to_pil_image, pil_to_tensor @@ -25,23 +30,23 @@ def make_vanilla_tensor_bounding_boxes(*args, **kwargs): yield bounding_box.data -def parametrize(transforms_with_inputs): +def parametrize(transforms_with_inpts): return pytest.mark.parametrize( - ("transform", "input"), + ("transform", "inpt"), [ pytest.param( transform, - input, - id=f"{type(transform).__name__}-{type(input).__module__}.{type(input).__name__}-{idx}", + inpt, + id=f"{type(transform).__name__}-{type(inpt).__module__}.{type(inpt).__name__}-{idx}", ) - for transform, inputs in transforms_with_inputs - for idx, input in enumerate(inputs) + for transform, inpts in transforms_with_inpts + for idx, inpt in enumerate(inpts) ], ) def parametrize_from_transforms(*transforms): - transforms_with_inputs = [] + transforms_with_inpts = [] for transform in transforms: for creation_fn in [ make_images, @@ -49,32 +54,34 @@ def parametrize_from_transforms(*transforms): make_one_hot_labels, make_vanilla_tensor_images, make_pil_images, + make_segmentation_masks, ]: - inputs = list(creation_fn()) - try: - output = transform(inputs[0]) - except Exception: - continue - else: - if output is inputs[0]: - continue + inpts = list(creation_fn()) + # try: + output = transform(inpts[0]) + # except TypeError: + # continue + # else: + # if output is inpts[0]: + # continue - transforms_with_inputs.append((transform, inputs)) + transforms_with_inpts.append((transform, inpts)) - return parametrize(transforms_with_inputs) + return parametrize(transforms_with_inpts) class TestSmoke: @parametrize_from_transforms( - transforms.RandomErasing(p=1.0), + # transforms.RandomErasing(p=1.0), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), - transforms.ConvertImageDtype(), - transforms.RandomHorizontalFlip(), - transforms.Pad(5), + # transforms.ConvertImageDtype(), + # transforms.RandomHorizontalFlip(), + # transforms.Pad(5), ) - def test_common(self, transform, input): - transform(input) + def test_common(self, transform, inpt): + output = transform(inpt) + assert type(output) == type(inpt) @parametrize( [ @@ -96,8 +103,8 @@ def test_common(self, transform, input): ] ] ) - def test_mixup_cutmix(self, transform, input): - transform(input) + def test_mixup_cutmix(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -127,8 +134,8 @@ def test_mixup_cutmix(self, transform, input): ) ] ) - def test_auto_augment(self, transform, input): - transform(input) + def test_auto_augment(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -144,8 +151,8 @@ def test_auto_augment(self, transform, input): ), ] ) - def test_normalize(self, transform, input): - transform(input) + def test_normalize(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -159,8 +166,8 @@ def test_normalize(self, transform, input): ) ] ) - def test_random_resized_crop(self, transform, input): - transform(input) + def test_random_resized_crop(self, transform, inpt): + transform(inpt) @parametrize( [ @@ -188,58 +195,58 @@ def test_random_resized_crop(self, transform, input): ) ] ) - def test_convert_image_color_space(self, transform, input): - transform(input) + def test_convert_image_color_space(self, transform, inpt): + transform(inpt) @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomHorizontalFlip: - def input_expected_image_tensor(self, p, dtype=torch.float32): - input = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) + def inpt_expected_image_tensor(self, p, dtype=torch.float32): + inpt = torch.tensor([[[0, 1], [0, 1]], [[1, 0], [1, 0]]], dtype=dtype) expected = torch.tensor([[[1, 0], [1, 0]], [[0, 1], [0, 1]]], dtype=dtype) - return input, expected if p == 1 else input + return inpt, expected if p == 1 else inpt def test_simple_tensor(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(input) + actual = transform(inpt) assert_equal(expected, actual) def test_pil_image(self, p): - input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) + inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(to_pil_image(input)) + actual = transform(to_pil_image(inpt)) assert_equal(expected, pil_to_tensor(actual)) def test_features_image(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(features.Image(input)) + actual = transform(features.Image(inpt)) assert_equal(features.Image(expected), actual) def test_features_segmentation_mask(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(features.SegmentationMask(input)) + actual = transform(features.SegmentationMask(inpt)) assert_equal(features.SegmentationMask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) transform = transforms.RandomHorizontalFlip(p=p) - actual = transform(input) + actual = transform(inpt) - expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + expected_image_tensor = torch.tensor([5, 0, 10, 5]) if p == 1.0 else inpt + expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size @@ -247,52 +254,52 @@ def test_features_bounding_box(self, p): @pytest.mark.parametrize("p", [0.0, 1.0]) class TestRandomVerticalFlip: - def input_expected_image_tensor(self, p, dtype=torch.float32): - input = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype) + def inpt_expected_image_tensor(self, p, dtype=torch.float32): + inpt = torch.tensor([[[1, 1], [0, 0]], [[1, 1], [0, 0]]], dtype=dtype) expected = torch.tensor([[[0, 0], [1, 1]], [[0, 0], [1, 1]]], dtype=dtype) - return input, expected if p == 1 else input + return inpt, expected if p == 1 else inpt def test_simple_tensor(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(input) + actual = transform(inpt) assert_equal(expected, actual) def test_pil_image(self, p): - input, expected = self.input_expected_image_tensor(p, dtype=torch.uint8) + inpt, expected = self.inpt_expected_image_tensor(p, dtype=torch.uint8) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(to_pil_image(input)) + actual = transform(to_pil_image(inpt)) assert_equal(expected, pil_to_tensor(actual)) def test_features_image(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(features.Image(input)) + actual = transform(features.Image(inpt)) assert_equal(features.Image(expected), actual) def test_features_segmentation_mask(self, p): - input, expected = self.input_expected_image_tensor(p) + inpt, expected = self.inpt_expected_image_tensor(p) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(features.SegmentationMask(input)) + actual = transform(features.SegmentationMask(inpt)) assert_equal(features.SegmentationMask(expected), actual) def test_features_bounding_box(self, p): - input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) + inpt = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) transform = transforms.RandomVerticalFlip(p=p) - actual = transform(input) + actual = transform(inpt) - expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else input - expected = features.BoundingBox.new_like(input, data=expected_image_tensor) + expected_image_tensor = torch.tensor([0, 5, 5, 10]) if p == 1.0 else inpt + expected = features.BoundingBox.new_like(inpt, data=expected_image_tensor) assert_equal(expected, actual) assert actual.format == expected.format assert actual.image_size == expected.image_size diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index cd5cdc69836..4d6043fb1ce 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -69,3 +69,22 @@ def to_format(self, format: Union[str, BoundingBoxFormat]) -> BoundingBox: return BoundingBox.new_like( self, convert_bounding_box_format(self, old_format=self.format, new_format=format), format=format ) + + def horizontal_flip(self) -> BoundingBox: + output = self._F.horizontal_flip_bounding_box(self, format=self.format, image_size=self.image_size) + return BoundingBox.new_like(self, output) + + def vertical_flip(self) -> BoundingBox: + output = self._F.vertical_flip_bounding_box(self, format=self.format, image_size=self.image_size) + return BoundingBox.new_like(self, output) + + def resize(self, size, *, interpolation, max_size, antialias) -> BoundingBox: + interpolation, antialias # unused + output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) + return BoundingBox.new_like(self, output) + + def center_crop(self, output_size) -> BoundingBox: + output = self._F.center_crop_bounding_box( + self, format=self.format, output_size=output_size, image_size=self.image_size + ) + return BoundingBox.new_like(self, output) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index f8026b4d34d..3e1b28f1057 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -7,7 +7,42 @@ F = TypeVar("F", bound="_Feature") -class _Feature(torch.Tensor): +class _TransformsMixin: + def __init__(self, *args, **kwargs): + super().__init__() + + # To avoid circular dependency between features and transforms + from ..transforms import functional as F + + self._F = F + + def horizontal_flip(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def vertical_flip(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resize(self, size, *, interpolation, max_size, antialias): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def center_crop(self, output_size): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + +class _Feature(_TransformsMixin, torch.Tensor): def __new__( cls: Type[F], data: Any, diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 9206a844b6d..caa4f0f07a4 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -109,3 +109,21 @@ def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image: # TODO: this is useful for developing and debugging but we should remove or at least revisit this before we # promote this out of the prototype state return Image.new_like(self, draw_bounding_boxes(self, bounding_box.to_format("xyxy").view(-1, 4), **kwargs)) + + def horizontal_flip(self) -> Image: + output = self._F.horizontal_flip_image_tensor(self) + return Image.new_like(self, output) + + def vertical_flip(self) -> Image: + output = self._F.vertical_flip_image_tensor(self) + return Image.new_like(self, output) + + def resize(self, size, *, interpolation, max_size, antialias) -> Image: + output = self._F.resize_image_tensor( + self, size, interpolation=interpolation, max_size=max_size, antialias=antialias + ) + return Image.new_like(self, output) + + def center_crop(self, output_size) -> Image: + output = self._F.center_crop_image_tensor(self, output_size=output_size) + return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index dc41697ae9b..4c1e2f4b514 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -1,5 +1,22 @@ +from __future__ import annotations + from ._feature import _Feature class SegmentationMask(_Feature): - pass + def horizontal_flip(self) -> SegmentationMask: + output = self._F.horizontal_flip_segmentation_mask(self) + return SegmentationMask.new_like(self, output) + + def vertical_flip(self) -> SegmentationMask: + output = self._F.vertical_flip_segmentation_mask(self) + return SegmentationMask.new_like(self, output) + + def resize(self, size, *, interpolation, max_size, antialias) -> SegmentationMask: + interpolation, antialias # unused + output = self._F.resize_segmentation_mask(self, size, max_size=max_size) + return SegmentationMask.new_like(self, output) + + def center_crop(self, output_size) -> SegmentationMask: + output = self._F.center_crop_segmentation_mask(self, output_size=output_size) + return SegmentationMask.new_like(self, output) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 0487a71416e..b6951837046 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -2,14 +2,14 @@ import math import numbers import warnings -from typing import Any, Dict, List, Union, Sequence, Tuple, cast +from typing import Any, Dict, List, Optional, Union, Sequence, Tuple, cast import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms.functional import pil_to_tensor, InterpolationMode -from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int +from torchvision.transforms.transforms import _setup_size from typing_extensions import Literal from ._transform import _RandomApplyTransform @@ -17,41 +17,27 @@ class RandomHorizontalFlip(_RandomApplyTransform): - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.horizontal_flip_image_tensor(input) - return features.Image.new_like(input, output) - elif isinstance(input, features.SegmentationMask): - output = F.horizontal_flip_segmentation_mask(input) - return features.SegmentationMask.new_like(input, output) - elif isinstance(input, features.BoundingBox): - output = F.horizontal_flip_bounding_box(input, format=input.format, image_size=input.image_size) - return features.BoundingBox.new_like(input, output) - elif isinstance(input, PIL.Image.Image): - return F.horizontal_flip_image_pil(input) - elif is_simple_tensor(input): - return F.horizontal_flip_image_tensor(input) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.horizontal_flip() + elif isinstance(inpt, PIL.Image.Image): + return F.horizontal_flip_image_pil(inpt) + elif is_simple_tensor(inpt): + return F.horizontal_flip_image_tensor(inpt) else: - return input + return inpt class RandomVerticalFlip(_RandomApplyTransform): - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.vertical_flip_image_tensor(input) - return features.Image.new_like(input, output) - elif isinstance(input, features.SegmentationMask): - output = F.vertical_flip_segmentation_mask(input) - return features.SegmentationMask.new_like(input, output) - elif isinstance(input, features.BoundingBox): - output = F.vertical_flip_bounding_box(input, format=input.format, image_size=input.image_size) - return features.BoundingBox.new_like(input, output) - elif isinstance(input, PIL.Image.Image): - return F.vertical_flip_image_pil(input) - elif is_simple_tensor(input): - return F.vertical_flip_image_tensor(input) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.vertical_flip() + elif isinstance(inpt, PIL.Image.Image): + return F.vertical_flip_image_pil(inpt) + elif is_simple_tensor(inpt): + return F.vertical_flip_image_tensor(inpt) else: - return input + return inpt class Resize(Transform): @@ -59,27 +45,40 @@ def __init__( self, size: Union[int, Sequence[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, ) -> None: super().__init__() self.size = [size] if isinstance(size, int) else list(size) self.interpolation = interpolation - - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.resize_image_tensor(input, self.size, interpolation=self.interpolation) - return features.Image.new_like(input, output) - elif isinstance(input, features.SegmentationMask): - output = F.resize_segmentation_mask(input, self.size) - return features.SegmentationMask.new_like(input, output) - elif isinstance(input, features.BoundingBox): - output = F.resize_bounding_box(input, self.size, image_size=input.image_size) - return features.BoundingBox.new_like(input, output, image_size=cast(Tuple[int, int], tuple(self.size))) - elif isinstance(input, PIL.Image.Image): - return F.resize_image_pil(input, self.size, interpolation=self.interpolation) - elif is_simple_tensor(input): - return F.resize_image_tensor(input, self.size, interpolation=self.interpolation) + self.max_size = max_size + self.antialias = antialias + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.resize( + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + antialias=self.antialias, + ) + elif isinstance(inpt, PIL.Image.Image): + return F.resize_image_pil( + inpt, + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + ) + elif is_simple_tensor(inpt): + return F.resize_image_tensor( + inpt, + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + antialias=self.antialias, + ) else: - return input + return inpt class CenterCrop(Transform): @@ -87,22 +86,15 @@ def __init__(self, output_size: List[int]): super().__init__() self.output_size = output_size - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.center_crop_image_tensor(input, self.output_size) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.center_crop_image_tensor(input, self.output_size) - elif isinstance(input, PIL.Image.Image): - return F.center_crop_image_pil(input, self.output_size) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.center_crop(self.output_size) + elif is_simple_tensor(inpt): + return F.center_crop_image_tensor(inpt, self.output_size) + elif isinstance(inpt, PIL.Image.Image): + return F.center_crop_image_pil(inpt, self.output_size) else: - return input - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - return super().forward(sample) + return inpt class RandomResizedCrop(Transform): @@ -125,14 +117,6 @@ def __init__( if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") - # Backward compatibility with integer value - if isinstance(interpolation, int): - warnings.warn( - "Argument interpolation should be of type InterpolationMode instead of int. " - "Please, use InterpolationMode enum." - ) - interpolation = _interpolation_modes_from_int(interpolation) - self.size = size self.scale = scale self.ratio = ratio @@ -177,21 +161,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.resized_crop_image_tensor( - input, **params, size=list(self.size), interpolation=self.interpolation - ) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.resized_crop_image_tensor(input, **params, size=list(self.size), interpolation=self.interpolation) - elif isinstance(input, PIL.Image.Image): - return F.resized_crop_image_pil(input, **params, size=list(self.size), interpolation=self.interpolation) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features.Image): + output = F.resized_crop_image_tensor(inpt, **params, size=list(self.size), interpolation=self.interpolation) + return features.Image.new_like(inpt, output) + elif is_simple_tensor(inpt): + return F.resized_crop_image_tensor(inpt, **params, size=list(self.size), interpolation=self.interpolation) + elif isinstance(inpt, PIL.Image.Image): + return F.resized_crop_image_pil(inpt, **params, size=list(self.size), interpolation=self.interpolation) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) @@ -213,19 +195,19 @@ def __init__(self, size: Union[int, Sequence[int]]) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.five_crop_image_tensor(input, self.size) - return MultiCropResult(features.Image.new_like(input, o) for o in output) - elif is_simple_tensor(input): - return MultiCropResult(F.five_crop_image_tensor(input, self.size)) - elif isinstance(input, PIL.Image.Image): - return MultiCropResult(F.five_crop_image_pil(input, self.size)) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features.Image): + output = F.five_crop_image_tensor(inpt, self.size) + return MultiCropResult(features.Image.new_like(inpt, o) for o in output) + elif is_simple_tensor(inpt): + return MultiCropResult(F.five_crop_image_tensor(inpt, self.size)) + elif isinstance(inpt, PIL.Image.Image): + return MultiCropResult(F.five_crop_image_pil(inpt, self.size)) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) @@ -237,26 +219,26 @@ def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.vertical_flip = vertical_flip - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip) - return MultiCropResult(features.Image.new_like(input, o) for o in output) - elif is_simple_tensor(input): - return MultiCropResult(F.ten_crop_image_tensor(input, self.size)) - elif isinstance(input, PIL.Image.Image): - return MultiCropResult(F.ten_crop_image_pil(input, self.size)) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features.Image): + output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip) + return MultiCropResult(features.Image.new_like(inpt, o) for o in output) + elif is_simple_tensor(inpt): + return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size)) + elif isinstance(inpt, PIL.Image.Image): + return MultiCropResult(F.ten_crop_image_pil(inpt, self.size)) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) class BatchMultiCrop(Transform): - def forward(self, *inputs: Any) -> Any: + def forward(self, *inpts: Any) -> Any: # This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one # significant difference: # Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from @@ -280,7 +262,7 @@ def apply_recursively(obj: Any) -> Any: else: return obj - return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) + return apply_recursively(inpts if len(inpts) > 1 else inpts[0]) class Pad(Transform): @@ -309,13 +291,14 @@ def __init__( self.fill = fill self.padding_mode = padding_mode - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image) or is_simple_tensor(input): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features.Image) or is_simple_tensor(inpt): # PyTorch's pad supports only integers on fill. So we need to overwrite the colour - output = F.pad_image_tensor(input, params["padding"], fill=0, padding_mode="constant") - left, top, right, bottom = params["padding"] - fill = torch.tensor(params["fill"], dtype=input.dtype, device=input.device).to().view(-1, 1, 1) + output = F.pad_image_tensor(inpt, self.padding, fill=0, padding_mode="constant") + + left, top, right, bottom = self.padding + fill = torch.tensor(self.fill, dtype=inpt.dtype, device=inpt.device).to().view(-1, 1, 1) if top > 0: output[..., :top, :] = fill @@ -326,28 +309,28 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any: if right > 0: output[..., :, -right:] = fill - if isinstance(input, features.Image): - output = features.Image.new_like(input, output) + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output) return output - elif isinstance(input, PIL.Image.Image): + elif isinstance(inpt, PIL.Image.Image): return F.pad_image_pil( - input, - params["padding"], - fill=tuple(int(v) if input.mode != "F" else v for v in params["fill"]), + inpt, + self.padding, + fill=tuple(int(v) if inpt.mode != "F" else v for v in self.fill), padding_mode="constant", ) - elif isinstance(input, features.BoundingBox): - output = F.pad_bounding_box(input, params["padding"], format=input.format) + elif isinstance(inpt, features.BoundingBox): + output = F.pad_bounding_box(inpt, self.padding, format=inpt.format) - left, top, right, bottom = params["padding"] - height, width = input.image_size + left, top, right, bottom = self.padding + height, width = inpt.image_size height += top + bottom width += left + right - return features.BoundingBox.new_like(input, output, image_size=(height, width)) + return features.BoundingBox.new_like(inpt, output, image_size=(height, width)) else: - return input + return inpt class RandomZoomOut(_RandomApplyTransform): @@ -385,6 +368,6 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(padding=padding, fill=fill) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: transform = Pad(**params, padding_mode="constant") - return transform(input) + return transform(inpt) From 6a5e5ab19de2a020a5797f2e3fd5606b78a1ae3d Mon Sep 17 00:00:00 2001 From: vfdev Date: Thu, 23 Jun 2022 12:16:51 +0200 Subject: [PATCH 12/20] Update functional.py --- torchvision/transforms/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 691eff84426..2a4a7f1b6dd 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -2,7 +2,7 @@ import numbers import warnings from enum import Enum -from typing import List, Tuple, Any, Optional, Union +from typing import List, Tuple, Any, Optional import numpy as np import torch From b2ada459b27b8e9875c0305cf970cf3fbb76096b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 23 Jun 2022 10:36:08 +0000 Subject: [PATCH 13/20] Reverted fill/center order for rotate Other nits --- .../prototype/transforms/functional/_geometry.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 19085d2a974..95e094ad798 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -48,12 +48,11 @@ def resize_image_tensor( antialias: Optional[bool] = None, ) -> torch.Tensor: num_channels, old_height, old_width = get_dimensions_image_tensor(image) - size = _compute_output_size((old_height, old_width), size=size, max_size=max_size) - new_height, new_width = size + new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) batch_shape = image.shape[:-3] return _FT.resize( image.reshape((-1, num_channels, old_height, old_width)), - size=size, + size=[new_height, new_width], interpolation=interpolation.value, antialias=antialias, ).reshape(batch_shape + (num_channels, new_height, new_width)) @@ -83,8 +82,7 @@ def resize_bounding_box( bounding_box: torch.Tensor, size: List[int], image_size: Tuple[int, int], max_size: Optional[int] = None ) -> torch.Tensor: old_height, old_width = image_size - size = _compute_output_size(image_size, size=size, max_size=max_size) - new_height, new_width = size + new_height, new_width = _compute_output_size(image_size, size=size, max_size=max_size) ratios = torch.tensor((new_width / old_width, new_height / old_height), device=bounding_box.device) return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) @@ -330,8 +328,8 @@ def rotate_image_tensor( angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, - center: Optional[List[float]] = None, fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: From 5bc6a503863d59a10fb3b23653ac1f9e6d15fa14 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 27 Jun 2022 08:28:57 +0000 Subject: [PATCH 14/20] - Added Pad and WIP on Rotate --- test/test_prototype_transforms.py | 23 +-- .../prototype/features/_bounding_box.py | 29 +++- torchvision/prototype/features/_feature.py | 10 ++ torchvision/prototype/features/_image.py | 21 +++ .../prototype/features/_segmentation_mask.py | 11 ++ torchvision/prototype/transforms/_geometry.py | 140 ++++++++++++------ .../transforms/functional/_geometry.py | 5 +- torchvision/transforms/functional_tensor.py | 5 +- 8 files changed, 181 insertions(+), 63 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index de1ac950569..c4653fe2136 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -57,13 +57,13 @@ def parametrize_from_transforms(*transforms): make_segmentation_masks, ]: inpts = list(creation_fn()) - # try: - output = transform(inpts[0]) - # except TypeError: - # continue - # else: - # if output is inpts[0]: - # continue + try: + output = transform(inpts[0]) + except TypeError: + continue + else: + if output is inpts[0]: + continue transforms_with_inpts.append((transform, inpts)) @@ -72,12 +72,13 @@ def parametrize_from_transforms(*transforms): class TestSmoke: @parametrize_from_transforms( - # transforms.RandomErasing(p=1.0), + transforms.RandomErasing(p=1.0), transforms.Resize([16, 16]), transforms.CenterCrop([16, 16]), - # transforms.ConvertImageDtype(), - # transforms.RandomHorizontalFlip(), - # transforms.Pad(5), + transforms.RandomResizedCrop([16, 16]), + transforms.ConvertImageDtype(), + transforms.RandomHorizontalFlip(), + transforms.Pad(5), ) def test_common(self, transform, inpt): output = transform(inpt) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 4d6043fb1ce..56d400c51f5 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -81,10 +81,37 @@ def vertical_flip(self) -> BoundingBox: def resize(self, size, *, interpolation, max_size, antialias) -> BoundingBox: interpolation, antialias # unused output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) - return BoundingBox.new_like(self, output) + return BoundingBox.new_like(self, output, image_size=size) def center_crop(self, output_size) -> BoundingBox: output = self._F.center_crop_bounding_box( self, format=self.format, output_size=output_size, image_size=self.image_size ) + return BoundingBox.new_like(self, output, image_size=output_size) + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> BoundingBox: + # TODO: untested right now + interpolation, antialias # unused + output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) + return BoundingBox.new_like(self, output, image_size=size) + + def pad(self, padding, *, fill, padding_mode) -> BoundingBox: + fill # unused + if padding_mode not in ["constant"]: + raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") + + output = self._F.pad_bounding_box(self, padding, fill=fill, padding_mode=padding_mode) + + # Update output image size: + left, top, right, bottom = padding + height, width = self.image_size + height += top + bottom + width += left + right + + return BoundingBox.new_like(self, output, image_size=(height, width)) + + def rotate(self, angle, *, interpolation, expand, fill, center) -> BoundingBox: + output = self._F.rotate_bounding_box( + self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center + ) return BoundingBox.new_like(self, output) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 3e1b28f1057..e6fcf54ac3a 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -41,6 +41,16 @@ def resized_crop(self, top, left, height, width, *, size, interpolation, antiali # How dangerous to do this instead of raising an error ? return self + def pad(self, padding, *, fill, padding_mode): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def rotate(self, angle, *, interpolation, expand, fill, center): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + class _Feature(_TransformsMixin, torch.Tensor): def __new__( diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index caa4f0f07a4..a1eaf03207f 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -127,3 +127,24 @@ def resize(self, size, *, interpolation, max_size, antialias) -> Image: def center_crop(self, output_size) -> Image: output = self._F.center_crop_image_tensor(self, output_size=output_size) return Image.new_like(self, output) + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> Image: + output = self._F.resized_crop_image_tensor( + self, top, left, height, width, size=list(size), interpolation=interpolation, antialias=antialias + ) + return Image.new_like(self, output) + + def pad(self, padding, *, fill, padding_mode) -> Image: + # Previous message from previous implementation: + # PyTorch's pad supports only integers on fill. So we need to overwrite the colour + # vfdev-5: pytorch pad support both int and floats but keeps original dtyp + # if user pads int image with float pad, they need to cast the image first to float + # before padding. Let's remove previous manual float fill support. + output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) + return Image.new_like(self, output) + + def rotate(self, angle, *, interpolation, expand, fill, center) -> Image: + output = self._F.rotate_image_tensor( + self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center + ) + return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index 4c1e2f4b514..d2406464a0f 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -20,3 +20,14 @@ def resize(self, size, *, interpolation, max_size, antialias) -> SegmentationMas def center_crop(self, output_size) -> SegmentationMask: output = self._F.center_crop_segmentation_mask(self, output_size=output_size) return SegmentationMask.new_like(self, output) + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> SegmentationMask: + # TODO: untested right now + interpolation, antialias # unused + output = self._F.resized_crop_segmentation_mask(self, top, left, height, width, size=list(size)) + return SegmentationMask.new_like(self, output) + + def pad(self, padding, *, fill, padding_mode) -> SegmentationMask: + fill # unused + output = self._F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) + return SegmentationMask.new_like(self, output) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index b6951837046..11adc5a5a1b 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -9,7 +9,12 @@ from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F from torchvision.transforms.functional import pil_to_tensor, InterpolationMode -from torchvision.transforms.transforms import _setup_size + +# TODO: refactor _parse_pad_padding into +# torchvision.transforms.functional and update F_t.pad and F_pil.pad +# and remove redundancy +from torchvision.transforms.functional_tensor import _parse_pad_padding +from torchvision.transforms.transforms import _setup_size, _setup_angle, _check_sequence_input from typing_extensions import Literal from ._transform import _RandomApplyTransform @@ -104,6 +109,7 @@ def __init__( scale: Tuple[float, float] = (0.08, 1.0), ratio: Tuple[float, float] = (3.0 / 4.0, 4.0 / 3.0), interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = None, ) -> None: super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -121,8 +127,12 @@ def __init__( self.scale = scale self.ratio = ratio self.interpolation = interpolation + self.antialias = antialias def _get_params(self, sample: Any) -> Dict[str, Any]: + # vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples + # What if we have multiple images/bboxes/masks of different sizes ? + # TODO: let's support bbox or mask in samples without image image = query_image(sample) _, height, width = get_image_dimensions(image) area = height * width @@ -162,22 +172,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image): - output = F.resized_crop_image_tensor(inpt, **params, size=list(self.size), interpolation=self.interpolation) - return features.Image.new_like(inpt, output) + if isinstance(inpt, features._Feature): + antialias = False if self.antialias is None else self.antialias + return inpt.resized_crop(**params, size=self.size, interpolation=self.interpolation, antialias=antialias) elif is_simple_tensor(inpt): - return F.resized_crop_image_tensor(inpt, **params, size=list(self.size), interpolation=self.interpolation) + antialias = False if self.antialias is None else self.antialias + return F.resized_crop_image_tensor( + inpt, **params, size=list(self.size), interpolation=self.interpolation, antialias=antialias + ) elif isinstance(inpt, PIL.Image.Image): return F.resized_crop_image_pil(inpt, **params, size=list(self.size), interpolation=self.interpolation) else: return inpt - def forward(self, *inpts: Any) -> Any: - sample = inpts if len(inpts) > 1 else inpts[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - return super().forward(sample) - class MultiCropResult(list): """Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`. @@ -287,48 +294,32 @@ def __init__( f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" ) + padding = _parse_pad_padding(padding) self.padding = padding self.fill = fill self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image) or is_simple_tensor(inpt): - # PyTorch's pad supports only integers on fill. So we need to overwrite the colour - - output = F.pad_image_tensor(inpt, self.padding, fill=0, padding_mode="constant") - - left, top, right, bottom = self.padding - fill = torch.tensor(self.fill, dtype=inpt.dtype, device=inpt.device).to().view(-1, 1, 1) - - if top > 0: - output[..., :top, :] = fill - if left > 0: - output[..., :, :left] = fill - if bottom > 0: - output[..., -bottom:, :] = fill - if right > 0: - output[..., :, -right:] = fill - - if isinstance(inpt, features.Image): - output = features.Image.new_like(inpt, output) - - return output + if isinstance(inpt, features._Feature): + return inpt.pad( + self.padding, + fill=self.fill, + padding_mode=self.padding_mode, + ) elif isinstance(inpt, PIL.Image.Image): return F.pad_image_pil( inpt, self.padding, - fill=tuple(int(v) if inpt.mode != "F" else v for v in self.fill), - padding_mode="constant", + fill=self.fill, + padding_mode=self.padding_mode, + ) + elif is_simple_tensor(inpt): + return F.pad_image_tensor( + inpt, + self.padding, + fill=self.fill, + padding_mode=self.padding_mode, ) - elif isinstance(inpt, features.BoundingBox): - output = F.pad_bounding_box(inpt, self.padding, format=inpt.format) - - left, top, right, bottom = self.padding - height, width = inpt.image_size - height += top + bottom - width += left + right - - return features.BoundingBox.new_like(inpt, output, image_size=(height, width)) else: return inpt @@ -347,6 +338,8 @@ def __init__( if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError(f"Invalid canvas side range provided {side_range}.") + self.pad_op = Pad(0, padding_mode="constant") + def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) orig_c, orig_h, orig_w = get_image_dimensions(image) @@ -369,5 +362,62 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(padding=padding, fill=fill) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - transform = Pad(**params, padding_mode="constant") - return transform(inpt) + self.pad_op.padding = params["padding"] + self.pad_op.fill = params["fill"] + return self.pad_op(inpt) + + +class RandomRotation(Transform): + def __init__( + self, + degrees, + interpolation=InterpolationMode.NEAREST, + expand=False, + fill=0, + center=None, + ): + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + self.interpolation = interpolation + self.expand = expand + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fill = fill + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, sample: Any) -> Dict[str, Any]: + angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + return dict(angle=angle) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.rotate( + **params, + interpolation=self.interpolation, + expand=self.expand, + fill=self.fill, + center=self.center, + ) + elif is_simple_tensor(inpt): + return F.rotate_image_tensor( + inpt, + **params, + interpolation=self.interpolation, + expand=self.expand, + fill=self.fill, + center=self.center, + ) + elif isinstance(inpt, PIL.Image.Image): + return F.rotate_image_pil( + inpt, **params, interpolation=self.interpolation, expand=self.expand, fill=self.fill, center=self.center + ) + else: + return inpt diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 95e094ad798..3df10d4239d 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -45,7 +45,7 @@ def resize_image_tensor( size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, max_size: Optional[int] = None, - antialias: Optional[bool] = None, + antialias: bool = False, ) -> torch.Tensor: num_channels, old_height, old_width = get_dimensions_image_tensor(image) new_height, new_width = _compute_output_size((old_height, old_width), size=size, max_size=max_size) @@ -651,9 +651,10 @@ def resized_crop_image_tensor( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: bool = False, ) -> torch.Tensor: img = crop_image_tensor(img, top, left, height, width) - return resize_image_tensor(img, size, interpolation=interpolation) + return resize_image_tensor(img, size, interpolation=interpolation, antialias=antialias) def resized_crop_image_pil( diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index acc8d3ae3e1..35618da9339 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -430,16 +430,13 @@ def resize( img: Tensor, size: List[int], interpolation: str = "bilinear", - antialias: Optional[bool] = None, + antialias: bool = False, ) -> Tensor: _assert_image_tensor(img) if isinstance(size, tuple): size = list(size) - if antialias is None: - antialias = False - if antialias and interpolation not in ["bilinear", "bicubic"]: raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") From 6a5201ac32a63bd06c2d103d405bee1e321ffab5 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 27 Jun 2022 15:14:22 +0000 Subject: [PATCH 15/20] Added more ops and mid-level functional API --- test/test_prototype_transforms.py | 2 +- .../prototype/features/_bounding_box.py | 27 +++- torchvision/prototype/features/_feature.py | 126 +++++++++++------- torchvision/prototype/features/_image.py | 46 +++++++ torchvision/prototype/features/_label.py | 13 +- .../prototype/features/_segmentation_mask.py | 26 ++++ torchvision/prototype/transforms/_augment.py | 68 ++++------ .../prototype/transforms/_auto_augment.py | 30 ++--- torchvision/prototype/transforms/_geometry.py | 71 +++------- .../transforms/functional/__init__.py | 11 +- .../transforms/functional/_augment.py | 44 +----- .../prototype/transforms/functional/_color.py | 42 ++++++ .../transforms/functional/_geometry.py | 116 ++++++++++++---- 13 files changed, 391 insertions(+), 231 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index c4653fe2136..fb7c7341992 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -59,7 +59,7 @@ def parametrize_from_transforms(*transforms): inpts = list(creation_fn()) try: output = transform(inpts[0]) - except TypeError: + except (TypeError, RuntimeError): continue else: if output is inpts[0]: diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index 56d400c51f5..bd6f04995cc 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -111,7 +111,32 @@ def pad(self, padding, *, fill, padding_mode) -> BoundingBox: return BoundingBox.new_like(self, output, image_size=(height, width)) def rotate(self, angle, *, interpolation, expand, fill, center) -> BoundingBox: + interpolation, fill # unused output = self._F.rotate_bounding_box( - self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center + self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center ) + # TODO: update output image size if expand is True + if expand: + raise RuntimeError("Not yet implemented") return BoundingBox.new_like(self, output) + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) -> BoundingBox: + interpolation, fill # unused + output = self._F.affine_bounding_box( + self, + angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return BoundingBox.new_like(self, output) + + def erase(self, *args) -> BoundingBox: + raise TypeError(f"Erase transformation does not support bounding boxes") + + def mixup(self, *args) -> BoundingBox: + raise TypeError(f"Mixup transformation does not support bounding boxes") + + def cutmix(self, *args) -> BoundingBox: + raise TypeError(f"Cutmix transformation does not support bounding boxes") diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index e6fcf54ac3a..64a62fdabea 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -7,52 +7,7 @@ F = TypeVar("F", bound="_Feature") -class _TransformsMixin: - def __init__(self, *args, **kwargs): - super().__init__() - - # To avoid circular dependency between features and transforms - from ..transforms import functional as F - - self._F = F - - def horizontal_flip(self): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - - def vertical_flip(self): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - - def resize(self, size, *, interpolation, max_size, antialias): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - - def center_crop(self, output_size): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - - def resized_crop(self, top, left, height, width, *, size, interpolation, antialias): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - - def pad(self, padding, *, fill, padding_mode): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - - def rotate(self, angle, *, interpolation, expand, fill, center): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - - -class _Feature(_TransformsMixin, torch.Tensor): +class _Feature(torch.Tensor): def __new__( cls: Type[F], data: Any, @@ -61,7 +16,7 @@ def __new__( device: Optional[Union[torch.device, str, int]] = None, requires_grad: bool = False, ) -> F: - return cast( + feature = cast( F, torch.Tensor._make_subclass( cast(_TensorBase, cls), @@ -70,6 +25,13 @@ def __new__( ), ) + # To avoid circular dependency between features and transforms + from ..transforms import functional + + feature._F = functional + + return feature + @classmethod def new_like( cls: Type[F], @@ -128,3 +90,73 @@ def __torch_function__( return cls.new_like(args[0], output, dtype=output.dtype, device=output.device) else: return output + + def horizontal_flip(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def vertical_flip(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resize(self, size, *, interpolation, max_size, antialias): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def center_crop(self, output_size): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def resized_crop(self, top, left, height, width, *, size, interpolation, antialias): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def pad(self, padding, *, fill, padding_mode): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def rotate(self, angle, *, interpolation, expand, fill, center): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_brightness(self, brightness_factor): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_saturation(self, saturation_factor): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_contrast(self, contrast_factor): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def erase(self, i, j, h, w, v): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def mixup(self, lam): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def cutmix(self, *, box, lam_adjusted): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index a1eaf03207f..eba86b02c89 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -148,3 +148,49 @@ def rotate(self, angle, *, interpolation, expand, fill, center) -> Image: self, angle, interpolation=interpolation, expand=expand, fill=fill, center=center ) return Image.new_like(self, output) + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) -> Image: + output = self._F.affine_image_tensor( + self, + angle, + translate=translate, + scale=scale, + shear=shear, + interpolation=interpolation, + fill=fill, + center=center, + ) + return Image.new_like(self, output) + + def adjust_brightness(self, brightness_factor) -> Image: + output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) + return Image.new_like(self, output) + + def adjust_saturation(self, saturation_factor) -> Image: + output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) + return Image.new_like(self, output) + + def adjust_contrast(self, contrast_factor) -> Image: + output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) + return Image.new_like(self, output) + + def erase(self, i, j, h, w, v) -> Image: + output = self._F.erase_image_tensor(self, i, j, h, w, v) + return Image.new_like(self, output) + + def mixup(self, lam: float) -> Image: + if self.ndim < 4: + raise ValueError("Need a batch of images") + output = self.clone() + output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) + return Image.new_like(self, output) + + def cutmix(self, *, box: Tuple[int, int, int, int], lam_adjusted: float) -> Image: + lam_adjusted # unused + if self.ndim < 4: + raise ValueError("Need a batch of images") + x1, y1, x2, y2 = box + image_rolled = self.roll(1, -4) + output = self.clone() + output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_label.py b/torchvision/prototype/features/_label.py index e3433b7bb08..1b1ded61c1c 100644 --- a/torchvision/prototype/features/_label.py +++ b/torchvision/prototype/features/_label.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Optional, Sequence, cast, Union +from typing import Any, Optional, Sequence, cast, Union, Tuple import torch from torchvision.prototype.utils._internal import apply_recursively @@ -77,3 +77,14 @@ def new_like( return super().new_like( other, data, categories=categories if categories is not None else other.categories, **kwargs ) + + def mixup(self, lam) -> OneHotLabel: + if self.ndim < 2: + raise ValueError("Need a batch of one hot labels") + output = self.clone() + output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam)) + return OneHotLabel.new_like(self, output) + + def cutmix(self, *, box: Tuple[int, int, int, int], lam_adjusted: float) -> OneHotLabel: + box # unused + return self.mixup(lam_adjusted) diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index d2406464a0f..41fba31d505 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -31,3 +31,29 @@ def pad(self, padding, *, fill, padding_mode) -> SegmentationMask: fill # unused output = self._F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) return SegmentationMask.new_like(self, output) + + def rotate(self, angle, *, interpolation, expand, fill, center) -> SegmentationMask: + interpolation, fill # unused + output = self._F.rotate_segmentation_mask(self, angle, expand=expand, center=center) + return SegmentationMask.new_like(self, output) + + def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) -> SegmentationMask: + interpolation, fill # unused + output = self._F.affine_segmentation_mask( + self, + angle, + translate=translate, + scale=scale, + shear=shear, + center=center, + ) + return SegmentationMask.new_like(self, output) + + def erase(self, *args) -> SegmentationMask: + raise TypeError(f"Erase transformation does not support segmentation masks") + + def mixup(self, *args) -> SegmentationMask: + raise TypeError(f"Mixup transformation does not support segmentation masks") + + def cutmix(self, *args) -> SegmentationMask: + raise TypeError(f"Cutmix transformation does not support segmentation masks") diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 82c5f52f1dc..6088ef4bfac 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -3,6 +3,7 @@ import warnings from typing import Any, Dict, Tuple +import PIL.Image import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, functional as F @@ -51,7 +52,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: if value is not None and not (len(value) in (1, img_c)): raise ValueError( - f"If value is a sequence, it should have either a single value or {img_c} (number of input channels)" + f"If value is a sequence, it should have either a single value or {img_c} (number of inpt channels)" ) area = img_h * img_w @@ -82,23 +83,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: else: i, j, h, w, v = 0, 0, img_h, img_w, image - return dict(zip("ijhwv", (i, j, h, w, v))) + return dict(i=i, j=j, h=h, w=w, v=v) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.erase_image_tensor(input, **params) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.erase_image_tensor(input, **params) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.erase(**params) + elif isinstance(inpt, PIL.Image.Image): + # Shouldn't we implement a fallback to tensor ? + raise RuntimeError("Not implemented") + elif isinstance(inpt, torch.Tensor): + return F.erase_image_tensor(inpt, **params) else: - return input - - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - - return super().forward(sample) + return inpt class RandomMixup(Transform): @@ -110,21 +106,15 @@ def __init__(self, *, alpha: float) -> None: def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.mixup_image_tensor(input, **params) - return features.Image.new_like(input, output) - elif isinstance(input, features.OneHotLabel): - output = F.mixup_one_hot_label(input, **params) - return features.OneHotLabel.new_like(input, output) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.mixup(**params) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - elif not has_all(sample, features.Image, features.OneHotLabel): + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] + if not has_all(sample, features.Image, features.OneHotLabel): raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") return super().forward(sample) @@ -158,20 +148,14 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.cutmix_image_tensor(input, box=params["box"]) - return features.Image.new_like(input, output) - elif isinstance(input, features.OneHotLabel): - output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"]) - return features.OneHotLabel.new_like(input, output) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.cutmix(**params) else: - return input + return inpt - def forward(self, *inputs: Any) -> Any: - sample = inputs if len(inputs) > 1 else inputs[0] - if has_any(sample, features.BoundingBox, features.SegmentationMask): - raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") - elif not has_all(sample, features.Image, features.OneHotLabel): + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] + if not has_all(sample, features.Image, features.OneHotLabel): raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") return super().forward(sample) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 7fc62423ab8..2b64b9cd517 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -106,9 +106,7 @@ def _apply_image_transform( if transform_id == "Identity": return image elif transform_id == "ShearX": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[0, 0], @@ -118,9 +116,7 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "ShearY": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[0, 0], @@ -130,9 +126,7 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "TranslateX": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[int(magnitude), 0], @@ -142,9 +136,7 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "TranslateY": - return self._dispatch_image_kernels( - F.affine_image_tensor, - F.affine_image_pil, + return F.affine( image, angle=0.0, translate=[0, int(magnitude)], @@ -154,25 +146,19 @@ def _apply_image_transform( fill=fill, ) elif transform_id == "Rotate": - return self._dispatch_image_kernels(F.rotate_image_tensor, F.rotate_image_pil, image, angle=magnitude) + return F.rotate(image, angle=magnitude) elif transform_id == "Brightness": - return self._dispatch_image_kernels( - F.adjust_brightness_image_tensor, - F.adjust_brightness_image_pil, + return F.adjust_brightness( image, brightness_factor=1.0 + magnitude, ) elif transform_id == "Color": - return self._dispatch_image_kernels( - F.adjust_saturation_image_tensor, - F.adjust_saturation_image_pil, + return F.adjust_saturation( image, saturation_factor=1.0 + magnitude, ) elif transform_id == "Contrast": - return self._dispatch_image_kernels( - F.adjust_contrast_image_tensor, F.adjust_contrast_image_pil, image, contrast_factor=1.0 + magnitude - ) + return F.adjust_contrast(image, contrast_factor=1.0 + magnitude) elif transform_id == "Sharpness": return self._dispatch_image_kernels( F.adjust_sharpness_image_tensor, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 11adc5a5a1b..f2810d9df24 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -23,26 +23,12 @@ class RandomHorizontalFlip(_RandomApplyTransform): def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features._Feature): - return inpt.horizontal_flip() - elif isinstance(inpt, PIL.Image.Image): - return F.horizontal_flip_image_pil(inpt) - elif is_simple_tensor(inpt): - return F.horizontal_flip_image_tensor(inpt) - else: - return inpt + return F.horizontal_flip(inpt) class RandomVerticalFlip(_RandomApplyTransform): def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features._Feature): - return inpt.vertical_flip() - elif isinstance(inpt, PIL.Image.Image): - return F.vertical_flip_image_pil(inpt) - elif is_simple_tensor(inpt): - return F.vertical_flip_image_tensor(inpt) - else: - return inpt + return F.vertical_flip(inpt) class Resize(Transform): @@ -74,7 +60,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: interpolation=self.interpolation, max_size=self.max_size, ) - elif is_simple_tensor(inpt): + elif isinstance(inpt, torch.Tensor): return F.resize_image_tensor( inpt, self.size, @@ -94,10 +80,10 @@ def __init__(self, output_size: List[int]): def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, features._Feature): return inpt.center_crop(self.output_size) - elif is_simple_tensor(inpt): - return F.center_crop_image_tensor(inpt, self.output_size) elif isinstance(inpt, PIL.Image.Image): return F.center_crop_image_pil(inpt, self.output_size) + elif isinstance(inpt, torch.Tensor): + return F.center_crop_image_tensor(inpt, self.output_size) else: return inpt @@ -175,13 +161,13 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, features._Feature): antialias = False if self.antialias is None else self.antialias return inpt.resized_crop(**params, size=self.size, interpolation=self.interpolation, antialias=antialias) - elif is_simple_tensor(inpt): + elif isinstance(inpt, PIL.Image.Image): + return F.resized_crop_image_pil(inpt, **params, size=list(self.size), interpolation=self.interpolation) + elif isinstance(inpt, torch.Tensor): antialias = False if self.antialias is None else self.antialias return F.resized_crop_image_tensor( inpt, **params, size=list(self.size), interpolation=self.interpolation, antialias=antialias ) - elif isinstance(inpt, PIL.Image.Image): - return F.resized_crop_image_pil(inpt, **params, size=list(self.size), interpolation=self.interpolation) else: return inpt @@ -206,10 +192,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, features.Image): output = F.five_crop_image_tensor(inpt, self.size) return MultiCropResult(features.Image.new_like(inpt, o) for o in output) - elif is_simple_tensor(inpt): - return MultiCropResult(F.five_crop_image_tensor(inpt, self.size)) elif isinstance(inpt, PIL.Image.Image): return MultiCropResult(F.five_crop_image_pil(inpt, self.size)) + elif isinstance(inpt, torch.Tensor): + return MultiCropResult(F.five_crop_image_tensor(inpt, self.size)) else: return inpt @@ -230,10 +216,10 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if isinstance(inpt, features.Image): output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip) return MultiCropResult(features.Image.new_like(inpt, o) for o in output) - elif is_simple_tensor(inpt): - return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size)) elif isinstance(inpt, PIL.Image.Image): return MultiCropResult(F.ten_crop_image_pil(inpt, self.size)) + elif isinstance(inpt, torch.Tensor): + return MultiCropResult(F.ten_crop_image_tensor(inpt, self.size)) else: return inpt @@ -313,7 +299,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, padding_mode=self.padding_mode, ) - elif is_simple_tensor(inpt): + elif isinstance(inpt, torch.Tensor): return F.pad_image_tensor( inpt, self.padding, @@ -398,26 +384,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(angle=angle) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features._Feature): - return inpt.rotate( - **params, - interpolation=self.interpolation, - expand=self.expand, - fill=self.fill, - center=self.center, - ) - elif is_simple_tensor(inpt): - return F.rotate_image_tensor( - inpt, - **params, - interpolation=self.interpolation, - expand=self.expand, - fill=self.fill, - center=self.center, - ) - elif isinstance(inpt, PIL.Image.Image): - return F.rotate_image_pil( - inpt, **params, interpolation=self.interpolation, expand=self.expand, fill=self.fill, center=self.center - ) - else: - return inpt + return F.rotate( + inpt, + **params, + interpolation=self.interpolation, + expand=self.expand, + fill=self.fill, + center=self.center, + ) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 2a6c7dce516..4569a01c91f 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -7,16 +7,15 @@ from ._augment import ( erase_image_tensor, - mixup_image_tensor, - mixup_one_hot_label, - cutmix_image_tensor, - cutmix_one_hot_label, ) from ._color import ( + adjust_brightness, adjust_brightness_image_tensor, adjust_brightness_image_pil, + adjust_contrast, adjust_contrast_image_tensor, adjust_contrast_image_pil, + adjust_saturation, adjust_saturation_image_tensor, adjust_saturation_image_pil, adjust_sharpness_image_tensor, @@ -37,6 +36,7 @@ adjust_gamma_image_pil, ) from ._geometry import ( + horizontal_flip, horizontal_flip_bounding_box, horizontal_flip_image_tensor, horizontal_flip_image_pil, @@ -53,10 +53,12 @@ resized_crop_image_tensor, resized_crop_image_pil, resized_crop_segmentation_mask, + affine, affine_bounding_box, affine_image_tensor, affine_image_pil, affine_segmentation_mask, + rotate, rotate_bounding_box, rotate_image_tensor, rotate_image_pil, @@ -73,6 +75,7 @@ perspective_image_tensor, perspective_image_pil, perspective_segmentation_mask, + vertical_flip, vertical_flip_image_tensor, vertical_flip_image_pil, vertical_flip_bounding_box, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 5004ac550dd..3920d1b3065 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,45 +1,13 @@ -from typing import Tuple - -import torch from torchvision.transforms import functional_tensor as _FT erase_image_tensor = _FT.erase -def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: - input = input.clone() - return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) - - -def mixup_image_tensor(image_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - return _mixup_tensor(image_batch, -4, lam) - - -def mixup_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") - - return _mixup_tensor(one_hot_label_batch, -2, lam) - - -def cutmix_image_tensor(image_batch: torch.Tensor, *, box: Tuple[int, int, int, int]) -> torch.Tensor: - if image_batch.ndim < 4: - raise ValueError("Need a batch of images") - - x1, y1, x2, y2 = box - image_rolled = image_batch.roll(1, -4) - - image_batch = image_batch.clone() - image_batch[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] - return image_batch - - -def cutmix_one_hot_label(one_hot_label_batch: torch.Tensor, *, lam_adjusted: float) -> torch.Tensor: - if one_hot_label_batch.ndim < 2: - raise ValueError("Need a batch of one hot labels") +# TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels. +# Like the mixup and cutmix stuff - return _mixup_tensor(one_hot_label_batch, -2, lam_adjusted) +# This function is copy-pasted to Image and OneHotLabel and may be refactored +# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: +# input = input.clone() +# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index fa632d7df58..8046934c678 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,14 +1,56 @@ +from typing import Any + +import PIL.Image +import torch +from torchvision.prototype import features from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP + adjust_brightness_image_tensor = _FT.adjust_brightness adjust_brightness_image_pil = _FP.adjust_brightness + +def adjust_brightness(inpt: Any, brightness_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_brightness(brightness_factor=brightness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_brightness_image_pil(inpt, brightness_factor=brightness_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_brightness_image_tensor(inpt, brightness_factor=brightness_factor) + else: + return inpt + + adjust_saturation_image_tensor = _FT.adjust_saturation adjust_saturation_image_pil = _FP.adjust_saturation + +def adjust_saturation(inpt: Any, saturation_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_saturation(saturation_factor=saturation_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_saturation_image_pil(inpt, saturation_factor=saturation_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_saturation_image_tensor(inpt, saturation_factor=saturation_factor) + else: + return inpt + + adjust_contrast_image_tensor = _FT.adjust_contrast adjust_contrast_image_pil = _FP.adjust_contrast + +def adjust_contrast(inpt: Any, contrast_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_contrast(contrast_factor=contrast_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_contrast_image_pil(inpt, contrast_factor=contrast_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_contrast_image_tensor(inpt, contrast_factor=contrast_factor) + else: + return inpt + + adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 3df10d4239d..f0ceec6d96e 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,6 +1,6 @@ import numbers import warnings -from typing import Tuple, List, Optional, Sequence, Union +from typing import Any, Tuple, List, Optional, Sequence, Union import PIL.Image import torch @@ -40,6 +40,52 @@ def horizontal_flip_bounding_box( ).view(shape) +def horizontal_flip(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.horizontal_flip() + elif isinstance(inpt, PIL.Image.Image): + return horizontal_flip_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return horizontal_flip_image_tensor(inpt) + else: + return inpt + + +vertical_flip_image_tensor = _FT.vflip +vertical_flip_image_pil = _FP.vflip + + +def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: + return vertical_flip_image_tensor(segmentation_mask) + + +def vertical_flip_bounding_box( + bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] +) -> torch.Tensor: + shape = bounding_box.shape + + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] + + return convert_bounding_box_format( + bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(shape) + + +def vertical_flip(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.vertical_flip() + elif isinstance(inpt, PIL.Image.Image): + return vertical_flip_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return vertical_flip_image_tensor(inpt) + else: + return inpt + + def resize_image_tensor( image: torch.Tensor, size: List[int], @@ -87,30 +133,6 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) -vertical_flip_image_tensor = _FT.vflip -vertical_flip_image_pil = _FP.vflip - - -def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: - return vertical_flip_image_tensor(segmentation_mask) - - -def vertical_flip_bounding_box( - bounding_box: torch.Tensor, format: features.BoundingBoxFormat, image_size: Tuple[int, int] -) -> torch.Tensor: - shape = bounding_box.shape - - bounding_box = convert_bounding_box_format( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY - ).view(-1, 4) - - bounding_box[:, [1, 3]] = image_size[0] - bounding_box[:, [3, 1]] - - return convert_bounding_box_format( - bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False - ).view(shape) - - def _affine_parse_args( angle: float, translate: List[float], @@ -323,6 +345,27 @@ def affine_segmentation_mask( ) +def affine( + inpt: Any, + angle: float, *, + translate: List[float], + scale: float, + shear: List[float], + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> Any: + kwargs = dict(translate=translate, scale=scale, shear=shear, interpolation=interpolation, fill=fill, center=center) + if isinstance(inpt, features._Feature): + return inpt.affine(angle, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return affine_image_pil(inpt, angle, **kwargs) + elif isinstance(inpt, torch.Tensor): + return affine_image_tensor(inpt, angle, **kwargs) + else: + return inpt + + def rotate_image_tensor( img: torch.Tensor, angle: float, @@ -402,6 +445,29 @@ def rotate_segmentation_mask( ) +def rotate(inpt: Any, + angle: float, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + expand: bool = False, + fill: Optional[List[float]] = None, + center: Optional[List[float]] = None, +) -> Any: + kwargs = dict( + interpolation=interpolation, + expand=expand, + fill=fill, + center=center, + ) + if isinstance(inpt, features._Feature): + return inpt.rotate(angle, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return rotate_image_pil(inpt, angle, **kwargs) + elif isinstance(inpt, torch.Tensor): + return rotate_image_tensor(inpt, angle, **kwargs) + else: + return inpt + + pad_image_tensor = _FT.pad pad_image_pil = _FP.pad From b659a08d7a1e602b07dcd9ccc72c98aaaae609bc Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 27 Jun 2022 16:09:58 +0000 Subject: [PATCH 16/20] - Added more color ops --- torchvision/prototype/features/_feature.py | 46 +++++++- torchvision/prototype/features/_image.py | 38 ++++++- torchvision/prototype/transforms/_augment.py | 29 ++--- .../prototype/transforms/_auto_augment.py | 31 ++---- torchvision/prototype/transforms/_geometry.py | 86 +++++++++++++++ .../transforms/functional/__init__.py | 16 ++- .../prototype/transforms/functional/_color.py | 103 +++++++++++++++++- .../transforms/functional/_geometry.py | 6 +- 8 files changed, 297 insertions(+), 58 deletions(-) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 64a62fdabea..1cebbff83cb 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -131,17 +131,57 @@ def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) # How dangerous to do this instead of raising an error ? return self - def adjust_brightness(self, brightness_factor): + def adjust_brightness(self, brightness_factor: float): # Just output itself # How dangerous to do this instead of raising an error ? return self - def adjust_saturation(self, saturation_factor): + def adjust_saturation(self, saturation_factor: float): # Just output itself # How dangerous to do this instead of raising an error ? return self - def adjust_contrast(self, contrast_factor): + def adjust_contrast(self, contrast_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_sharpness(self, sharpness_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_hue(self, hue_factor: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def adjust_gamma(self, gamma: float, gain: float = 1): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def posterize(self, bits: int): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def solarize(self, threshold: float): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def autocontrast(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def equalize(self): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + + def equalize(self): # Just output itself # How dangerous to do this instead of raising an error ? return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index eba86b02c89..668882ec639 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -162,18 +162,50 @@ def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) ) return Image.new_like(self, output) - def adjust_brightness(self, brightness_factor) -> Image: + def adjust_brightness(self, brightness_factor: float) -> Image: output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) return Image.new_like(self, output) - def adjust_saturation(self, saturation_factor) -> Image: + def adjust_saturation(self, saturation_factor: float) -> Image: output = self._F.adjust_saturation_image_tensor(self, saturation_factor=saturation_factor) return Image.new_like(self, output) - def adjust_contrast(self, contrast_factor) -> Image: + def adjust_contrast(self, contrast_factor: float) -> Image: output = self._F.adjust_contrast_image_tensor(self, contrast_factor=contrast_factor) return Image.new_like(self, output) + def adjust_sharpness(self, sharpness_factor: float) -> Image: + output = self._F.adjust_sharpness_image_tensor(self, sharpness_factor=sharpness_factor) + return Image.new_like(self, output) + + def adjust_hue(self, hue_factor: float) -> Image: + output = self._F.adjust_hue_image_tensor(self, hue_factor=hue_factor) + return Image.new_like(self, output) + + def adjust_gamma(self, gamma: float, gain: float = 1) -> Image: + output = self._F.adjust_gamma_image_tensor(self, gamma=gamma, gain=gain) + return Image.new_like(self, output) + + def posterize(self, bits: int) -> Image: + output = self._F.posterize_image_tensor(self, bits=bits) + return Image.new_like(self, output) + + def solarize(self, threshold: float) -> Image: + output = self._F.solarize_image_tensor(self, threshold=threshold) + return Image.new_like(self, output) + + def autocontrast(self) -> Image: + output = self._F.autocontrast_image_tensor(self) + return Image.new_like(self, output) + + def equalize(self) -> Image: + output = self._F.equalize_image_tensor(self) + return Image.new_like(self, output) + + def invert(self) -> Image: + output = self._F.invert_image_tensor(self) + return Image.new_like(self, output) + def erase(self, i, j, h, w, v) -> Image: output = self._F.erase_image_tensor(self, i, j, h, w, v) return Image.new_like(self, output) diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 6088ef4bfac..a32ba5c9f26 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -97,12 +97,20 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt -class RandomMixup(Transform): +class _BaseMixupCutmix(Transform): def __init__(self, *, alpha: float) -> None: super().__init__() self.alpha = alpha self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) + def forward(self, *inpts: Any) -> Any: + sample = inpts if len(inpts) > 1 else inpts[0] + if not has_all(sample, features.Image, features.OneHotLabel): + raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + return super().forward(sample) + + +class RandomMixup(_BaseMixupCutmix): def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(lam=float(self._dist.sample(()))) @@ -112,19 +120,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: else: return inpt - def forward(self, *inpts: Any) -> Any: - sample = inpts if len(inpts) > 1 else inpts[0] - if not has_all(sample, features.Image, features.OneHotLabel): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") - return super().forward(sample) - - -class RandomCutmix(Transform): - def __init__(self, *, alpha: float) -> None: - super().__init__() - self.alpha = alpha - self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) +class RandomCutmix(_BaseMixupCutmix): def _get_params(self, sample: Any) -> Dict[str, Any]: lam = float(self._dist.sample(())) @@ -153,9 +150,3 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return inpt.cutmix(**params) else: return inpt - - def forward(self, *inpts: Any) -> Any: - sample = inpts if len(inpts) > 1 else inpts[0] - if not has_all(sample, features.Image, features.OneHotLabel): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") - return super().forward(sample) diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index 2b64b9cd517..3c43166ad96 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -148,38 +148,23 @@ def _apply_image_transform( elif transform_id == "Rotate": return F.rotate(image, angle=magnitude) elif transform_id == "Brightness": - return F.adjust_brightness( - image, - brightness_factor=1.0 + magnitude, - ) + return F.adjust_brightness(image, brightness_factor=1.0 + magnitude) elif transform_id == "Color": - return F.adjust_saturation( - image, - saturation_factor=1.0 + magnitude, - ) + return F.adjust_saturation(image, saturation_factor=1.0 + magnitude) elif transform_id == "Contrast": return F.adjust_contrast(image, contrast_factor=1.0 + magnitude) elif transform_id == "Sharpness": - return self._dispatch_image_kernels( - F.adjust_sharpness_image_tensor, - F.adjust_sharpness_image_pil, - image, - sharpness_factor=1.0 + magnitude, - ) + return F.adjust_sharpness(image, sharpness_factor=1.0 + magnitude) elif transform_id == "Posterize": - return self._dispatch_image_kernels( - F.posterize_image_tensor, F.posterize_image_pil, image, bits=int(magnitude) - ) + return F.posterize(image, bits=int(magnitude)) elif transform_id == "Solarize": - return self._dispatch_image_kernels( - F.solarize_image_tensor, F.solarize_image_pil, image, threshold=magnitude - ) + return F.solarize(image, threshold=magnitude) elif transform_id == "AutoContrast": - return self._dispatch_image_kernels(F.autocontrast_image_tensor, F.autocontrast_image_pil, image) + return F.autocontrast(image) elif transform_id == "Equalize": - return self._dispatch_image_kernels(F.equalize_image_tensor, F.equalize_image_pil, image) + return F.equalize(image) elif transform_id == "Invert": - return self._dispatch_image_kernels(F.invert_image_tensor, F.invert_image_pil, image) + return F.invert(image) else: raise ValueError(f"No transform available for {transform_id}") diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index f2810d9df24..288e88b3aed 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -392,3 +392,89 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, center=self.center, ) + + +class RandomAffine(Transform): + def __init__( + self, + degrees, + translate=None, + scale=None, + shear=None, + interpolation=InterpolationMode.NEAREST, + fill=0, + center=None, + ): + super().__init__() + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) + if translate is not None: + _check_sequence_input(translate, "translate", req_sizes=(2,)) + for t in translate: + if not (0.0 <= t <= 1.0): + raise ValueError("translation values should be between 0 and 1") + self.translate = translate + if scale is not None: + _check_sequence_input(scale, "scale", req_sizes=(2,)) + for s in scale: + if s <= 0: + raise ValueError("scale values should be positive") + self.scale = scale + + if shear is not None: + self.shear = _setup_angle(shear, name="shear", req_sizes=(2, 4)) + else: + self.shear = shear + + self.interpolation = interpolation + + if fill is None: + fill = 0 + elif not isinstance(fill, (Sequence, numbers.Number)): + raise TypeError("Fill should be either a sequence or a number.") + + self.fill = fill + + if center is not None: + _check_sequence_input(center, "center", req_sizes=(2,)) + + self.center = center + + def _get_params(self, sample: Any) -> Dict[str, Any]: + + # Get image size + # TODO: make it work with bboxes and segm masks + image = query_image(sample) + _, height, width = get_image_dimensions(image) + + angle = float(torch.empty(1).uniform_(float(self.degrees[0]), float(self.degrees[1])).item()) + if self.translate is not None: + max_dx = float(self.translate[0] * width) + max_dy = float(self.translate[1] * height) + tx = int(round(torch.empty(1).uniform_(-max_dx, max_dx).item())) + ty = int(round(torch.empty(1).uniform_(-max_dy, max_dy).item())) + translations = (tx, ty) + else: + translations = (0, 0) + + if self.scale is not None: + scale = float(torch.empty(1).uniform_(self.scale[0], self.scale[1]).item()) + else: + scale = 1.0 + + shear_x = shear_y = 0.0 + if self.shear is not None: + shear_x = float(torch.empty(1).uniform_(self.shear[0], self.shear[1]).item()) + if len(self.shear) == 4: + shear_y = float(torch.empty(1).uniform_(self.shear[2], self.shear[3]).item()) + + shear = (shear_x, shear_y) + return dict(angle=angle, translations=translations, scale=scale, shear=shear) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.affine( + inpt, + **params, + interpolation=self.interpolation, + fill=self.fill, + center=self.center, + ) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 4569a01c91f..4ff938ec684 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -18,22 +18,30 @@ adjust_saturation, adjust_saturation_image_tensor, adjust_saturation_image_pil, + adjust_sharpness, adjust_sharpness_image_tensor, adjust_sharpness_image_pil, + adjust_hue, + adjust_hue_image_tensor, + adjust_hue_image_pil, + adjust_gamma, + adjust_gamma_image_tensor, + adjust_gamma_image_pil, + posterize, posterize_image_tensor, posterize_image_pil, + solarize, solarize_image_tensor, solarize_image_pil, + autocontrast, autocontrast_image_tensor, autocontrast_image_pil, + equalize, equalize_image_tensor, equalize_image_pil, + invert, invert_image_tensor, invert_image_pil, - adjust_hue_image_tensor, - adjust_hue_image_pil, - adjust_gamma_image_tensor, - adjust_gamma_image_pil, ) from ._geometry import ( horizontal_flip, diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 8046934c678..f8016b43a36 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -54,23 +54,118 @@ def adjust_contrast(inpt: Any, contrast_factor: float) -> Any: adjust_sharpness_image_tensor = _FT.adjust_sharpness adjust_sharpness_image_pil = _FP.adjust_sharpness + +def adjust_sharpness(inpt: Any, sharpness_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_sharpness(sharpness_factor=sharpness_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_sharpness_image_tensor(inpt, sharpness_factor=sharpness_factor) + else: + return inpt + + +adjust_hue_image_tensor = _FT.adjust_hue +adjust_hue_image_pil = _FP.adjust_hue + + +def adjust_hue(inpt: Any, hue_factor: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_hue(hue_factor=hue_factor) + elif isinstance(inpt, PIL.Image.Image): + return adjust_hue_image_pil(inpt, hue_factor=hue_factor) + elif isinstance(inpt, torch.Tensor): + return adjust_hue_image_tensor(inpt, hue_factor=hue_factor) + else: + return inpt + + +adjust_gamma_image_tensor = _FT.adjust_gamma +adjust_gamma_image_pil = _FP.adjust_gamma + + +def adjust_gamma(inpt: Any, gamma: float, gain: float = 1) -> Any: + if isinstance(inpt, features._Feature): + return inpt.adjust_gamma(gamma=gamma, gain=gain) + elif isinstance(inpt, PIL.Image.Image): + return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) + elif isinstance(inpt, torch.Tensor): + return adjust_gamma_image_tensor(inpt, gamma=gamma, gain=gain) + else: + return inpt + + posterize_image_tensor = _FT.posterize posterize_image_pil = _FP.posterize + +def posterize(inpt: Any, bits: int) -> Any: + if isinstance(inpt, features._Feature): + return inpt.posterize(bits=bits) + elif isinstance(inpt, PIL.Image.Image): + return posterize_image_pil(inpt, bits=bits) + elif isinstance(inpt, torch.Tensor): + return posterize_image_tensor(inpt, bits=bits) + else: + return inpt + + solarize_image_tensor = _FT.solarize solarize_image_pil = _FP.solarize + +def solarize(inpt: Any, threshold: float) -> Any: + if isinstance(inpt, features._Feature): + return inpt.solarize(threshold=threshold) + elif isinstance(inpt, PIL.Image.Image): + return solarize_image_pil(inpt, threshold=threshold) + elif isinstance(inpt, torch.Tensor): + return solarize_image_tensor(inpt, threshold=threshold) + else: + return inpt + + autocontrast_image_tensor = _FT.autocontrast autocontrast_image_pil = _FP.autocontrast + +def autocontrast(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.autocontrast() + elif isinstance(inpt, PIL.Image.Image): + return autocontrast_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return autocontrast_image_tensor(inpt) + else: + return inpt + + equalize_image_tensor = _FT.equalize equalize_image_pil = _FP.equalize + +def equalize(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.equalize() + elif isinstance(inpt, PIL.Image.Image): + return equalize_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return equalize_image_tensor(inpt) + else: + return inpt + + invert_image_tensor = _FT.invert invert_image_pil = _FP.invert -adjust_hue_image_tensor = _FT.adjust_hue -adjust_hue_image_pil = _FP.adjust_hue -adjust_gamma_image_tensor = _FT.adjust_gamma -adjust_gamma_image_pil = _FP.adjust_gamma +def invert(inpt: Any) -> Any: + if isinstance(inpt, features._Feature): + return inpt.invert() + elif isinstance(inpt, PIL.Image.Image): + return invert_image_pil(inpt) + elif isinstance(inpt, torch.Tensor): + return invert_image_tensor(inpt) + else: + return inpt diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index f0ceec6d96e..9cad4106bc2 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -347,7 +347,8 @@ def affine_segmentation_mask( def affine( inpt: Any, - angle: float, *, + angle: float, + *, translate: List[float], scale: float, shear: List[float], @@ -445,7 +446,8 @@ def rotate_segmentation_mask( ) -def rotate(inpt: Any, +def rotate( + inpt: Any, angle: float, interpolation: InterpolationMode = InterpolationMode.NEAREST, expand: bool = False, From 7a8f9501e1356a8a83c1c83ed83941e595da0e3d Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 28 Jun 2022 12:27:05 +0000 Subject: [PATCH 17/20] Added more f-mid-level ops --- .../prototype/features/_bounding_box.py | 17 +- torchvision/prototype/features/_feature.py | 15 +- torchvision/prototype/features/_image.py | 8 + .../prototype/features/_segmentation_mask.py | 15 +- torchvision/prototype/transforms/_augment.py | 2 +- torchvision/prototype/transforms/_color.py | 151 +++++++----------- torchvision/prototype/transforms/_geometry.py | 84 ++-------- .../transforms/functional/__init__.py | 6 + .../transforms/functional/_geometry.py | 98 ++++++++++++ 9 files changed, 223 insertions(+), 173 deletions(-) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index bd6f04995cc..f18d7efae6c 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -83,6 +83,10 @@ def resize(self, size, *, interpolation, max_size, antialias) -> BoundingBox: output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) return BoundingBox.new_like(self, output, image_size=size) + def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: + output = self._F.crop_bounding_box(self, self.format, top, left) + return BoundingBox.new_like(self, output, image_size=(height, width)) + def center_crop(self, output_size) -> BoundingBox: output = self._F.center_crop_bounding_box( self, format=self.format, output_size=output_size, image_size=self.image_size @@ -100,7 +104,7 @@ def pad(self, padding, *, fill, padding_mode) -> BoundingBox: if padding_mode not in ["constant"]: raise ValueError(f"Padding mode '{padding_mode}' is not supported with bounding boxes") - output = self._F.pad_bounding_box(self, padding, fill=fill, padding_mode=padding_mode) + output = self._F.pad_bounding_box(self, padding, format=self.format) # Update output image size: left, top, right, bottom = padding @@ -132,11 +136,16 @@ def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) ) return BoundingBox.new_like(self, output) + def perspective(self, perspective_coeffs, *, interpolation, fill) -> BoundingBox: + interpolation, fill # unused + output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) + return BoundingBox.new_like(self, output) + def erase(self, *args) -> BoundingBox: - raise TypeError(f"Erase transformation does not support bounding boxes") + raise TypeError("Erase transformation does not support bounding boxes") def mixup(self, *args) -> BoundingBox: - raise TypeError(f"Mixup transformation does not support bounding boxes") + raise TypeError("Mixup transformation does not support bounding boxes") def cutmix(self, *args) -> BoundingBox: - raise TypeError(f"Cutmix transformation does not support bounding boxes") + raise TypeError("Cutmix transformation does not support bounding boxes") diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 1cebbff83cb..0d4e9321977 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -106,6 +106,11 @@ def resize(self, size, *, interpolation, max_size, antialias): # How dangerous to do this instead of raising an error ? return self + def crop(self, top: int, left: int, height: int, width: int): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + def center_crop(self, output_size): # Just output itself # How dangerous to do this instead of raising an error ? @@ -131,6 +136,11 @@ def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) # How dangerous to do this instead of raising an error ? return self + def perspective(self, perspective_coeffs, *, interpolation, fill): + # Just output itself + # How dangerous to do this instead of raising an error ? + return self + def adjust_brightness(self, brightness_factor: float): # Just output itself # How dangerous to do this instead of raising an error ? @@ -181,11 +191,6 @@ def equalize(self): # How dangerous to do this instead of raising an error ? return self - def equalize(self): - # Just output itself - # How dangerous to do this instead of raising an error ? - return self - def erase(self, i, j, h, w, v): # Just output itself # How dangerous to do this instead of raising an error ? diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 668882ec639..407e8fc07e9 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -124,6 +124,10 @@ def resize(self, size, *, interpolation, max_size, antialias) -> Image: ) return Image.new_like(self, output) + def crop(self, top: int, left: int, height: int, width: int) -> Image: + output = self._F.crop_image_tensor(self, top, left, height, width) + return Image.new_like(self, output) + def center_crop(self, output_size) -> Image: output = self._F.center_crop_image_tensor(self, output_size=output_size) return Image.new_like(self, output) @@ -162,6 +166,10 @@ def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) ) return Image.new_like(self, output) + def perspective(self, perspective_coeffs, *, interpolation, fill) -> Image: + output = self._F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) + return Image.new_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Image: output = self._F.adjust_brightness_image_tensor(self, brightness_factor=brightness_factor) return Image.new_like(self, output) diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index 41fba31d505..9a90b6ccc53 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -17,6 +17,10 @@ def resize(self, size, *, interpolation, max_size, antialias) -> SegmentationMas output = self._F.resize_segmentation_mask(self, size, max_size=max_size) return SegmentationMask.new_like(self, output) + def crop(self, top: int, left: int, height: int, width: int) -> SegmentationMask: + output = self._F.center_crop_segmentation_mask(self, top, left, height, width) + return SegmentationMask.new_like(self, output) + def center_crop(self, output_size) -> SegmentationMask: output = self._F.center_crop_segmentation_mask(self, output_size=output_size) return SegmentationMask.new_like(self, output) @@ -49,11 +53,16 @@ def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) ) return SegmentationMask.new_like(self, output) + def perspective(self, perspective_coeffs, *, interpolation, fill) -> SegmentationMask: + interpolation, fill # unused + output = self._F.perspective_segmentation_mask(self, perspective_coeffs) + return SegmentationMask.new_like(self, output) + def erase(self, *args) -> SegmentationMask: - raise TypeError(f"Erase transformation does not support segmentation masks") + raise TypeError("Erase transformation does not support segmentation masks") def mixup(self, *args) -> SegmentationMask: - raise TypeError(f"Mixup transformation does not support segmentation masks") + raise TypeError("Mixup transformation does not support segmentation masks") def cutmix(self, *args) -> SegmentationMask: - raise TypeError(f"Cutmix transformation does not support segmentation masks") + raise TypeError("Cutmix transformation does not support segmentation masks") diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index a32ba5c9f26..4ad9c7302b7 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.transforms import Transform, functional as F from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_all, has_any, is_simple_tensor +from ._utils import query_image, get_image_dimensions, has_all class RandomErasing(_RandomApplyTransform): diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 960020baff8..e9587e053ab 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,5 +1,4 @@ import collections.abc -import functools from typing import Any, Dict, Union, Tuple, Optional, Sequence, Callable, TypeVar import PIL.Image @@ -23,12 +22,12 @@ def __init__( hue: Optional[Union[float, Sequence[float]]] = None, ) -> None: super().__init__() - self.brightness = self._check_input(brightness, "brightness") - self.contrast = self._check_input(contrast, "contrast") - self.saturation = self._check_input(saturation, "saturation") - self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.brightness = self._check_inpt(brightness, "brightness") + self.contrast = self._check_inpt(contrast, "contrast") + self.saturation = self._check_inpt(saturation, "saturation") + self.hue = self._check_inpt(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - def _check_input( + def _check_inpt( self, value: Optional[Union[float, Sequence[float]]], name: str, @@ -55,74 +54,52 @@ def _check_input( def _image_transform( self, - input: T, + inpt: T, *, kernel_tensor: Callable[..., torch.Tensor], kernel_pil: Callable[..., PIL.Image.Image], **kwargs: Any, ) -> T: - if isinstance(input, features.Image): - output = kernel_tensor(input, **kwargs) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return kernel_tensor(input, **kwargs) - elif isinstance(input, PIL.Image.Image): - return kernel_pil(input, **kwargs) # type: ignore[no-any-return] + if isinstance(inpt, features.Image): + output = kernel_tensor(inpt, **kwargs) + return features.Image.new_like(inpt, output) + elif is_simple_tensor(inpt): + return kernel_tensor(inpt, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return kernel_pil(inpt, **kwargs) # type: ignore[no-any-return] else: raise RuntimeError + @staticmethod + def _generate_value(left: float, right: float) -> float: + return float(torch.distributions.Uniform(left, right).sample()) + def _get_params(self, sample: Any) -> Dict[str, Any]: - image_transforms = [] - if self.brightness is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_brightness_image_tensor, - kernel_pil=F.adjust_brightness_image_pil, - brightness_factor=float( - torch.distributions.Uniform(self.brightness[0], self.brightness[1]).sample() - ), - ) - ) - if self.contrast is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_contrast_image_tensor, - kernel_pil=F.adjust_contrast_image_pil, - contrast_factor=float(torch.distributions.Uniform(self.contrast[0], self.contrast[1]).sample()), - ) - ) - if self.saturation is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_saturation_image_tensor, - kernel_pil=F.adjust_saturation_image_pil, - saturation_factor=float( - torch.distributions.Uniform(self.saturation[0], self.saturation[1]).sample() - ), - ) - ) - if self.hue is not None: - image_transforms.append( - functools.partial( - self._image_transform, - kernel_tensor=F.adjust_hue_image_tensor, - kernel_pil=F.adjust_hue_image_pil, - hue_factor=float(torch.distributions.Uniform(self.hue[0], self.hue[1]).sample()), - ) - ) - - return dict(image_transforms=[image_transforms[idx] for idx in torch.randperm(len(image_transforms))]) - - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): - return input - - for transform in params["image_transforms"]: - input = transform(input) - return input + fn_idx = torch.randperm(4) + + b = None if self.brightness is None else self._generate_value(self.brightness[0], self.brightness[1]) + c = None if self.contrast is None else self._generate_value(self.contrast[0], self.contrast[1]) + s = None if self.saturation is None else self._generate_value(self.saturation[0], self.saturation[1]) + h = None if self.hue is None else self._generate_value(self.hue[0], self.hue[1]) + + return dict(fn_idx=fn_idx, brightness_factor=b, contrast_factor=c, saturation_factor=s, hue_factor=h) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + output = inpt + brightness_factor = params["brightness_factor"] + contrast_factor = params["contrast_factor"] + saturation_factor = params["saturation_factor"] + hue_factor = params["hue_factor"] + for fn_id in params["fn_idx"]: + if fn_id == 0 and brightness_factor is not None: + output = F.adjust_brightness(output, brightness_factor=brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + output = F.adjust_contrast(output, contrast_factor=contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + output = F.adjust_saturation(output, saturation_factor=saturation_factor) + elif fn_id == 3 and hue_factor is not None: + output = F.adjust_hue(output, hue_factor=hue_factor) + return output class _RandomChannelShuffle(Transform): @@ -131,19 +108,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: num_channels, _, _ = get_image_dimensions(image) return dict(permutation=torch.randperm(num_channels)) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if not (isinstance(input, (features.Image, PIL.Image.Image)) or is_simple_tensor(input)): - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)): + return inpt - image = input - if isinstance(input, PIL.Image.Image): + image = inpt + if isinstance(inpt, PIL.Image.Image): image = _F.pil_to_tensor(image) output = image[..., params["permutation"], :, :] - if isinstance(input, features.Image): - output = features.Image.new_like(input, output, color_space=features.ColorSpace.OTHER) - elif isinstance(input, PIL.Image.Image): + if isinstance(inpt, features.Image): + output = features.Image.new_like(inpt, output, color_space=features.ColorSpace.OTHER) + elif isinstance(inpt, PIL.Image.Image): output = _F.to_pil_image(output) return output @@ -175,33 +152,25 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: contrast_before=torch.rand(()) < 0.5, ) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: if params["brightness"]: - input = self._brightness(input) + inpt = self._brightness(inpt) if params["contrast1"] and params["contrast_before"]: - input = self._contrast(input) + inpt = self._contrast(inpt) if params["saturation"]: - input = self._saturation(input) + inpt = self._saturation(inpt) if params["saturation"]: - input = self._saturation(input) + inpt = self._saturation(inpt) if params["contrast2"] and not params["contrast_before"]: - input = self._contrast(input) + inpt = self._contrast(inpt) if params["channel_shuffle"]: - input = self._channel_shuffle(input) - return input + inpt = self._channel_shuffle(inpt) + return inpt class RandomEqualize(_RandomApplyTransform): def __init__(self, p: float = 0.5): super().__init__(p=p) - def _transform(self, input: Any, params: Dict[str, Any]) -> Any: - if isinstance(input, features.Image): - output = F.equalize_image_tensor(input) - return features.Image.new_like(input, output) - elif is_simple_tensor(input): - return F.equalize_image_tensor(input) - elif isinstance(input, PIL.Image.Image): - return F.equalize_image_pil(input) - else: - return input + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.equalize(inpt) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 288e88b3aed..e112ce744b1 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -18,7 +18,7 @@ from typing_extensions import Literal from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor +from ._utils import query_image, get_image_dimensions, has_any class RandomHorizontalFlip(_RandomApplyTransform): @@ -46,30 +46,13 @@ def __init__( self.antialias = antialias def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features._Feature): - return inpt.resize( - self.size, - interpolation=self.interpolation, - max_size=self.max_size, - antialias=self.antialias, - ) - elif isinstance(inpt, PIL.Image.Image): - return F.resize_image_pil( - inpt, - self.size, - interpolation=self.interpolation, - max_size=self.max_size, - ) - elif isinstance(inpt, torch.Tensor): - return F.resize_image_tensor( - inpt, - self.size, - interpolation=self.interpolation, - max_size=self.max_size, - antialias=self.antialias, - ) - else: - return inpt + return F.resize( + inpt, + self.size, + interpolation=self.interpolation, + max_size=self.max_size, + antialias=self.antialias, + ) class CenterCrop(Transform): @@ -78,14 +61,7 @@ def __init__(self, output_size: List[int]): self.output_size = output_size def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features._Feature): - return inpt.center_crop(self.output_size) - elif isinstance(inpt, PIL.Image.Image): - return F.center_crop_image_pil(inpt, self.output_size) - elif isinstance(inpt, torch.Tensor): - return F.center_crop_image_tensor(inpt, self.output_size) - else: - return inpt + return F.center_crop(inpt, output_size=self.output_size) class RandomResizedCrop(Transform): @@ -158,18 +134,9 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(top=i, left=j, height=h, width=w) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features._Feature): - antialias = False if self.antialias is None else self.antialias - return inpt.resized_crop(**params, size=self.size, interpolation=self.interpolation, antialias=antialias) - elif isinstance(inpt, PIL.Image.Image): - return F.resized_crop_image_pil(inpt, **params, size=list(self.size), interpolation=self.interpolation) - elif isinstance(inpt, torch.Tensor): - antialias = False if self.antialias is None else self.antialias - return F.resized_crop_image_tensor( - inpt, **params, size=list(self.size), interpolation=self.interpolation, antialias=antialias - ) - else: - return inpt + return F.resized_crop( + inpt, **params, size=self.size, interpolation=self.interpolation, antialias=self.antialias + ) class MultiCropResult(list): @@ -286,28 +253,7 @@ def __init__( self.padding_mode = padding_mode def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features._Feature): - return inpt.pad( - self.padding, - fill=self.fill, - padding_mode=self.padding_mode, - ) - elif isinstance(inpt, PIL.Image.Image): - return F.pad_image_pil( - inpt, - self.padding, - fill=self.fill, - padding_mode=self.padding_mode, - ) - elif isinstance(inpt, torch.Tensor): - return F.pad_image_tensor( - inpt, - self.padding, - fill=self.fill, - padding_mode=self.padding_mode, - ) - else: - return inpt + return F.pad(inpt, padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) class RandomZoomOut(_RandomApplyTransform): @@ -361,7 +307,7 @@ def __init__( expand=False, fill=0, center=None, - ): + ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) self.interpolation = interpolation @@ -404,7 +350,7 @@ def __init__( interpolation=InterpolationMode.NEAREST, fill=0, center=None, - ): + ) -> None: super().__init__() self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 4ff938ec684..a8c17577a56 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -49,14 +49,17 @@ horizontal_flip_image_tensor, horizontal_flip_image_pil, horizontal_flip_segmentation_mask, + resize, resize_bounding_box, resize_image_tensor, resize_image_pil, resize_segmentation_mask, + center_crop, center_crop_bounding_box, center_crop_segmentation_mask, center_crop_image_tensor, center_crop_image_pil, + resized_crop, resized_crop_bounding_box, resized_crop_image_tensor, resized_crop_image_pil, @@ -71,14 +74,17 @@ rotate_image_tensor, rotate_image_pil, rotate_segmentation_mask, + pad, pad_bounding_box, pad_image_tensor, pad_image_pil, pad_segmentation_mask, + crop, crop_bounding_box, crop_image_tensor, crop_image_pil, crop_segmentation_mask, + perspective, perspective_bounding_box, perspective_image_tensor, perspective_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 9cad4106bc2..beafcb42136 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -133,6 +133,27 @@ def resize_bounding_box( return bounding_box.view(-1, 2, 2).mul(ratios).view(bounding_box.shape) +def resize( + inpt: Any, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + max_size: Optional[int] = None, + antialias: Optional[bool] = None, +) -> Any: + if isinstance(inpt, features._Feature): + antialias = False if antialias is None else antialias + return inpt.resize(size, interpolation=interpolation, max_size=max_size, antialias=antialias) + elif isinstance(inpt, PIL.Image.Image): + if antialias is not None and not antialias: + warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") + return resize_image_pil(inpt, size, interpolation=interpolation, max_size=max_size) + elif isinstance(inpt, torch.Tensor): + antialias = False if antialias is None else antialias + return resize_image_tensor(inpt, size, interpolation=interpolation, max_size=max_size, antialias=antialias) + else: + return inpt + + def _affine_parse_args( angle: float, translate: List[float], @@ -504,6 +525,21 @@ def pad_bounding_box( return bounding_box +def pad(inpt: Any, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Any: + kwargs = dict( + fill=fill, + padding_mode=padding_mode, + ) + if isinstance(inpt, features._Feature): + return inpt.pad(padding, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return rotate_image_pil(inpt, padding, **kwargs) + elif isinstance(inpt, torch.Tensor): + return rotate_image_tensor(inpt, padding, **kwargs) + else: + return inpt + + crop_image_tensor = _FT.crop crop_image_pil = _FP.crop @@ -531,6 +567,17 @@ def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, return crop_image_tensor(img, top, left, height, width) +def crop(inpt: Any, top: int, left: int, height: int, width: int) -> Any: + if isinstance(inpt, features._Feature): + return inpt.crop(top, left, height, width) + elif isinstance(inpt, PIL.Image.Image): + return crop_image_pil(inpt, top, left, height, width) + elif isinstance(inpt, torch.Tensor): + return crop_image_tensor(inpt, top, left, height, width) + else: + return inpt + + def perspective_image_tensor( img: torch.Tensor, perspective_coeffs: List[float], @@ -638,6 +685,23 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) +def perspective( + inpt: Any, + perspective_coeffs: List[float], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> Any: + kwargs = dict(interpolation=interpolation, fill=fill) + if isinstance(inpt, features._Feature): + return inpt.perspective(perspective_coeffs, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return perspective_image_pil(inpt, perspective_coeffs, **kwargs) + elif isinstance(inpt, torch.Tensor): + return perspective_image_tensor(inpt, perspective_coeffs, **kwargs) + else: + return inpt + + def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)] @@ -711,6 +775,17 @@ def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: return center_crop_image_tensor(img=segmentation_mask, output_size=output_size) +def center_crop(inpt: Any, output_size: List[int]) -> Any: + if isinstance(inpt, features._Feature): + return inpt.center_crop(output_size) + elif isinstance(inpt, PIL.Image.Image): + return center_crop_image_pil(inpt, output_size) + elif isinstance(inpt, torch.Tensor): + return center_crop_image_tensor(inpt, output_size) + else: + return inpt + + def resized_crop_image_tensor( img: torch.Tensor, top: int, @@ -763,6 +838,29 @@ def resized_crop_segmentation_mask( return resize_segmentation_mask(mask, size) +def resized_crop( + inpt: Any, + top: int, + left: int, + height: int, + width: int, + size: List[int], + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + antialias: Optional[bool] = None, +) -> Any: + kwargs = dict(size=size, interpolation=interpolation) + if isinstance(inpt, features._Feature): + antialias = False if antialias is None else antialias + return inpt.resized_crop(top, left, height, width, antialias=antialias, **kwargs) + elif isinstance(inpt, PIL.Image.Image): + return resized_crop_image_pil(inpt, top, left, height, width, **kwargs) + elif isinstance(inpt, torch.Tensor): + antialias = False if antialias is None else antialias + return resized_crop_image_tensor(inpt, top, left, height, width, antialias=antialias, **kwargs) + else: + return inpt + + def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): size = [int(size), int(size)] From 7917a17eda535fe3904bdc3c3a31af552423a4c1 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 28 Jun 2022 12:32:27 +0000 Subject: [PATCH 18/20] _check_inpt -> _check_input --- torchvision/prototype/transforms/_color.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index e9587e053ab..60fe46ed9ea 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -22,12 +22,12 @@ def __init__( hue: Optional[Union[float, Sequence[float]]] = None, ) -> None: super().__init__() - self.brightness = self._check_inpt(brightness, "brightness") - self.contrast = self._check_inpt(contrast, "contrast") - self.saturation = self._check_inpt(saturation, "saturation") - self.hue = self._check_inpt(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) + self.brightness = self._check_input(brightness, "brightness") + self.contrast = self._check_input(contrast, "contrast") + self.saturation = self._check_input(saturation, "saturation") + self.hue = self._check_input(hue, "hue", center=0, bound=(-0.5, 0.5), clip_first_on_zero=False) - def _check_inpt( + def _check_input( self, value: Optional[Union[float, Sequence[float]]], name: str, From 99bfad9eabd61c7d6ff1130255e3f32019671aea Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 28 Jun 2022 15:10:10 +0000 Subject: [PATCH 19/20] Fixed broken code, added a test for mid-level ops --- test/test_prototype_transforms_functional.py | 32 ++++++++++++++++--- .../prototype/features/_bounding_box.py | 18 +++++++---- .../transforms/functional/_geometry.py | 1 - 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 30d9b833ec8..9a26f9a225b 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -489,16 +489,40 @@ def center_crop_segmentation_mask(): and callable(kernel) and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"}) and "pil" not in name - and name - not in { - "to_image_tensor", - } + and name not in {"to_image_tensor"} ], ) def test_scriptable(kernel): jit.script(kernel) +@pytest.mark.parametrize( + "func", + [ + pytest.param(func, id=name) + for name, func in F.__dict__.items() + if not name.startswith("_") + and callable(func) + and all( + feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} + ) + and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av"} + ], +) +def test_functional_mid_level(func): + finfos = [finfo for finfo in FUNCTIONAL_INFOS if f"{func.__name__}_" in finfo.name] + for finfo in finfos: + for sample_input in finfo.sample_inputs(): + expected = finfo(sample_input) + kwargs = dict(sample_input.kwargs) + for key in ["format", "image_size"]: + if key in kwargs: + del kwargs[key] + output = func(*sample_input.args, **kwargs) + torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}") + break + + @pytest.mark.parametrize( ("functional_info", "sample_input"), [ diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index f18d7efae6c..4e119e7bb25 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -81,7 +81,7 @@ def vertical_flip(self) -> BoundingBox: def resize(self, size, *, interpolation, max_size, antialias) -> BoundingBox: interpolation, antialias # unused output = self._F.resize_bounding_box(self, size, image_size=self.image_size, max_size=max_size) - return BoundingBox.new_like(self, output, image_size=size) + return BoundingBox.new_like(self, output, image_size=size, dtype=output.dtype) def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox: output = self._F.crop_bounding_box(self, self.format, top, left) @@ -94,10 +94,9 @@ def center_crop(self, output_size) -> BoundingBox: return BoundingBox.new_like(self, output, image_size=output_size) def resized_crop(self, top, left, height, width, *, size, interpolation, antialias) -> BoundingBox: - # TODO: untested right now interpolation, antialias # unused output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size) - return BoundingBox.new_like(self, output, image_size=size) + return BoundingBox.new_like(self, output, image_size=size, dtype=output.dtype) def pad(self, padding, *, fill, padding_mode) -> BoundingBox: fill # unused @@ -107,7 +106,10 @@ def pad(self, padding, *, fill, padding_mode) -> BoundingBox: output = self._F.pad_bounding_box(self, padding, format=self.format) # Update output image size: - left, top, right, bottom = padding + # TODO: remove the import below and make _parse_pad_padding available + from torchvision.transforms.functional_tensor import _parse_pad_padding + + left, top, right, bottom = _parse_pad_padding(padding) height, width = self.image_size height += top + bottom width += left + right @@ -122,24 +124,26 @@ def rotate(self, angle, *, interpolation, expand, fill, center) -> BoundingBox: # TODO: update output image size if expand is True if expand: raise RuntimeError("Not yet implemented") - return BoundingBox.new_like(self, output) + return BoundingBox.new_like(self, output, dtype=output.dtype) def affine(self, angle, *, translate, scale, shear, interpolation, fill, center) -> BoundingBox: interpolation, fill # unused output = self._F.affine_bounding_box( self, + self.format, + self.image_size, angle, translate=translate, scale=scale, shear=shear, center=center, ) - return BoundingBox.new_like(self, output) + return BoundingBox.new_like(self, output, dtype=output.dtype) def perspective(self, perspective_coeffs, *, interpolation, fill) -> BoundingBox: interpolation, fill # unused output = self._F.perspective_bounding_box(self, self.format, perspective_coeffs) - return BoundingBox.new_like(self, output) + return BoundingBox.new_like(self, output, dtype=output.dtype) def erase(self, *args) -> BoundingBox: raise TypeError("Erase transformation does not support bounding boxes") diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index beafcb42136..9d4fe748bd0 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -369,7 +369,6 @@ def affine_segmentation_mask( def affine( inpt: Any, angle: float, - *, translate: List[float], scale: float, shear: List[float], From ae5eef949e5f9eaa72fdbb7c7f6fcac6c84f1aaa Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 29 Jun 2022 12:26:42 +0000 Subject: [PATCH 20/20] Fixed bugs and started porting transforms --- torchvision/prototype/features/_image.py | 5 -- torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 89 +++++++++++++++++-- .../transforms/functional/_geometry.py | 4 +- 4 files changed, 87 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 407e8fc07e9..f7bb24fb427 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -139,11 +139,6 @@ def resized_crop(self, top, left, height, width, *, size, interpolation, antiali return Image.new_like(self, output) def pad(self, padding, *, fill, padding_mode) -> Image: - # Previous message from previous implementation: - # PyTorch's pad supports only integers on fill. So we need to overwrite the colour - # vfdev-5: pytorch pad support both int and floats but keeps original dtyp - # if user pads int image with float pad, they need to cast the image first to float - # before padding. Let's remove previous manual float fill support. output = self._F.pad_image_tensor(self, padding, fill=fill, padding_mode=padding_mode) return Image.new_like(self, output) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 5edd18890a8..d5777560089 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -8,6 +8,7 @@ from ._container import Compose, RandomApply, RandomChoice, RandomOrder from ._geometry import ( Resize, + RandomCrop, CenterCrop, RandomResizedCrop, FiveCrop, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index e112ce744b1..e7e07701cd4 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -64,6 +64,84 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.center_crop(inpt, output_size=self.output_size) +class RandomCrop(Transform): + def __init__( + self, + size: Union[int, Sequence[int]], + padding: Optional[Union[int, Sequence[int]]] = None, + pad_if_needed: bool = False, + fill: Union[float, Sequence[float]] = 0.0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ): + super().__init__() + self.size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + self._pad_op = Pad(padding=self.padding, fill=self.fill, padding_mode=self.padding_mode) + + def _get_params(self, sample: Any) -> Dict[str, Any]: + # vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples + # What if we have multiple images/bboxes/masks of different sizes ? + # TODO: let's support bbox or mask in samples without image + image = query_image(sample) + _, height, width = get_image_dimensions(image) + out_height, out_width = self.size + + if height + 1 < out_height or width + 1 < out_width: + raise ValueError( + f"Required crop size {(out_height, out_width)} is larger then input image size {(height, width)}" + ) + + if height == out_height and width == out_width: + return dict(top=0, left=0, height=height, width=width) + + i = torch.randint(0, height - out_height + 1, size=(1,)).item() + j = torch.randint(0, width - out_width + 1, size=(1,)).item() + return dict(top=i, left=j, height=out_height, width=out_width) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.crop(inpt, **params) + + def forward(self, *inputs: Any) -> Any: + # TODO: main difficulties implementing this op: + # 1) unstructured inputs and why we need to call: sample = inputs if len(inputs) > 1 else inputs[0] ? + # 2) how to call F.op efficiently on inputs ? + # + # We can make inputs flatten using from torch.utils._pytree import tree_flatten, tree_unflatten + # Such that inputs -> flat_inputs = [obj1, obj2, obj3, ...] + + raise RuntimeError("Not yet implemented") + + params = self._get_params(inputs) + + # sample = inputs if len(inputs) > 1 else inputs[0] + # return apply_recursively(functools.partial(self._transform, params=self._get_params(sample)), sample) + + if self.padding is not None: + self._pad_op.padding = self.padding + inputs = self._pad_op(*inputs) + + # vfdev-5: techically, this op can work on bboxes/segm masks only inputs without image in samples + # What if we have multiple images/bboxes/masks of different sizes ? + # TODO: let's support bbox or mask in samples without image + image = query_image(inputs) + _, height, width = get_image_dimensions(image) + + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding=padding, fill=self.fill, padding_mode=self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding=padding, fill=self.fill, padding_mode=self.padding_mode) + + return ... + + class RandomResizedCrop(Transform): def __init__( self, @@ -270,7 +348,7 @@ def __init__( if side_range[0] < 1.0 or side_range[0] > side_range[1]: raise ValueError(f"Invalid canvas side range provided {side_range}.") - self.pad_op = Pad(0, padding_mode="constant") + self._pad_op = Pad(0, padding_mode="constant") def _get_params(self, sample: Any) -> Dict[str, Any]: image = query_image(sample) @@ -293,10 +371,11 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(padding=padding, fill=fill) - def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - self.pad_op.padding = params["padding"] - self.pad_op.fill = params["fill"] - return self.pad_op(inpt) + def forward(self, *inputs: Any) -> Any: + params = self._get_params(inputs) + self._pad_op.padding = params["padding"] + self._pad_op.fill = params["fill"] + return self._pad_op(*inputs) class RandomRotation(Transform): diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 9d4fe748bd0..b6704b96328 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -532,9 +532,9 @@ def pad(inpt: Any, padding: List[int], fill: int = 0, padding_mode: str = "const if isinstance(inpt, features._Feature): return inpt.pad(padding, **kwargs) elif isinstance(inpt, PIL.Image.Image): - return rotate_image_pil(inpt, padding, **kwargs) + return pad_image_pil(inpt, padding, **kwargs) elif isinstance(inpt, torch.Tensor): - return rotate_image_tensor(inpt, padding, **kwargs) + return pad_image_tensor(inpt, padding, **kwargs) else: return inpt