From 4c0a57440024ed0de233a59aeeb281f1e567e2b2 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 27 Jul 2022 00:13:14 +0200 Subject: [PATCH] Added erase_image_pil and eager/jit erase_image_tensor test --- test/test_prototype_transforms_functional.py | 7 +++++++ torchvision/prototype/transforms/_augment.py | 5 +---- .../prototype/transforms/functional/__init__.py | 2 +- .../prototype/transforms/functional/_augment.py | 17 ++++++++++------- 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index fb5f10459fe..a5bf45f5676 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -654,6 +654,13 @@ def adjust_sharpness_image_tensor(): yield SampleInput(image, sharpness_factor=sharpness_factor) +@register_kernel_info_from_sample_inputs_fn +def erase_image_tensor(): + for image in make_images(): + c = image.shape[-3] + yield SampleInput(image, i=1, j=2, h=6, w=7, v=torch.rand(c, 6, 7)) + + @pytest.mark.parametrize( "kernel", [ diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 2c71a5faf64..12e2cd3cc6d 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,7 +7,6 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform -from torchvision.transforms.functional import pil_to_tensor, to_pil_image from ._transform import _RandomApplyTransform from ._utils import get_image_dimensions, has_all, has_any, query_image @@ -93,9 +92,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(inpt, output) return output elif isinstance(inpt, PIL.Image.Image): - t_img = pil_to_tensor(inpt) - output = F.erase_image_tensor(t_img, **params) - return to_pil_image(output, mode=inpt.mode) + return F.erase_image_pil(inpt, **params) else: return inpt diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 1aef37600d6..82e3096821a 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -5,7 +5,7 @@ convert_image_color_space_pil, ) # usort: skip -from ._augment import erase_image_tensor +from ._augment import erase_image_pil, erase_image_tensor from ._color import ( adjust_brightness, adjust_brightness_image_pil, diff --git a/torchvision/prototype/transforms/functional/_augment.py b/torchvision/prototype/transforms/functional/_augment.py index 3920d1b3065..84b069cf396 100644 --- a/torchvision/prototype/transforms/functional/_augment.py +++ b/torchvision/prototype/transforms/functional/_augment.py @@ -1,13 +1,16 @@ +import PIL.Image + +import torch from torchvision.transforms import functional_tensor as _FT +from torchvision.transforms.functional import pil_to_tensor, to_pil_image erase_image_tensor = _FT.erase -# TODO: Don't forget to clean up from the primitives kernels those that shouldn't be kernels. -# Like the mixup and cutmix stuff - -# This function is copy-pasted to Image and OneHotLabel and may be refactored -# def _mixup_tensor(input: torch.Tensor, batch_dim: int, lam: float) -> torch.Tensor: -# input = input.clone() -# return input.roll(1, batch_dim).mul_(1 - lam).add_(input.mul_(lam)) +def erase_image_pil( + img: PIL.Image.Image, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False +) -> PIL.Image.Image: + t_img = pil_to_tensor(img) + output = erase_image_tensor(t_img, i=i, j=j, h=h, w=w, v=v, inplace=inplace) + return to_pil_image(output, mode=img.mode)