From 1b22ddc642334b54a08c2a111ebc6d111713b46b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 3 Nov 2020 00:17:53 +0000 Subject: [PATCH 1/8] Deprecated arguments: resample and fillcolor Replaced by interpolation and fill --- test/test_functional_tensor.py | 26 +++++------ test/test_transforms.py | 2 +- test/test_transforms_tensor.py | 4 +- torchvision/transforms/functional.py | 33 +++++++++---- torchvision/transforms/functional_pil.py | 16 +++---- torchvision/transforms/functional_tensor.py | 34 +++++++------- torchvision/transforms/transforms.py | 51 +++++++++++++-------- 7 files changed, 97 insertions(+), 69 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e91e9321107..7d105c0c362 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -438,12 +438,12 @@ def test_resized_crop(self): def _test_affine_identity_map(self, tensor, scripted_affine): # 1) identity map - out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0) self.assertTrue( tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) ) - out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0) + out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0) self.assertTrue( tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) ) @@ -461,13 +461,13 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine): ] for a, true_tensor in test_configs: out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device) for fn in [F.affine, scripted_affine]: out_tensor = fn( - tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 ) if true_tensor is not None: self.assertTrue( @@ -496,13 +496,13 @@ def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine): for a in test_configs: out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.affine, scripted_affine]: out_tensor = fn( - tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0 + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 ).cpu() if out_tensor.dtype != torch.uint8: @@ -526,10 +526,10 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine): ] for t in test_configs: - out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=0) for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0) + out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=0) if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -552,11 +552,11 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): ] for r in [0, ]: for a, t, s, sh in test_configs: - out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r) + out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r).cpu() + out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, interpolation=r).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -613,10 +613,10 @@ def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): for e in [True, False]: for c in centers: - out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c) + out_pil_img = F.rotate(pil_img, angle=a, interpolation=r, expand=e, center=c) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.rotate, scripted_rotate]: - out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c).cpu() + out_tensor = fn(tensor, angle=a, interpolation=r, expand=e, center=c).cpu() if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -673,7 +673,7 @@ def test_rotate(self): center = (20, 22) self._test_fn_on_batch( - batch_tensors, F.rotate, angle=32, resample=0, expand=True, center=center + batch_tensors, F.rotate, angle=32, interpolation=0, expand=True, center=center ) def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): diff --git a/test/test_transforms.py b/test/test_transforms.py index f9add6d1b57..f7eedd802e1 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1529,7 +1529,7 @@ def test_random_affine(self): # Checking if RandomAffine can be printed as string t.__repr__() - t = transforms.RandomAffine(10, resample=Image.BILINEAR) + t = transforms.RandomAffine(10, interpolation=Image.BILINEAR) self.assertIn("Image.BILINEAR", t.__repr__()) def test_to_grayscale(self): diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 1fc0ab61ec4..e6070bd427c 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -349,7 +349,7 @@ def test_random_affine(self): for interpolation in [NEAREST, BILINEAR]: transform = T.RandomAffine( degrees=degrees, translate=translate, - scale=scale, shear=shear, resample=interpolation + scale=scale, shear=shear, interpolation=interpolation ) s_transform = torch.jit.script(transform) @@ -368,7 +368,7 @@ def test_random_rotate(self): for degrees in [45, 35.0, (-45, 45), [-90.0, 90.0]]: for interpolation in [NEAREST, BILINEAR]: transform = T.RandomRotation( - degrees=degrees, resample=interpolation, expand=expand, center=center + degrees=degrees, interpolation=interpolation, expand=expand, center=center ) s_transform = torch.jit.script(transform) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 13135807091..813502f4db3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -793,8 +793,8 @@ def _get_inverse_affine_matrix( def rotate( - img: Tensor, angle: float, resample: int = 0, expand: bool = False, - center: Optional[List[int]] = None, fill: Optional[int] = None + img: Tensor, angle: float, interpolation: int = 0, expand: bool = False, + center: Optional[List[int]] = None, fill: Optional[int] = None, resample: Optional[int] = None ) -> Tensor: """Rotate the image by angle. The image can be a PIL Image or a Tensor, in which case it is expected @@ -803,7 +803,7 @@ def rotate( Args: img (PIL Image or Tensor): image to be rotated. angle (float or int): rotation angle value in degrees, counter-clockwise. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. expand (bool, optional): Optional expansion flag. @@ -817,6 +817,7 @@ def rotate( Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. Returns: PIL Image or Tensor: Rotated image. @@ -824,6 +825,9 @@ def rotate( .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ + if resample is not None: + warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") @@ -831,7 +835,7 @@ def rotate( raise TypeError("Argument center should be a sequence") if not isinstance(img, torch.Tensor): - return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill) + return F_pil.rotate(img, angle=angle, interpolation=interpolation, expand=expand, center=center, fill=fill) center_f = [0.0, 0.0] if center is not None: @@ -842,12 +846,13 @@ def rotate( # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) - return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill) + return F_t.rotate(img, matrix=matrix, interpolation=interpolation, expand=expand, fill=fill) def affine( img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], - resample: int = 0, fillcolor: Optional[int] = None + interpolation: int = 0, fill: Optional[int] = None, resample: Optional[int] = None, + fillcolor: Optional[int] = None ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. The image can be a PIL Image or a Tensor, in which case it is expected @@ -861,17 +866,25 @@ def affine( shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction. If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. - fillcolor (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). + fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + fillcolor (tuple or int, optional): deprecated argument, please use `arg`:fill: instead. + resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. Returns: PIL Image or Tensor: Transformed image. """ + if resample is not None: + warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + + if fillcolor is not None: + warnings.warn("Argument fillcolor is deprecated. Please, use fill instead") + if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") @@ -913,11 +926,11 @@ def affine( center = [img_size[0] * 0.5, img_size[1] * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) - return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) + return F_pil.affine(img, matrix=matrix, interpolation=interpolation, fill=fill) translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) - return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor) + return F_t.affine(img, matrix=matrix, interpolation=interpolation, fill=fill) @torch.jit.unused diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index d76bc7a0027..7e3989f0288 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -474,7 +474,7 @@ def _parse_fill(fill, img, min_pil_version, name="fillcolor"): @torch.jit.unused -def affine(img, matrix, resample=0, fillcolor=None): +def affine(img, matrix, interpolation=0, fill=None): """PRIVATE METHOD. Apply affine transformation on the PIL Image keeping image center invariant. .. warning:: @@ -485,11 +485,11 @@ def affine(img, matrix, resample=0, fillcolor=None): Args: img (PIL Image): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) + fill (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) Returns: PIL Image: Transformed image. @@ -498,12 +498,12 @@ def affine(img, matrix, resample=0, fillcolor=None): raise TypeError('img should be PIL Image. Got {}'.format(type(img))) output_size = img.size - opts = _parse_fill(fillcolor, img, '5.0.0') - return img.transform(output_size, Image.AFFINE, matrix, resample, **opts) + opts = _parse_fill(fill, img, '5.0.0') + return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts) @torch.jit.unused -def rotate(img, angle, resample=0, expand=False, center=None, fill=None): +def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None): """PRIVATE METHOD. Rotate PIL image by angle. .. warning:: @@ -514,7 +514,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): Args: img (PIL Image): image to be rotated. angle (float or int): rotation angle value in degrees, counter-clockwise. - resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): + interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): An optional resampling filter. See `filters`_ for more information. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. expand (bool, optional): Optional expansion flag. @@ -538,7 +538,7 @@ def rotate(img, angle, resample=0, expand=False, center=None, fill=None): raise TypeError("img should be PIL Image. Got {}".format(type(img))) opts = _parse_fill(fill, img, '5.2.0') - return img.rotate(angle, resample, expand, center, **opts) + return img.rotate(angle, interpolation, expand, center, **opts) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index d27e6066fe3..25e0a82cd7d 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -855,8 +855,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: def _assert_grid_transform_inputs( img: Tensor, matrix: Optional[List[float]], - resample: int, - fillcolor: Optional[int], + interpolation: int, + fill: Optional[int], _interpolation_modes: Dict[int, str], coeffs: Optional[List[float]] = None, ): @@ -872,11 +872,11 @@ def _assert_grid_transform_inputs( if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") - if fillcolor is not None: - warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero") + if fill is not None: + warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero") - if resample not in _interpolation_modes: - raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample)) + if interpolation not in _interpolation_modes: + raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation)) def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]: @@ -941,7 +941,7 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None + img: Tensor, matrix: List[float], interpolation: int = 0, fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant. @@ -953,9 +953,9 @@ def affine( Args: img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + interpolation (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: bilinear(=2). - fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the + fill (int, optional): this option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. Returns: @@ -966,14 +966,14 @@ def affine( 2: "bilinear", } - _assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes) + _assert_grid_transform_inputs(img, matrix, interpolation, fill, _interpolation_modes) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) - mode = _interpolation_modes[resample] + mode = _interpolation_modes[interpolation] return _apply_grid_transform(img, grid, mode) @@ -1003,7 +1003,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( - img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None + img: Tensor, matrix: List[float], interpolation: int = 0, expand: bool = False, fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Rotate the Tensor image by angle. @@ -1016,7 +1016,7 @@ def rotate( img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates. - resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: + interpolation (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: bilinear(=2). expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. @@ -1036,14 +1036,14 @@ def rotate( 2: "bilinear", } - _assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes) + _assert_grid_transform_inputs(img, matrix, interpolation, fill, _interpolation_modes) w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) - mode = _interpolation_modes[resample] + mode = _interpolation_modes[interpolation] return _apply_grid_transform(img, grid, mode) @@ -1112,8 +1112,8 @@ def perspective( _assert_grid_transform_inputs( img, matrix=None, - resample=interpolation, - fillcolor=fill, + interpolation=interpolation, + fill=fill, _interpolation_modes=_interpolation_modes, coeffs=perspective_coeffs ) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index be835bdd213..67d974409f5 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -242,7 +242,7 @@ class Resize(torch.nn.Module): (size * height / width, size). In torchscript mode padding as single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. + interpolation (int): Interpolation type defined by `filters`_. Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` and ``PIL.Image.BICUBIC`` are supported. """ @@ -744,7 +744,7 @@ class RandomResizedCrop(torch.nn.Module): made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). scale (tuple of float): scale range of the cropped image before resizing, relatively to the origin image. ratio (tuple of float): aspect ratio range of the cropped image before resizing. - interpolation (int): Desired interpolation enum defined by `filters`_. + interpolation (int): Interpolation type defined by `filters`_. Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` and ``PIL.Image.BICUBIC`` are supported. """ @@ -1134,7 +1134,7 @@ class RandomRotation(torch.nn.Module): degrees (sequence or float or int): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). - resample (int, optional): An optional resampling filter. See `filters`_ for more information. + interpolation (int): Interpolation type defined by `filters`_. If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. expand (bool, optional): Optional expansion flag. @@ -1148,13 +1148,17 @@ class RandomRotation(torch.nn.Module): Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - def __init__(self, degrees, resample=False, expand=False, center=None, fill=None): + def __init__(self, degrees, interpolation=0, expand=False, center=None, fill=None, resample=None): super().__init__() + if resample is not None: + warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) if center is not None: @@ -1162,7 +1166,7 @@ def __init__(self, degrees, resample=False, expand=False, center=None, fill=None self.center = center - self.resample = resample + self.resample = self.interpolation = interpolation self.expand = expand self.fill = fill @@ -1185,11 +1189,12 @@ def forward(self, img): PIL Image or Tensor: Rotated image. """ angle = self.get_params(self.degrees) - return F.rotate(img, angle, self.resample, self.expand, self.center, self.fill) + return F.rotate(img, angle, self.interpolation, self.expand, self.center, self.fill) def __repr__(self): + interpolate_str = _pil_interpolation_to_str[self.interpolation] format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) - format_string += ', resample={0}'.format(self.resample) + format_string += ', interpolation={0}'.format(interpolate_str) format_string += ', expand={0}'.format(self.expand) if self.center is not None: format_string += ', center={0}'.format(self.center) @@ -1220,19 +1225,29 @@ class RandomAffine(torch.nn.Module): range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. Will not apply shear by default. - resample (int, optional): An optional resampling filter. See `filters`_ for more information. + interpolation (int): Interpolation type defined by `filters`_. If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. - fillcolor (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area + fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. + fillcolor (tuple or int, optional): deprecated argument, please use `arg`:fill: instead. + resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, fillcolor=0): + def __init__( + self, degrees, translate=None, scale=None, shear=None, interpolation=0, fill=0, fillcolor=None, resample=None + ): super().__init__() + if resample is not None: + warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + + if fillcolor is not None: + warnings.warn("Argument fillcolor is deprecated. Please, use fill instead") + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) if translate is not None: @@ -1254,8 +1269,8 @@ def __init__(self, degrees, translate=None, scale=None, shear=None, resample=0, else: self.shear = shear - self.resample = resample - self.fillcolor = fillcolor + self.resample = self.interpolation = interpolation + self.fillcolor = self.fill = fill @staticmethod def get_params( @@ -1306,7 +1321,7 @@ def forward(self, img): img_size = F._get_image_size(img) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size) - return F.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) + return F.affine(img, *ret, interpolation=self.interpolation, fill=self.fill) def __repr__(self): s = '{name}(degrees={degrees}' @@ -1316,13 +1331,13 @@ def __repr__(self): s += ', scale={scale}' if self.shear is not None: s += ', shear={shear}' - if self.resample > 0: - s += ', resample={resample}' - if self.fillcolor != 0: - s += ', fillcolor={fillcolor}' + if self.interpolation > 0: + s += ', interpolation={interpolation}' + if self.fill != 0: + s += ', fill={fill}' s += ')' d = dict(self.__dict__) - d['resample'] = _pil_interpolation_to_str[d['resample']] + d['interpolation'] = _pil_interpolation_to_str[d['interpolation']] return s.format(name=self.__class__.__name__, **d) From 20c71774213d907685eee946e1db2d6b953fa382 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 3 Nov 2020 16:30:54 +0000 Subject: [PATCH 2/8] Updates according to the review --- torchvision/transforms/functional.py | 24 ++++++++++++++++++------ torchvision/transforms/transforms.py | 24 ++++++++++++++++++------ 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 813502f4db3..96fdd364ab6 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -817,7 +817,8 @@ def rotate( Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. - resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. Returns: PIL Image or Tensor: Rotated image. @@ -826,7 +827,10 @@ def rotate( """ if resample is not None: - warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = resample if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") @@ -873,17 +877,25 @@ def affine( fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. - fillcolor (tuple or int, optional): deprecated argument, please use `arg`:fill: instead. - resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. + fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:fill: instead. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. Returns: PIL Image or Tensor: Transformed image. """ if resample is not None: - warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = resample if fillcolor is not None: - warnings.warn("Argument fillcolor is deprecated. Please, use fill instead") + warnings.warn( + "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" + ) + fill = fillcolor if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 67d974409f5..bad0ebd571d 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1148,7 +1148,8 @@ class RandomRotation(torch.nn.Module): Defaults to 0 for all bands. This option is only available for Pillow>=5.2.0. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. - resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters @@ -1157,7 +1158,10 @@ class RandomRotation(torch.nn.Module): def __init__(self, degrees, interpolation=0, expand=False, center=None, fill=None, resample=None): super().__init__() if resample is not None: - warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = resample self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) @@ -1231,8 +1235,10 @@ class RandomAffine(torch.nn.Module): fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. - fillcolor (tuple or int, optional): deprecated argument, please use `arg`:fill: instead. - resample (int, optional): deprecated argument, please use `arg`:interpolation: instead. + fillcolor (tuple or int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:fill: instead. + resample (int, optional): deprecated argument and will be removed since v0.10.0. + Please use `arg`:interpolation: instead. .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters @@ -1243,10 +1249,16 @@ def __init__( ): super().__init__() if resample is not None: - warnings.warn("Argument resample is deprecated. Please, use interpolation instead") + warnings.warn( + "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" + ) + interpolation = resample if fillcolor is not None: - warnings.warn("Argument fillcolor is deprecated. Please, use fill instead") + warnings.warn( + "Argument fillcolor is deprecated and will be removed since v0.10.0. Please, use fill instead" + ) + fill = fillcolor self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) From 1bc499ee000afac76a53448e6c1aafc4f80b0a74 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 3 Nov 2020 17:58:30 +0000 Subject: [PATCH 3/8] Added tests to check warnings and asserted BC --- test/test_functional_tensor.py | 18 ++++++++++++++++++ test/test_transforms.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 7d105c0c362..e7a9d995f22 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -605,6 +605,18 @@ def test_affine(self): batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0] ) + tensor, pil_img = data[0] + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2) + res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) + self.assertTrue(res1.equal(res2)) + + with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): + res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10) + res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10) + self.assertEqual(res1, res2) + def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): img_size = pil_img.size dt = tensor.dtype @@ -675,6 +687,12 @@ def test_rotate(self): self._test_fn_on_batch( batch_tensors, F.rotate, angle=32, interpolation=0, expand=True, center=center ) + tensor, pil_img = data[0] + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + res1 = F.rotate(tensor, 45, resample=2) + res2 = F.rotate(tensor, 45, interpolation=2) + self.assertTrue(res1.equal(res2)) def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): dt = tensor.dtype diff --git a/test/test_transforms.py b/test/test_transforms.py index f7eedd802e1..457b450dc95 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1489,6 +1489,11 @@ def test_random_rotation(self): # Checking if RandomRotation can be printed as string t.__repr__() + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + t = transforms.RandomRotation((-10, 10), resample=2) + self.assertEqual(t.interpolation, 2) + def test_random_affine(self): with self.assertRaises(ValueError): @@ -1532,6 +1537,15 @@ def test_random_affine(self): t = transforms.RandomAffine(10, interpolation=Image.BILINEAR) self.assertIn("Image.BILINEAR", t.__repr__()) + # assert deprecation warning and non-BC + with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): + t = transforms.RandomAffine(10, resample=2) + self.assertEqual(t.interpolation, 2) + + with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): + t = transforms.RandomAffine(10, fillcolor=10) + self.assertEqual(t.fill, 10) + def test_to_grayscale(self): """Unit tests for grayscale transform""" From 665350ee19cef072f0c24613b2b30e51a1dab212 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 16 Nov 2020 10:04:26 +0000 Subject: [PATCH 4/8] [WIP] Interpolation modes --- torchvision/transforms/functional.py | 37 +++++++++++++++++++++++----- torchvision/transforms/transforms.py | 4 ++- 2 files changed, 34 insertions(+), 7 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 96fdd364ab6..3700d974957 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1,6 +1,7 @@ import math import numbers import warnings +from enum import Enum from typing import Any, Optional import numpy as np @@ -19,6 +20,27 @@ from . import functional_tensor as F_t +class InterpolationModes(Enum): + """Interpolation modes + """ + NEAREST = "nearest" + BILINEAR = "bilinear" + BICUBIC = "bicubic" + # For PIL compatibility + BOX = "box" + HAMMING = "hamming" + LANCZOS = "lanczos" + + +pil_modes_mapping = { + InterpolationModes.NEAREST: 0, + InterpolationModes.BILINEAR: 2, + InterpolationModes.BICUBIC: 3, + InterpolationModes.BOX: 4, + InterpolationModes.HAMMING: 5, + InterpolationModes.LANCZOS: 1, +} + _is_pil_image = F_pil._is_pil_image _parse_fill = F_pil._parse_fill @@ -285,7 +307,7 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool return tensor -def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> Tensor: +def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = InterpolationModes.BILINEAR) -> Tensor: r"""Resize the input image to the given size. The image can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions @@ -299,17 +321,20 @@ def resize(img: Tensor, size: List[int], interpolation: int = Image.BILINEAR) -> :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. In torchscript mode size as single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes, optional): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. + Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, + ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. Returns: PIL Image or Tensor: Resized image. """ if not isinstance(img, torch.Tensor): - return F_pil.resize(img, size=size, interpolation=interpolation) + # TODO: Check and convert to PIL value + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.resize(img, size=size, interpolation=pil_interpolation) - return F_t.resize(img, size=size, interpolation=interpolation) + return F_t.resize(img, size=size, interpolation=interpolation.value) def scale(*args, **kwargs): diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index bad0ebd571d..20bc17e727b 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,12 +15,14 @@ accimage = None from . import functional as F +from .functional import InterpolationModes + __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", - "RandomPerspective", "RandomErasing", "GaussianBlur"] + "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationModes"] _pil_interpolation_to_str = { Image.NEAREST: 'PIL.Image.NEAREST', From ceadb659b5168acbf2b77e8a00691f9efcbd7f35 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 16 Nov 2020 16:22:43 +0000 Subject: [PATCH 5/8] Added InterpolationModes enum --- test/test_functional_tensor.py | 51 ++++++------ test/test_transforms.py | 10 +-- test/test_transforms_tensor.py | 6 +- torchvision/transforms/functional.py | 91 ++++++++++++++------- torchvision/transforms/functional_tensor.py | 77 ++++++----------- torchvision/transforms/transforms.py | 86 +++++++++++-------- 6 files changed, 173 insertions(+), 148 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index cc8dafd3b8f..c01865b4539 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -4,16 +4,19 @@ import math import numpy as np -from PIL.Image import NEAREST, BILINEAR, BICUBIC import torch import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_pil as F_pil import torchvision.transforms.functional as F +from torchvision.transforms import InterpolationModes from common_utils import TransformsTester +NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC + + class Tester(TransformsTester): def setUp(self): @@ -365,7 +368,7 @@ def test_adjust_gamma(self): ) def test_resize(self): - script_fn = torch.jit.script(F_t.resize) + script_fn = torch.jit.script(F.resize) tensor, pil_img = self._create_data(26, 36, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) @@ -382,8 +385,8 @@ def test_resize(self): for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: for interpolation in [BILINEAR, BICUBIC, NEAREST]: - resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation) - resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation) + resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) + resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) self.assertEqual( resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) @@ -418,13 +421,13 @@ def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity tensor, _ = self._create_data(26, 36, device=self.device) - for i in [0, 2, 3]: - out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=i) + for mode in [NEAREST, BILINEAR, BICUBIC]: + out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) # 2) resize by half and crop a TL corner tensor, _ = self._create_data(26, 36, device=self.device) - out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=0) + out_tensor = F.resized_crop(tensor, top=0, left=0, height=20, width=30, size=[10, 15], interpolation=NEAREST) expected_out_tensor = tensor[:, :20:2, :30:2] self.assertTrue( expected_out_tensor.equal(out_tensor), @@ -433,17 +436,19 @@ def test_resized_crop(self): batch_tensors = self._create_data_batch(26, 36, num_samples=4, device=self.device) self._test_fn_on_batch( - batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=0 + batch_tensors, F.resized_crop, top=1, left=2, height=20, width=30, size=[10, 15], interpolation=NEAREST ) def _test_affine_identity_map(self, tensor, scripted_affine): # 1) identity map - out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0) + out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) self.assertTrue( tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) ) - out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0) + out_tensor = scripted_affine( + tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST + ) self.assertTrue( tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5]) ) @@ -461,13 +466,13 @@ def _test_affine_square_rotations(self, tensor, pil_img, scripted_affine): ] for a, true_tensor in test_configs: out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))).to(self.device) for fn in [F.affine, scripted_affine]: out_tensor = fn( - tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ) if true_tensor is not None: self.assertTrue( @@ -496,13 +501,13 @@ def _test_affine_rect_rotations(self, tensor, pil_img, scripted_affine): for a in test_configs: out_pil_img = F.affine( - pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 + pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) for fn in [F.affine, scripted_affine]: out_tensor = fn( - tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=0 + tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST ).cpu() if out_tensor.dtype != torch.uint8: @@ -526,10 +531,10 @@ def _test_affine_translations(self, tensor, pil_img, scripted_affine): ] for t in test_configs: - out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=0) + out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) for fn in [F.affine, scripted_affine]: - out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=0) + out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST) if out_tensor.dtype != torch.uint8: out_tensor = out_tensor.to(torch.uint8) @@ -550,7 +555,7 @@ def _test_affine_all_ops(self, tensor, pil_img, scripted_affine): (-45, [-10, -10], 1.2, [4.0, 5.0]), (-90, [0, 0], 1.0, [0.0, 0.0]), ] - for r in [0, ]: + for r in [NEAREST, ]: for a, t, s, sh in test_configs: out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, interpolation=r) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) @@ -609,7 +614,7 @@ def test_affine(self): # assert deprecation warning and non-BC with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=2) - res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) + res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) self.assertTrue(res1.equal(res2)) with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): @@ -620,7 +625,7 @@ def test_affine(self): def _test_rotate_all_options(self, tensor, pil_img, scripted_rotate, centers): img_size = pil_img.size dt = tensor.dtype - for r in [0, ]: + for r in [NEAREST, ]: for a in range(-180, 180, 17): for e in [True, False]: for c in centers: @@ -685,18 +690,18 @@ def test_rotate(self): center = (20, 22) self._test_fn_on_batch( - batch_tensors, F.rotate, angle=32, interpolation=0, expand=True, center=center + batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center ) tensor, pil_img = data[0] # assert deprecation warning and non-BC with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): res1 = F.rotate(tensor, 45, resample=2) - res2 = F.rotate(tensor, 45, interpolation=2) + res2 = F.rotate(tensor, 45, interpolation=BILINEAR) self.assertTrue(res1.equal(res2)) def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): dt = tensor.dtype - for r in [0, ]: + for r in [NEAREST, ]: for spoints, epoints in test_configs: out_pil_img = F.perspective(pil_img, startpoints=spoints, endpoints=epoints, interpolation=r) out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1))) @@ -757,7 +762,7 @@ def test_perspective(self): for spoints, epoints in test_configs: self._test_fn_on_batch( - batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0 + batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=NEAREST ) def test_gaussian_blur(self): diff --git a/test/test_transforms.py b/test/test_transforms.py index 457b450dc95..8f533ba4b33 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1484,7 +1484,7 @@ def test_random_rotation(self): t = transforms.RandomRotation((-10, 10)) angle = t.get_params(t.degrees) - self.assertTrue(angle > -10 and angle < 10) + self.assertTrue(-10 < angle < 10) # Checking if RandomRotation can be printed as string t.__repr__() @@ -1492,7 +1492,7 @@ def test_random_rotation(self): # assert deprecation warning and non-BC with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): t = transforms.RandomRotation((-10, 10), resample=2) - self.assertEqual(t.interpolation, 2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) def test_random_affine(self): @@ -1534,13 +1534,13 @@ def test_random_affine(self): # Checking if RandomAffine can be printed as string t.__repr__() - t = transforms.RandomAffine(10, interpolation=Image.BILINEAR) - self.assertIn("Image.BILINEAR", t.__repr__()) + t = transforms.RandomAffine(10, interpolation=transforms.InterpolationModes.BILINEAR) + self.assertIn("bilinear", t.__repr__()) # assert deprecation warning and non-BC with self.assertWarnsRegex(UserWarning, r"Argument resample is deprecated and will be removed"): t = transforms.RandomAffine(10, resample=2) - self.assertEqual(t.interpolation, 2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): t = transforms.RandomAffine(10, fillcolor=10) diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 9e45fd0d7fb..ad5d303d36d 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -2,8 +2,7 @@ import torch from torchvision import transforms as T from torchvision.transforms import functional as F - -from PIL.Image import NEAREST, BILINEAR, BICUBIC +from torchvision.transforms import InterpolationModes import numpy as np @@ -12,6 +11,9 @@ from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes +NEAREST, BILINEAR, BICUBIC = InterpolationModes.NEAREST, InterpolationModes.BILINEAR, InterpolationModes.BICUBIC + + class Tester(TransformsTester): def setUp(self): diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 3700d974957..f0f931a73f3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -321,7 +321,7 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = Int :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. In torchscript mode size as single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (InterpolationModes, optional): Desired interpolation enum defined by + interpolation (InterpolationModes): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. @@ -329,8 +329,10 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = Int Returns: PIL Image or Tensor: Resized image. """ + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if not isinstance(img, torch.Tensor): - # TODO: Check and convert to PIL value pil_interpolation = pil_modes_mapping[interpolation] return F_pil.resize(img, size=size, interpolation=pil_interpolation) @@ -441,7 +443,8 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: def resized_crop( - img: Tensor, top: int, left: int, height: int, width: int, size: List[int], interpolation: int = Image.BILINEAR + img: Tensor, top: int, left: int, height: int, width: int, size: List[int], + interpolation: InterpolationModes = InterpolationModes.BILINEAR ) -> Tensor: """Crop the given image and resize it to desired size. The image can be a PIL Image or a Tensor, in which case it is expected @@ -456,9 +459,11 @@ def resized_crop( height (int): Height of the crop box. width (int): Width of the crop box. size (sequence or int): Desired output size. Same semantics as ``resize``. - interpolation (int, optional): Desired interpolation enum defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. + Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, + ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. + Returns: PIL Image or Tensor: Cropped image. """ @@ -519,7 +524,7 @@ def perspective( img: Tensor, startpoints: List[List[int]], endpoints: List[List[int]], - interpolation: int = 2, + interpolation: InterpolationModes = InterpolationModes.BILINEAR, fill: Optional[int] = None ) -> Tensor: """Perform perspective transform of the given image. @@ -532,8 +537,9 @@ def perspective( ``[top-left, top-right, bottom-right, bottom-left]`` of the original image. endpoints (list of list of ints): List containing four lists of two integers corresponding to four corners ``[top-left, top-right, bottom-right, bottom-left]`` of the transformed image. - interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and - ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor @@ -545,10 +551,14 @@ def perspective( coeffs = _get_perspective_coeffs(startpoints, endpoints) + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if not isinstance(img, torch.Tensor): - return F_pil.perspective(img, coeffs, interpolation=interpolation, fill=fill) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.perspective(img, coeffs, interpolation=pil_interpolation, fill=fill) - return F_t.perspective(img, coeffs, interpolation=interpolation, fill=fill) + return F_t.perspective(img, coeffs, interpolation=interpolation.value, fill=fill) def vflip(img: Tensor) -> Tensor: @@ -818,8 +828,9 @@ def _get_inverse_affine_matrix( def rotate( - img: Tensor, angle: float, interpolation: int = 0, expand: bool = False, - center: Optional[List[int]] = None, fill: Optional[int] = None, resample: Optional[int] = None + img: Tensor, angle: float, interpolation: InterpolationModes = InterpolationModes.NEAREST, + expand: bool = False, center: Optional[List[int]] = None, + fill: Optional[int] = None, resample: Optional[int] = None ) -> Tensor: """Rotate the image by angle. The image can be a PIL Image or a Tensor, in which case it is expected @@ -828,9 +839,9 @@ def rotate( Args: img (PIL Image or Tensor): image to be rotated. angle (float or int): rotation angle value in degrees, counter-clockwise. - interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -855,7 +866,15 @@ def rotate( warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - interpolation = resample + inverse_modes_mapping = { + 0: InterpolationModes.NEAREST, + 2: InterpolationModes.BILINEAR, + 3: InterpolationModes.BICUBIC, + 4: InterpolationModes.BOX, + 5: InterpolationModes.HAMMING, + 1: InterpolationModes.LANCZOS, + } + interpolation = inverse_modes_mapping[resample] if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") @@ -863,8 +882,12 @@ def rotate( if center is not None and not isinstance(center, (list, tuple)): raise TypeError("Argument center should be a sequence") + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if not isinstance(img, torch.Tensor): - return F_pil.rotate(img, angle=angle, interpolation=interpolation, expand=expand, center=center, fill=fill) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) center_f = [0.0, 0.0] if center is not None: @@ -875,13 +898,13 @@ def rotate( # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) - return F_t.rotate(img, matrix=matrix, interpolation=interpolation, expand=expand, fill=fill) + return F_t.rotate(img, matrix=matrix, interpolation=interpolation.value, expand=expand, fill=fill) def affine( img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float], - interpolation: int = 0, fill: Optional[int] = None, resample: Optional[int] = None, - fillcolor: Optional[int] = None + interpolation: InterpolationModes = InterpolationModes.NEAREST, fill: Optional[int] = None, + resample: Optional[int] = None, fillcolor: Optional[int] = None ) -> Tensor: """Apply affine transformation on the image keeping image center invariant. The image can be a PIL Image or a Tensor, in which case it is expected @@ -895,10 +918,9 @@ def affine( shear (float or tuple or list): shear angle value in degrees between -180 to 180, clockwise direction. If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while the second value corresponds to a shear parallel to the y axis. - interpolation (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): - An optional resampling filter. See `filters`_ for more information. - If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. @@ -914,7 +936,15 @@ def affine( warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - interpolation = resample + inverse_modes_mapping = { + 0: InterpolationModes.NEAREST, + 2: InterpolationModes.BILINEAR, + 3: InterpolationModes.BICUBIC, + 4: InterpolationModes.BOX, + 5: InterpolationModes.HAMMING, + 1: InterpolationModes.LANCZOS, + } + interpolation = inverse_modes_mapping[resample] if fillcolor is not None: warnings.warn( @@ -937,6 +967,9 @@ def affine( if not isinstance(shear, (numbers.Number, (list, tuple))): raise TypeError("Shear should be either a single value or a sequence of two values") + if not isinstance(interpolation, InterpolationModes): + raise TypeError("Argument interpolation should be a InterpolationModes") + if isinstance(angle, int): angle = float(angle) @@ -962,12 +995,12 @@ def affine( # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine center = [img_size[0] * 0.5, img_size[1] * 0.5] matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) - - return F_pil.affine(img, matrix=matrix, interpolation=interpolation, fill=fill) + pil_interpolation = pil_modes_mapping[interpolation] + return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill) translate_f = [1.0 * t for t in translate] matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear) - return F_t.affine(img, matrix=matrix, interpolation=interpolation, fill=fill) + return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill) @torch.jit.unused diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 991bfbdf816..da8fdf53bf7 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -757,7 +757,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con return img -def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: +def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor: r"""PRIVATE METHOD. Resize the input Tensor to the given size. .. warning:: @@ -774,8 +774,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. In torchscript mode padding as a single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values: - nearest(=0) and bicubic(=3). + interpolation (str): Desired interpolation. Default is "bilinear". Other supported values: + "nearest" and "bicubic". Returns: Tensor: Resized image. @@ -785,16 +785,10 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: if not isinstance(size, (int, tuple, list)): raise TypeError("Got inappropriate size arg") - if not isinstance(interpolation, int): + if not isinstance(interpolation, str): raise TypeError("Got inappropriate interpolation arg") - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - 3: "bicubic", - } - - if interpolation not in _interpolation_modes: + if interpolation not in ["nearest", "bilinear", "bicubic"]: raise ValueError("This interpolation mode is unsupported with Tensor input") if isinstance(size, tuple): @@ -822,16 +816,14 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: if (w <= h and w == size_w) or (h <= w and h == size_h): return img - mode = _interpolation_modes[interpolation] - img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) # Define align_corners to avoid warnings - align_corners = False if mode in ["bilinear", "bicubic"] else None + align_corners = False if interpolation in ["bilinear", "bicubic"] else None - img = interpolate(img, size=[size_h, size_w], mode=mode, align_corners=align_corners) + img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners) - if mode == "bicubic" and out_dtype == torch.uint8: + if interpolation == "bicubic" and out_dtype == torch.uint8: img = img.clamp(min=0, max=255) img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) @@ -842,9 +834,9 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: def _assert_grid_transform_inputs( img: Tensor, matrix: Optional[List[float]], - interpolation: int, + interpolation: str, fill: Optional[int], - _interpolation_modes: Dict[int, str], + supported_interpolation_modes: List[str], coeffs: Optional[List[float]] = None, ): if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): @@ -862,7 +854,7 @@ def _assert_grid_transform_inputs( if fill is not None: warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero") - if interpolation not in _interpolation_modes: + if interpolation not in supported_interpolation_modes: raise ValueError("Interpolation mode '{}' is unsupported with Tensor input".format(interpolation)) @@ -931,7 +923,7 @@ def _gen_affine_grid( def affine( - img: Tensor, matrix: List[float], interpolation: int = 0, fill: Optional[int] = None + img: Tensor, matrix: List[float], interpolation: str = "nearest", fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Apply affine transformation on the Tensor image keeping image center invariant. @@ -943,28 +935,21 @@ def affine( Args: img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation. - interpolation (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: - bilinear(=2). + interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear". fill (int, optional): this option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. Returns: Tensor: Transformed image. """ - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } - - _assert_grid_transform_inputs(img, matrix, interpolation, fill, _interpolation_modes) + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) shape = img.shape # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=shape[-1], h=shape[-2], ow=shape[-1], oh=shape[-2]) - mode = _interpolation_modes[interpolation] - return _apply_grid_transform(img, grid, mode) + return _apply_grid_transform(img, grid, interpolation) def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]: @@ -993,7 +978,8 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int] def rotate( - img: Tensor, matrix: List[float], interpolation: int = 0, expand: bool = False, fill: Optional[int] = None + img: Tensor, matrix: List[float], interpolation: str = "nearest", + expand: bool = False, fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Rotate the Tensor image by angle. @@ -1006,8 +992,7 @@ def rotate( img (Tensor): image to be rotated. matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation. Translation part (``matrix[2]`` and ``matrix[5]``) should be in pixel coordinates. - interpolation (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values: - bilinear(=2). + interpolation (str): An optional resampling filter. Default is "nearest". Other supported values: "bilinear". expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1021,21 +1006,14 @@ def rotate( .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters """ - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } - - _assert_grid_transform_inputs(img, matrix, interpolation, fill, _interpolation_modes) + _assert_grid_transform_inputs(img, matrix, interpolation, fill, ["nearest", "bilinear"]) w, h = img.shape[-1], img.shape[-2] ow, oh = _compute_output_size(matrix, w, h) if expand else (w, h) dtype = img.dtype if torch.is_floating_point(img) else torch.float32 theta = torch.tensor(matrix, dtype=dtype, device=img.device).reshape(1, 2, 3) # grid will be generated on the same device as theta and img grid = _gen_affine_grid(theta, w=w, h=h, ow=ow, oh=oh) - mode = _interpolation_modes[interpolation] - - return _apply_grid_transform(img, grid, mode) + return _apply_grid_transform(img, grid, interpolation) def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, device: torch.device): @@ -1072,7 +1050,7 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, def perspective( - img: Tensor, perspective_coeffs: List[float], interpolation: int = 2, fill: Optional[int] = None + img: Tensor, perspective_coeffs: List[float], interpolation: str = "bilinear", fill: Optional[int] = None ) -> Tensor: """PRIVATE METHOD. Perform perspective transform of the given Tensor image. @@ -1084,7 +1062,7 @@ def perspective( Args: img (Tensor): Image to be transformed. perspective_coeffs (list of float): perspective transformation coefficients. - interpolation (int): Interpolation type. Default, ``PIL.Image.BILINEAR``. + interpolation (str): Interpolation type. Default, "bilinear". fill (n-tuple or int or float): this option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. @@ -1094,26 +1072,19 @@ def perspective( if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)): raise TypeError('Input img should be Tensor Image') - _interpolation_modes = { - 0: "nearest", - 2: "bilinear", - } - _assert_grid_transform_inputs( img, matrix=None, interpolation=interpolation, fill=fill, - _interpolation_modes=_interpolation_modes, + supported_interpolation_modes=["nearest", "bilinear"], coeffs=perspective_coeffs ) ow, oh = img.shape[-1], img.shape[-2] dtype = img.dtype if torch.is_floating_point(img) else torch.float32 grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=img.device) - mode = _interpolation_modes[interpolation] - - return _apply_grid_transform(img, grid, mode) + return _apply_grid_transform(img, grid, interpolation) def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 20bc17e727b..ebbfcbfcfc1 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -24,15 +24,6 @@ "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationModes"] -_pil_interpolation_to_str = { - Image.NEAREST: 'PIL.Image.NEAREST', - Image.BILINEAR: 'PIL.Image.BILINEAR', - Image.BICUBIC: 'PIL.Image.BICUBIC', - Image.LANCZOS: 'PIL.Image.LANCZOS', - Image.HAMMING: 'PIL.Image.HAMMING', - Image.BOX: 'PIL.Image.BOX', -} - class Compose: """Composes several transforms together. This transform does not support torchscript. @@ -244,12 +235,14 @@ class Resize(torch.nn.Module): (size * height / width, size). In torchscript mode padding as single int is not supported, use a tuple or list of length 1: ``[size, ]``. - interpolation (int): Interpolation type defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and + ``InterpolationModes.BICUBIC`` are supported. + """ - def __init__(self, size, interpolation=Image.BILINEAR): + def __init__(self, size, interpolation=InterpolationModes.BILINEAR): super().__init__() if not isinstance(size, (int, Sequence)): raise TypeError("Size should be int or sequence. Got {}".format(type(size))) @@ -269,7 +262,7 @@ def forward(self, img): return F.resize(img, self.size, self.interpolation) def __repr__(self): - interpolate_str = _pil_interpolation_to_str[self.interpolation] + interpolate_str = self.interpolation.value return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) @@ -661,16 +654,16 @@ class RandomPerspective(torch.nn.Module): distortion_scale (float): argument to control the degree of distortion and ranges from 0 to 1. Default is 0.5. p (float): probability of the image being transformed. Default is 0.5. - interpolation (int): Interpolation type. If input is Tensor, only ``PIL.Image.NEAREST`` and - ``PIL.Image.BILINEAR`` are supported. Default, ``PIL.Image.BILINEAR`` for PIL images and Tensors. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. Default is 0. This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. - """ - def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BILINEAR, fill=0): + def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationModes.BILINEAR, fill=0): super().__init__() self.p = p self.interpolation = interpolation @@ -746,12 +739,14 @@ class RandomResizedCrop(torch.nn.Module): made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). scale (tuple of float): scale range of the cropped image before resizing, relatively to the origin image. ratio (tuple of float): aspect ratio range of the cropped image before resizing. - interpolation (int): Interpolation type defined by `filters`_. - Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` - and ``PIL.Image.BICUBIC`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and + ``InterpolationModes.BICUBIC`` are supported. + """ - def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): + def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=InterpolationModes.BILINEAR): super().__init__() self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") @@ -826,7 +821,7 @@ def forward(self, img): return F.resized_crop(img, i, j, h, w, self.size, self.interpolation) def __repr__(self): - interpolate_str = _pil_interpolation_to_str[self.interpolation] + interpolate_str = self.interpolation.value format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) @@ -1136,9 +1131,9 @@ class RandomRotation(torch.nn.Module): degrees (sequence or float or int): Range of degrees to select from. If degrees is a number instead of sequence like (min, max), the range of degrees will be (-degrees, +degrees). - interpolation (int): Interpolation type defined by `filters`_. - If omitted, or if the image has mode "1" or "P", it is set to PIL.Image.NEAREST. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. expand (bool, optional): Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1157,13 +1152,23 @@ class RandomRotation(torch.nn.Module): """ - def __init__(self, degrees, interpolation=0, expand=False, center=None, fill=None, resample=None): + def __init__( + self, degrees, interpolation=InterpolationModes.NEAREST, expand=False, center=None, fill=None, resample=None + ): super().__init__() if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - interpolation = resample + inverse_modes_mapping = { + 0: InterpolationModes.NEAREST, + 2: InterpolationModes.BILINEAR, + 3: InterpolationModes.BICUBIC, + 4: InterpolationModes.BOX, + 5: InterpolationModes.HAMMING, + 1: InterpolationModes.LANCZOS, + } + interpolation = inverse_modes_mapping[resample] self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) @@ -1198,7 +1203,7 @@ def forward(self, img): return F.rotate(img, angle, self.interpolation, self.expand, self.center, self.fill) def __repr__(self): - interpolate_str = _pil_interpolation_to_str[self.interpolation] + interpolate_str = self.interpolation.value format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees) format_string += ', interpolation={0}'.format(interpolate_str) format_string += ', expand={0}'.format(self.expand) @@ -1231,9 +1236,9 @@ class RandomAffine(torch.nn.Module): range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values, a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied. Will not apply shear by default. - interpolation (int): Interpolation type defined by `filters`_. - If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. - If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported. + interpolation (InterpolationModes): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. + If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. @@ -1247,14 +1252,23 @@ class RandomAffine(torch.nn.Module): """ def __init__( - self, degrees, translate=None, scale=None, shear=None, interpolation=0, fill=0, fillcolor=None, resample=None + self, degrees, translate=None, scale=None, shear=None, interpolation=InterpolationModes.NEAREST, fill=0, + fillcolor=None, resample=None ): super().__init__() if resample is not None: warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - interpolation = resample + inverse_modes_mapping = { + 0: InterpolationModes.NEAREST, + 2: InterpolationModes.BILINEAR, + 3: InterpolationModes.BICUBIC, + 4: InterpolationModes.BOX, + 5: InterpolationModes.HAMMING, + 1: InterpolationModes.LANCZOS, + } + interpolation = inverse_modes_mapping[resample] if fillcolor is not None: warnings.warn( @@ -1345,13 +1359,13 @@ def __repr__(self): s += ', scale={scale}' if self.shear is not None: s += ', shear={shear}' - if self.interpolation > 0: + if self.interpolation != InterpolationModes.NEAREST: s += ', interpolation={interpolation}' if self.fill != 0: s += ', fill={fill}' s += ')' d = dict(self.__dict__) - d['interpolation'] = _pil_interpolation_to_str[d['interpolation']] + d['interpolation'] = self.interpolation.value return s.format(name=self.__class__.__name__, **d) From 37685193ec9ccf2ab6630810595422af99fc0471 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 20 Nov 2020 13:26:54 +0000 Subject: [PATCH 6/8] Added supported for int values for interpolation for BC --- test/test_functional_tensor.py | 34 +++++++++- test/test_transforms.py | 11 ++++ torchvision/transforms/functional.py | 71 +++++++++++++++------ torchvision/transforms/functional_tensor.py | 2 +- torchvision/transforms/transforms.py | 70 ++++++++++++++------ 5 files changed, 148 insertions(+), 40 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index c01865b4539..643915b2381 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -392,7 +392,7 @@ def test_resize(self): resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) ) - if interpolation != NEAREST: + if interpolation not in [NEAREST, ]: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] @@ -410,6 +410,11 @@ def test_resize(self): script_size = [size, ] else: script_size = size + + # skip test if interpolation is int + if isinstance(interpolation, int): + continue + resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) @@ -417,10 +422,17 @@ def test_resize(self): batch_tensors, F.resize, size=script_size, interpolation=interpolation ) + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.resize(tensor, size=32, interpolation=2) + res2 = F.resize(tensor, size=32, interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + def test_resized_crop(self): # test values of F.resized_crop in several cases: # 1) resize to the same size, crop to the same size => should be identity tensor, _ = self._create_data(26, 36, device=self.device) + for mode in [NEAREST, BILINEAR, BICUBIC]: out_tensor = F.resized_crop(tensor, top=0, left=0, height=26, width=36, size=[26, 36], interpolation=mode) self.assertTrue(tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])) @@ -617,6 +629,12 @@ def test_affine(self): res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) self.assertTrue(res1.equal(res2)) + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) + res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + with self.assertWarnsRegex(UserWarning, r"Argument fillcolor is deprecated and will be removed"): res1 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fillcolor=10) res2 = F.affine(pil_img, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], fill=10) @@ -699,6 +717,12 @@ def test_rotate(self): res2 = F.rotate(tensor, 45, interpolation=BILINEAR) self.assertTrue(res1.equal(res2)) + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.rotate(tensor, 45, interpolation=2) + res2 = F.rotate(tensor, 45, interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + def _test_perspective(self, tensor, pil_img, scripted_transform, test_configs): dt = tensor.dtype for r in [NEAREST, ]: @@ -765,6 +789,14 @@ def test_perspective(self): batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=NEAREST ) + # assert changed type warning + spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] + epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2) + res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR) + self.assertTrue(res1.equal(res2)) + def test_gaussian_blur(self): small_image_tensor = torch.from_numpy( np.arange(3 * 10 * 12, dtype="uint8").reshape((10, 12, 3)) diff --git a/test/test_transforms.py b/test/test_transforms.py index 8f533ba4b33..cf49bb29097 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1494,6 +1494,11 @@ def test_random_rotation(self): t = transforms.RandomRotation((-10, 10), resample=2) self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + t = transforms.RandomRotation((-10, 10), interpolation=2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) + def test_random_affine(self): with self.assertRaises(ValueError): @@ -1546,6 +1551,12 @@ def test_random_affine(self): t = transforms.RandomAffine(10, fillcolor=10) self.assertEqual(t.fill, 10) + # assert changed type warning + with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationModes"): + t = transforms.RandomAffine(10, interpolation=2) + self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) + + def test_to_grayscale(self): """Unit tests for grayscale transform""" diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index f0f931a73f3..840de1107c2 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -32,6 +32,20 @@ class InterpolationModes(Enum): LANCZOS = "lanczos" +# TODO: Once torchscript supports Enums with staticmethod +# this can be put into InterpolationModes as staticmethod +def _interpolation_modes_from_int(i: int) -> InterpolationModes: + inverse_modes_mapping = { + 0: InterpolationModes.NEAREST, + 2: InterpolationModes.BILINEAR, + 3: InterpolationModes.BICUBIC, + 4: InterpolationModes.BOX, + 5: InterpolationModes.HAMMING, + 1: InterpolationModes.LANCZOS, + } + return inverse_modes_mapping[i] + + pil_modes_mapping = { InterpolationModes.NEAREST: 0, InterpolationModes.BILINEAR: 2, @@ -325,10 +339,19 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationModes = Int :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. Returns: PIL Image or Tensor: Resized image. """ + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + if not isinstance(interpolation, InterpolationModes): raise TypeError("Argument interpolation should be a InterpolationModes") @@ -463,6 +486,7 @@ def resized_crop( :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. Returns: PIL Image or Tensor: Cropped image. @@ -540,6 +564,7 @@ def perspective( interpolation (InterpolationModes): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor @@ -551,6 +576,14 @@ def perspective( coeffs = _get_perspective_coeffs(startpoints, endpoints) + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + if not isinstance(interpolation, InterpolationModes): raise TypeError("Argument interpolation should be a InterpolationModes") @@ -842,6 +875,7 @@ def rotate( interpolation (InterpolationModes): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -866,15 +900,15 @@ def rotate( warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - inverse_modes_mapping = { - 0: InterpolationModes.NEAREST, - 2: InterpolationModes.BILINEAR, - 3: InterpolationModes.BICUBIC, - 4: InterpolationModes.BOX, - 5: InterpolationModes.HAMMING, - 1: InterpolationModes.LANCZOS, - } - interpolation = inverse_modes_mapping[resample] + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") @@ -921,6 +955,7 @@ def affine( interpolation (InterpolationModes): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (int): Optional fill color for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. @@ -936,15 +971,15 @@ def affine( warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - inverse_modes_mapping = { - 0: InterpolationModes.NEAREST, - 2: InterpolationModes.BILINEAR, - 3: InterpolationModes.BICUBIC, - 4: InterpolationModes.BOX, - 5: InterpolationModes.HAMMING, - 1: InterpolationModes.LANCZOS, - } - interpolation = inverse_modes_mapping[resample] + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) if fillcolor is not None: warnings.warn( diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index da8fdf53bf7..4f3e72a62ce 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -851,7 +851,7 @@ def _assert_grid_transform_inputs( if coeffs is not None and len(coeffs) != 8: raise ValueError("Argument coeffs should have 8 float values") - if fill is not None: + if fill is not None and not (isinstance(fill, (int, float)) and fill == 0): warnings.warn("Argument fill is not supported for Tensor input. Fill value is zero") if interpolation not in supported_interpolation_modes: diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index ebbfcbfcfc1..8889201c7fe 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -6,7 +6,6 @@ from typing import Tuple, List, Optional import torch -from PIL import Image from torch import Tensor try: @@ -15,7 +14,7 @@ accimage = None from . import functional as F -from .functional import InterpolationModes +from .functional import InterpolationModes, _interpolation_modes_from_int __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale", @@ -239,6 +238,7 @@ class Resize(torch.nn.Module): :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. """ @@ -249,6 +249,15 @@ def __init__(self, size, interpolation=InterpolationModes.BILINEAR): if isinstance(size, Sequence) and len(size) not in (1, 2): raise ValueError("If size is a sequence, it should have 1 or 2 values") self.size = size + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation def forward(self, img): @@ -657,6 +666,7 @@ class RandomPerspective(torch.nn.Module): interpolation (InterpolationModes): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (n-tuple or int or float): Pixel fill value for area outside the rotated image. If int or float, the value is used for all bands respectively. Default is 0. This option is only available for ``pillow>=5.0.0``. This option is not supported for Tensor @@ -666,6 +676,15 @@ class RandomPerspective(torch.nn.Module): def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationModes.BILINEAR, fill=0): super().__init__() self.p = p + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation self.distortion_scale = distortion_scale self.fill = fill @@ -743,6 +762,7 @@ class RandomResizedCrop(torch.nn.Module): :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.BILINEAR``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` and ``InterpolationModes.BICUBIC`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. """ @@ -757,6 +777,14 @@ def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolat if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation self.scale = scale self.ratio = ratio @@ -1134,6 +1162,7 @@ class RandomRotation(torch.nn.Module): interpolation (InterpolationModes): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. expand (bool, optional): Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1160,15 +1189,15 @@ def __init__( warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - inverse_modes_mapping = { - 0: InterpolationModes.NEAREST, - 2: InterpolationModes.BILINEAR, - 3: InterpolationModes.BICUBIC, - 4: InterpolationModes.BOX, - 5: InterpolationModes.HAMMING, - 1: InterpolationModes.LANCZOS, - } - interpolation = inverse_modes_mapping[resample] + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2, )) @@ -1239,6 +1268,7 @@ class RandomAffine(torch.nn.Module): interpolation (InterpolationModes): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationModes`. Default is ``InterpolationModes.NEAREST``. If input is Tensor, only ``InterpolationModes.NEAREST``, ``InterpolationModes.BILINEAR`` are supported. + For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. fill (tuple or int): Optional fill color (Tuple for RGB Image and int for grayscale) for the area outside the transform in the output image (Pillow>=5.0.0). This option is not supported for Tensor input. Fill value for the area outside the transform in the output image is always 0. @@ -1260,15 +1290,15 @@ def __init__( warnings.warn( "Argument resample is deprecated and will be removed since v0.10.0. Please, use interpolation instead" ) - inverse_modes_mapping = { - 0: InterpolationModes.NEAREST, - 2: InterpolationModes.BILINEAR, - 3: InterpolationModes.BICUBIC, - 4: InterpolationModes.BOX, - 5: InterpolationModes.HAMMING, - 1: InterpolationModes.LANCZOS, - } - interpolation = inverse_modes_mapping[resample] + interpolation = _interpolation_modes_from_int(resample) + + # Backward compatibility with integer value + if isinstance(interpolation, int): + warnings.warn( + "Argument interpolation should be of type InterpolationModes instead of int. " + "Please, use InterpolationModes enum." + ) + interpolation = _interpolation_modes_from_int(interpolation) if fillcolor is not None: warnings.warn( From cf3d711e5a589b2ab442c043d65103ad2c6af39a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 20 Nov 2020 13:32:49 +0000 Subject: [PATCH 7/8] Removed useless test code --- test/test_functional_tensor.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 643915b2381..d181dd94be2 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -411,10 +411,6 @@ def test_resize(self): else: script_size = size - # skip test if interpolation is int - if isinstance(interpolation, int): - continue - resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) From 690750590d79bee074fa15840ca72945ed565d67 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 20 Nov 2020 13:52:16 +0000 Subject: [PATCH 8/8] Fix flake8 --- test/test_transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 0fbe1f3e7b9..f113f9ee653 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1564,7 +1564,6 @@ def test_random_affine(self): t = transforms.RandomAffine(10, interpolation=2) self.assertEqual(t.interpolation, transforms.InterpolationModes.BILINEAR) - def test_to_grayscale(self): """Unit tests for grayscale transform"""