diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 3b69b72dd4f..b416dae20e0 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -149,6 +149,8 @@ def __init__( ArgsKwargs(num_output_channels=3), ], make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), + # Use default tolerances of `torch.testing.assert_close` + closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( prototype_transforms.ConvertDtype, @@ -271,6 +273,9 @@ def __init__( ArgsKwargs(p=0), ArgsKwargs(p=1), ], + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), + # Use default tolerances of `torch.testing.assert_close` + closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( prototype_transforms.RandomResizedCrop, diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index d6472301e99..5c50542b07d 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -230,6 +230,9 @@ def elastic( ) -> Datapoint: return self + def to_grayscale(self, num_output_channels: int = 1) -> Datapoint: + return self + def adjust_brightness(self, brightness_factor: float) -> Datapoint: return self diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index ece95169ac3..0b2ab7453ff 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -169,6 +169,12 @@ def elastic( ) return Image.wrap_like(self, output) + def to_grayscale(self, num_output_channels: int = 1) -> Image: + output = self._F.rgb_to_grayscale_image_tensor( + self.as_subclass(torch.Tensor), num_output_channels=num_output_channels + ) + return Image.wrap_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Image: output = self._F.adjust_brightness_image_tensor( self.as_subclass(torch.Tensor), brightness_factor=brightness_factor diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 5a73d35368a..50f9110f40c 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -173,6 +173,12 @@ def elastic( ) return Video.wrap_like(self, output) + def to_grayscale(self, num_output_channels: int = 1) -> Video: + output = self._F.rgb_to_grayscale_image_tensor( + self.as_subclass(torch.Tensor), num_output_channels=num_output_channels + ) + return Video.wrap_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Video: output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor) return Video.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index fa75cf63339..132edb1b6fc 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -9,9 +9,11 @@ from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, + Grayscale, RandomAdjustSharpness, RandomAutocontrast, RandomEqualize, + RandomGrayscale, RandomInvert, RandomPhotometricDistort, RandomPosterize, @@ -54,4 +56,4 @@ from ._temporal import UniformTemporalSubsample from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage -from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip +from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 0eb20e57764..6ab997b1e93 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -11,6 +11,41 @@ from .utils import is_simple_tensor, query_chw +class Grayscale(Transform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, num_output_channels: int = 1): + super().__init__() + self.num_output_channels = num_output_channels + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) + + +class RandomGrayscale(_RandomApplyTransform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, p: float = 0.1) -> None: + super().__init__(p=p) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_input_channels, *_ = query_chw(flat_inputs) + return dict(num_input_channels=num_input_channels) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) + + class ColorJitter(Transform): def __init__( self, diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 974fe2b2741..cd37f4d73d0 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -1,17 +1,12 @@ import warnings -from typing import Any, Dict, List, Union +from typing import Any, Dict, Union import numpy as np import PIL.Image import torch -from torchvision.prototype import datapoints from torchvision.prototype.transforms import Transform from torchvision.transforms import functional as _F -from typing_extensions import Literal - -from ._transform import _RandomApplyTransform -from .utils import is_simple_tensor, query_chw class ToTensor(Transform): @@ -26,78 +21,3 @@ def __init__(self) -> None: def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: return _F.to_tensor(inpt) - - -# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray? -class Grayscale(Transform): - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: - deprecation_msg = ( - f"The transform `Grayscale(num_output_channels={num_output_channels})` " - f"is deprecated and will be removed in a future release." - ) - if num_output_channels == 1: - replacement_msg = ( - "transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)" - ) - else: - replacement_msg = ( - "transforms.Compose(\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" - ")" - ) - warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}") - - super().__init__() - self.num_output_channels = num_output_channels - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output - - -class RandomGrayscale(_RandomApplyTransform): - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, p: float = 0.1) -> None: - warnings.warn( - "The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. " - "Instead, please use\n\n" - "transforms.RandomApply(\n" - " transforms.Compose(\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" - " )\n" - " p=...,\n" - ")" - ) - - super().__init__(p=p) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - num_input_channels, *_ = query_chw(flat_inputs) - return dict(num_input_channels=num_input_channels) - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 57b4cc4423a..0909b763441 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -71,6 +71,9 @@ posterize_image_pil, posterize_image_tensor, posterize_video, + rgb_to_grayscale, + rgb_to_grayscale_image_pil, + rgb_to_grayscale_image_tensor, solarize, solarize_image_pil, solarize_image_tensor, @@ -167,4 +170,4 @@ from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image -from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip +from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 53de1f407c8..719bd801e74 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,3 +1,5 @@ +from typing import Union + import PIL.Image import torch from torch.nn.functional import conv2d @@ -7,10 +9,53 @@ from torchvision.utils import _log_api_usage_once -from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor +from ._meta import _num_value_bits, convert_dtype_image_tensor from ._utils import is_simple_tensor +def _rgb_to_grayscale_image_tensor( + image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True +) -> torch.Tensor: + if image.shape[-3] == 1: + return image.clone() + + r, g, b = image.unbind(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.unsqueeze(dim=-3) + if preserve_dtype: + l_img = l_img.to(image.dtype) + if num_output_channels == 3: + l_img = l_img.expand(image.shape) + return l_img + + +def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) + + +rgb_to_grayscale_image_pil = _FP.to_grayscale + + +def rgb_to_grayscale( + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 +) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(rgb_to_grayscale) + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.to_grayscale(num_output_channels=num_output_channels) + elif isinstance(inpt, PIL.Image.Image): + return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) fp = image1.is_floating_point() @@ -68,7 +113,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float if c == 1: # Match PIL behaviour return image - grayscale_image = _rgb_to_gray(image, cast=False) + grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) if not image.is_floating_point(): grayscale_image = grayscale_image.floor_() @@ -110,7 +155,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") fp = image.is_floating_point() if c == 3: - grayscale_image = _rgb_to_gray(image, cast=False) + grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) if not fp: grayscale_image = grayscale_image.floor_() else: diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index a89bcae7b90..09870216059 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -7,8 +7,6 @@ from torchvision.prototype import datapoints from torchvision.transforms import functional as _F -from ._utils import is_simple_tensor - @torch.jit.unused def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: @@ -24,33 +22,6 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima return _F.to_grayscale(inpt, num_output_channels=num_output_channels) -def rgb_to_grayscale( - inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 -) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: - old_color_space = None # TODO: remove when un-deprecating - if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance( - inpt, (datapoints.Image, datapoints.Video) - ): - inpt = inpt.as_subclass(torch.Tensor) - - call = ", num_output_channels=3" if num_output_channels == 3 else "" - replacement = ( - f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY" - f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})" - ) - if num_output_channels == 3: - replacement = ( - f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB" - f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})" - ) - warnings.warn( - f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. " - f"Instead, please use `{replacement}`.", - ) - - return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) - - @torch.jit.unused def to_tensor(inpt: Any) -> torch.Tensor: warnings.warn( diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index b76dc7d7b68..31d86bec256 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -225,15 +225,6 @@ def clamp_bounding_box( return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) -def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: - r, g, b = image.unbind(dim=-3) - l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) - if cast: - l_img = l_img.to(image.dtype) - l_img = l_img.unsqueeze(dim=-3) - return l_img - - def _num_value_bits(dtype: torch.dtype) -> int: if dtype == torch.uint8: return 8