From 7cad286c152cad40439421b2fabe2376c15586af Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 14:00:31 +0000 Subject: [PATCH 1/9] Undeprecate ToGrayScale transforms and functionals --- torchvision/prototype/datapoints/_image.py | 8 +- torchvision/prototype/datapoints/_video.py | 8 +- torchvision/prototype/transforms/__init__.py | 4 +- torchvision/prototype/transforms/_color.py | 61 +++++++++++++- .../prototype/transforms/_deprecated.py | 82 +------------------ .../transforms/functional/__init__.py | 3 +- .../prototype/transforms/functional/_color.py | 42 +++++++++- .../transforms/functional/_deprecated.py | 29 ------- .../prototype/transforms/functional/_meta.py | 9 -- 9 files changed, 119 insertions(+), 127 deletions(-) diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index ece95169ac3..2c4a9bab7a2 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union import PIL.Image import torch @@ -169,6 +169,12 @@ def elastic( ) return Image.wrap_like(self, output) + def to_grayscale(self, num_output_channels: Literal[1, 3] = 1) -> Image: + output = self._F.rgb_to_grayscale_tensor( + self.as_subclass(torch.Tensor), num_output_channels=num_output_channels + ) + return Image.wrap_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Image: output = self._F.adjust_brightness_image_tensor( self.as_subclass(torch.Tensor), brightness_factor=brightness_factor diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 5a73d35368a..179bfa72f8a 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Optional, Tuple, Union +from typing import Any, List, Literal, Optional, Tuple, Union import torch from torchvision.transforms.functional import InterpolationMode @@ -173,6 +173,12 @@ def elastic( ) return Video.wrap_like(self, output) + def to_grayscale(self, num_output_channels: Literal[1, 3] = 1) -> Video: + output = self._F.rgb_to_grayscale_tensor( + self.as_subclass(torch.Tensor), num_output_channels=num_output_channels + ) + return Video.wrap_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Video: output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor) return Video.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index fa75cf63339..132edb1b6fc 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -9,9 +9,11 @@ from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide from ._color import ( ColorJitter, + Grayscale, RandomAdjustSharpness, RandomAutocontrast, RandomEqualize, + RandomGrayscale, RandomInvert, RandomPhotometricDistort, RandomPosterize, @@ -54,4 +56,4 @@ from ._temporal import UniformTemporalSubsample from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage -from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip +from ._deprecated import ToTensor # usort: skip diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 0eb20e57764..607ce49d3b3 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,5 +1,5 @@ import collections.abc -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union import PIL.Image import torch @@ -11,6 +11,21 @@ from .utils import is_simple_tensor, query_chw +class GrayScale(Transform): + def __init__(self, num_output_channels: Literal[1, 3] = 1): + super().__init__() + self.num_output_channels = num_output_channels + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) + + +class RandomGrayScale(_RandomApplyTransform): + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + num_output_channels = F.get_num_channels(inpt) + return F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) + + class ColorJitter(Transform): def __init__( self, @@ -198,3 +213,47 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor) + + +class Grayscale(Transform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: + self.num_output_channels = num_output_channels + + def _transform( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: + output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] + return output + + +class RandomGrayscale(_RandomApplyTransform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, p: float = 0.1) -> None: + super().__init__(p=p) + + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_input_channels, *_ = query_chw(flat_inputs) + return dict(num_input_channels=num_input_channels) + + def _transform( + self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] + ) -> Union[datapoints.ImageType, datapoints.VideoType]: + output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) + if isinstance(inpt, (datapoints.Image, datapoints.Video)): + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] + return output diff --git a/torchvision/prototype/transforms/_deprecated.py b/torchvision/prototype/transforms/_deprecated.py index 974fe2b2741..cd37f4d73d0 100644 --- a/torchvision/prototype/transforms/_deprecated.py +++ b/torchvision/prototype/transforms/_deprecated.py @@ -1,17 +1,12 @@ import warnings -from typing import Any, Dict, List, Union +from typing import Any, Dict, Union import numpy as np import PIL.Image import torch -from torchvision.prototype import datapoints from torchvision.prototype.transforms import Transform from torchvision.transforms import functional as _F -from typing_extensions import Literal - -from ._transform import _RandomApplyTransform -from .utils import is_simple_tensor, query_chw class ToTensor(Transform): @@ -26,78 +21,3 @@ def __init__(self) -> None: def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor: return _F.to_tensor(inpt) - - -# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray? -class Grayscale(Transform): - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: - deprecation_msg = ( - f"The transform `Grayscale(num_output_channels={num_output_channels})` " - f"is deprecated and will be removed in a future release." - ) - if num_output_channels == 1: - replacement_msg = ( - "transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)" - ) - else: - replacement_msg = ( - "transforms.Compose(\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" - ")" - ) - warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}") - - super().__init__() - self.num_output_channels = num_output_channels - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output - - -class RandomGrayscale(_RandomApplyTransform): - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, p: float = 0.1) -> None: - warnings.warn( - "The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. " - "Instead, please use\n\n" - "transforms.RandomApply(\n" - " transforms.Compose(\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n" - " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n" - " )\n" - " p=...,\n" - ")" - ) - - super().__init__(p=p) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - num_input_channels, *_ = query_chw(flat_inputs) - return dict(num_input_channels=num_input_channels) - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 57b4cc4423a..283a4ad3bbf 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -71,6 +71,7 @@ posterize_image_pil, posterize_image_tensor, posterize_video, + rgb_to_grayscale, solarize, solarize_image_pil, solarize_image_tensor, @@ -167,4 +168,4 @@ from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image -from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip +from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 53de1f407c8..1a87aa61e95 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,3 +1,5 @@ +from typing import Literal, Union + import PIL.Image import torch from torch.nn.functional import conv2d @@ -7,10 +9,44 @@ from torchvision.utils import _log_api_usage_once -from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor +from ._meta import _num_value_bits, convert_dtype_image_tensor, get_num_channels from ._utils import is_simple_tensor +def rgb_to_grayscale_tensor(image: torch.Tensor, num_output_channels: Literal[1, 3] = 1) -> torch.Tensor: + r, g, b = image.unbind(dim=-3) + l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) + l_img = l_img.unsqueeze(dim=-3) + if num_output_channels == 3: + return l_img.expand(image.shape) + return l_img + + +def rgb_to_grayscale( + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: Literal[1, 3] = 1 +) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: + if not torch.jit.is_scripting(): + _log_api_usage_once(rgb_to_grayscale) + if num_output_channels not in (1, 3): + raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") + num_channels = get_num_channels(inpt) + if num_channels != 3: + raise ValueError( + "Image is expected to have 3 channels (RGB) to be converted to grayscale" f"Got {num_channels}" + ) + if torch.jit.is_scripting() or is_simple_tensor(inpt): + return rgb_to_grayscale_tensor(inpt, num_output_channels=num_output_channels) + elif isinstance(inpt, datapoints._datapoint.Datapoint): + return inpt.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) + elif isinstance(inpt, PIL.Image.Image): + return _FP.to_grayscale(inpt, num_output_channels=num_output_channels) + else: + raise TypeError( + f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " + f"but got {type(inpt)} instead." + ) + + def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: ratio = float(ratio) fp = image1.is_floating_point() @@ -68,7 +104,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float if c == 1: # Match PIL behaviour return image - grayscale_image = _rgb_to_gray(image, cast=False) + grayscale_image = rgb_to_grayscale_tensor(image) if not image.is_floating_point(): grayscale_image = grayscale_image.floor_() @@ -110,7 +146,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") fp = image.is_floating_point() if c == 3: - grayscale_image = _rgb_to_gray(image, cast=False) + grayscale_image = rgb_to_grayscale_tensor(image) if not fp: grayscale_image = grayscale_image.floor_() else: diff --git a/torchvision/prototype/transforms/functional/_deprecated.py b/torchvision/prototype/transforms/functional/_deprecated.py index a89bcae7b90..09870216059 100644 --- a/torchvision/prototype/transforms/functional/_deprecated.py +++ b/torchvision/prototype/transforms/functional/_deprecated.py @@ -7,8 +7,6 @@ from torchvision.prototype import datapoints from torchvision.transforms import functional as _F -from ._utils import is_simple_tensor - @torch.jit.unused def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image: @@ -24,33 +22,6 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima return _F.to_grayscale(inpt, num_output_channels=num_output_channels) -def rgb_to_grayscale( - inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 -) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: - old_color_space = None # TODO: remove when un-deprecating - if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance( - inpt, (datapoints.Image, datapoints.Video) - ): - inpt = inpt.as_subclass(torch.Tensor) - - call = ", num_output_channels=3" if num_output_channels == 3 else "" - replacement = ( - f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY" - f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})" - ) - if num_output_channels == 3: - replacement = ( - f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB" - f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})" - ) - warnings.warn( - f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. " - f"Instead, please use `{replacement}`.", - ) - - return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) - - @torch.jit.unused def to_tensor(inpt: Any) -> torch.Tensor: warnings.warn( diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index b76dc7d7b68..31d86bec256 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -225,15 +225,6 @@ def clamp_bounding_box( return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) -def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor: - r, g, b = image.unbind(dim=-3) - l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) - if cast: - l_img = l_img.to(image.dtype) - l_img = l_img.unsqueeze(dim=-3) - return l_img - - def _num_value_bits(dtype: torch.dtype) -> int: if dtype == torch.uint8: return 8 From 9d79d60a41c83c7a2d31946a9df392a1c5b8333f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 14:14:58 +0000 Subject: [PATCH 2/9] Remove duplicated --- torchvision/prototype/transforms/_color.py | 65 +++++++--------------- 1 file changed, 19 insertions(+), 46 deletions(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 607ce49d3b3..ab09843ccb7 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -11,7 +11,14 @@ from .utils import is_simple_tensor, query_chw -class GrayScale(Transform): +class Grayscale(Transform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + def __init__(self, num_output_channels: Literal[1, 3] = 1): super().__init__() self.num_output_channels = num_output_channels @@ -20,7 +27,17 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) -class RandomGrayScale(_RandomApplyTransform): +class RandomGrayscale(_RandomApplyTransform): + _transformed_types = ( + datapoints.Image, + PIL.Image.Image, + is_simple_tensor, + datapoints.Video, + ) + + def __init__(self, p: float = 0.1) -> None: + super().__init__(p=p) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: num_output_channels = F.get_num_channels(inpt) return F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) @@ -213,47 +230,3 @@ def __init__(self, sharpness_factor: float, p: float = 0.5) -> None: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return F.adjust_sharpness(inpt, sharpness_factor=self.sharpness_factor) - - -class Grayscale(Transform): - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None: - self.num_output_channels = num_output_channels - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels) - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output - - -class RandomGrayscale(_RandomApplyTransform): - _transformed_types = ( - datapoints.Image, - PIL.Image.Image, - is_simple_tensor, - datapoints.Video, - ) - - def __init__(self, p: float = 0.1) -> None: - super().__init__(p=p) - - def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: - num_input_channels, *_ = query_chw(flat_inputs) - return dict(num_input_channels=num_input_channels) - - def _transform( - self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[datapoints.ImageType, datapoints.VideoType]: - output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) - if isinstance(inpt, (datapoints.Image, datapoints.Video)): - output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] - return output From 1c08b2b7d2e75585236b7b658b249bb6a2f0c1d7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 14:30:49 +0000 Subject: [PATCH 3/9] put back _get_params --- torchvision/prototype/transforms/_color.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index ab09843ccb7..7e282098cda 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -38,9 +38,12 @@ class RandomGrayscale(_RandomApplyTransform): def __init__(self, p: float = 0.1) -> None: super().__init__(p=p) + def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: + num_input_channels, *_ = query_chw(flat_inputs) + return dict(num_input_channels=num_input_channels) + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - num_output_channels = F.get_num_channels(inpt) - return F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) + return F.rgb_to_grayscale(inpt, num_output_channels=params["num_output_channels"]) class ColorJitter(Transform): From 23802c5c006700ddc8433a64970694029a78ab0a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Jan 2023 16:54:12 +0000 Subject: [PATCH 4/9] =?UTF-8?q?=F0=9F=98=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torchvision/prototype/datapoints/_image.py | 3 ++- torchvision/prototype/datapoints/_video.py | 3 ++- torchvision/prototype/transforms/_color.py | 3 ++- torchvision/prototype/transforms/functional/_color.py | 3 ++- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 2c4a9bab7a2..57a68f6825b 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import PIL.Image import torch from torchvision.transforms.functional import InterpolationMode +from typing_extensions import Literal from ._datapoint import Datapoint, FillTypeJIT diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index 179bfa72f8a..b8099ab79fd 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Any, List, Literal, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torchvision.transforms.functional import InterpolationMode +from typing_extensions import Literal from ._datapoint import Datapoint, FillTypeJIT diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 7e282098cda..72813a7710a 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -1,11 +1,12 @@ import collections.abc -from typing import Any, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import PIL.Image import torch from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform +from typing_extensions import Literal from ._transform import _RandomApplyTransform from .utils import is_simple_tensor, query_chw diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 1a87aa61e95..3eb8a6016c7 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -1,4 +1,4 @@ -from typing import Literal, Union +from typing import Union import PIL.Image import torch @@ -8,6 +8,7 @@ from torchvision.transforms.functional_tensor import _max_value from torchvision.utils import _log_api_usage_once +from typing_extensions import Literal from ._meta import _num_value_bits, convert_dtype_image_tensor, get_num_channels from ._utils import is_simple_tensor From d98c98c7c3c41e145cf77e7927f5b81b0938d87a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 26 Jan 2023 10:36:46 +0000 Subject: [PATCH 5/9] Literal -> int. Why? You know why... --- torchvision/prototype/transforms/functional/_color.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 3eb8a6016c7..5718728f00d 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -8,13 +8,12 @@ from torchvision.transforms.functional_tensor import _max_value from torchvision.utils import _log_api_usage_once -from typing_extensions import Literal from ._meta import _num_value_bits, convert_dtype_image_tensor, get_num_channels from ._utils import is_simple_tensor -def rgb_to_grayscale_tensor(image: torch.Tensor, num_output_channels: Literal[1, 3] = 1) -> torch.Tensor: +def rgb_to_grayscale_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3) @@ -24,7 +23,7 @@ def rgb_to_grayscale_tensor(image: torch.Tensor, num_output_channels: Literal[1, def rgb_to_grayscale( - inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: Literal[1, 3] = 1 + inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: if not torch.jit.is_scripting(): _log_api_usage_once(rgb_to_grayscale) From d77bf1761bc201cbad27f1fe7b95f6e75242d48d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 26 Jan 2023 11:01:05 +0000 Subject: [PATCH 6/9] Fix (some?) tests --- test/test_prototype_transforms_consistency.py | 5 +++- torchvision/prototype/datapoints/_image.py | 2 +- torchvision/prototype/datapoints/_video.py | 2 +- torchvision/prototype/transforms/_color.py | 2 +- .../transforms/functional/__init__.py | 2 ++ .../prototype/transforms/functional/_color.py | 25 +++++++++++++------ 6 files changed, 27 insertions(+), 11 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index 3b69b72dd4f..fac9837342a 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -148,7 +148,8 @@ def __init__( ArgsKwargs(num_output_channels=1), ArgsKwargs(num_output_channels=3), ], - make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), + # Use default tolerances of `torch.testing.assert_close` + closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( prototype_transforms.ConvertDtype, @@ -271,6 +272,8 @@ def __init__( ArgsKwargs(p=0), ArgsKwargs(p=1), ], + # Use default tolerances of `torch.testing.assert_close` + closeness_kwargs=dict(rtol=None, atol=None), ), ConsistencyConfig( prototype_transforms.RandomResizedCrop, diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 57a68f6825b..2365638ee0d 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -171,7 +171,7 @@ def elastic( return Image.wrap_like(self, output) def to_grayscale(self, num_output_channels: Literal[1, 3] = 1) -> Image: - output = self._F.rgb_to_grayscale_tensor( + output = self._F.rgb_to_grayscale_image_tensor( self.as_subclass(torch.Tensor), num_output_channels=num_output_channels ) return Image.wrap_like(self, output) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index b8099ab79fd..ad582353dff 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -175,7 +175,7 @@ def elastic( return Video.wrap_like(self, output) def to_grayscale(self, num_output_channels: Literal[1, 3] = 1) -> Video: - output = self._F.rgb_to_grayscale_tensor( + output = self._F.rgb_to_grayscale_image_tensor( self.as_subclass(torch.Tensor), num_output_channels=num_output_channels ) return Video.wrap_like(self, output) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 72813a7710a..19573d7148d 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -44,7 +44,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]: return dict(num_input_channels=num_input_channels) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - return F.rgb_to_grayscale(inpt, num_output_channels=params["num_output_channels"]) + return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"]) class ColorJitter(Transform): diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 283a4ad3bbf..0909b763441 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -72,6 +72,8 @@ posterize_image_tensor, posterize_video, rgb_to_grayscale, + rgb_to_grayscale_image_pil, + rgb_to_grayscale_image_tensor, solarize, solarize_image_pil, solarize_image_tensor, diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 5718728f00d..de6e2207414 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -13,15 +13,26 @@ from ._utils import is_simple_tensor -def rgb_to_grayscale_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: +def _rgb_to_grayscale_image_tensor( + image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True +) -> torch.Tensor: r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3) + if preserve_dtype: + l_img = l_img.to(image.dtype) if num_output_channels == 3: return l_img.expand(image.shape) return l_img +def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor: + return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True) + + +rgb_to_grayscale_image_pil = _FP.to_grayscale + + def rgb_to_grayscale( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1 ) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]: @@ -32,14 +43,14 @@ def rgb_to_grayscale( num_channels = get_num_channels(inpt) if num_channels != 3: raise ValueError( - "Image is expected to have 3 channels (RGB) to be converted to grayscale" f"Got {num_channels}" + f"Image is expected to have 3 channels (RGB) to be converted to grayscale. Got {num_channels}." ) if torch.jit.is_scripting() or is_simple_tensor(inpt): - return rgb_to_grayscale_tensor(inpt, num_output_channels=num_output_channels) + return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) elif isinstance(inpt, datapoints._datapoint.Datapoint): - return inpt.rgb_to_grayscale(inpt, num_output_channels=num_output_channels) + return inpt.to_grayscale(num_output_channels=num_output_channels) elif isinstance(inpt, PIL.Image.Image): - return _FP.to_grayscale(inpt, num_output_channels=num_output_channels) + return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels) else: raise TypeError( f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, " @@ -104,7 +115,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float if c == 1: # Match PIL behaviour return image - grayscale_image = rgb_to_grayscale_tensor(image) + grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) if not image.is_floating_point(): grayscale_image = grayscale_image.floor_() @@ -146,7 +157,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) -> raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}") fp = image.is_floating_point() if c == 3: - grayscale_image = rgb_to_grayscale_tensor(image) + grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False) if not fp: grayscale_image = grayscale_image.floor_() else: From 14fe16cbffa3d210f66b769a08a1fad22c596583 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 26 Jan 2023 11:18:46 +0000 Subject: [PATCH 7/9] Put back pass-through when C==1 (I shouldn't have removed it) --- test/test_prototype_transforms_consistency.py | 2 ++ .../prototype/transforms/functional/_color.py | 12 +++++------- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_transforms_consistency.py b/test/test_prototype_transforms_consistency.py index fac9837342a..b416dae20e0 100644 --- a/test/test_prototype_transforms_consistency.py +++ b/test/test_prototype_transforms_consistency.py @@ -148,6 +148,7 @@ def __init__( ArgsKwargs(num_output_channels=1), ArgsKwargs(num_output_channels=3), ], + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), # Use default tolerances of `torch.testing.assert_close` closeness_kwargs=dict(rtol=None, atol=None), ), @@ -272,6 +273,7 @@ def __init__( ArgsKwargs(p=0), ArgsKwargs(p=1), ], + make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]), # Use default tolerances of `torch.testing.assert_close` closeness_kwargs=dict(rtol=None, atol=None), ), diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index de6e2207414..719bd801e74 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -9,20 +9,23 @@ from torchvision.utils import _log_api_usage_once -from ._meta import _num_value_bits, convert_dtype_image_tensor, get_num_channels +from ._meta import _num_value_bits, convert_dtype_image_tensor from ._utils import is_simple_tensor def _rgb_to_grayscale_image_tensor( image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True ) -> torch.Tensor: + if image.shape[-3] == 1: + return image.clone() + r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3) if preserve_dtype: l_img = l_img.to(image.dtype) if num_output_channels == 3: - return l_img.expand(image.shape) + l_img = l_img.expand(image.shape) return l_img @@ -40,11 +43,6 @@ def rgb_to_grayscale( _log_api_usage_once(rgb_to_grayscale) if num_output_channels not in (1, 3): raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.") - num_channels = get_num_channels(inpt) - if num_channels != 3: - raise ValueError( - f"Image is expected to have 3 channels (RGB) to be converted to grayscale. Got {num_channels}." - ) if torch.jit.is_scripting() or is_simple_tensor(inpt): return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels) elif isinstance(inpt, datapoints._datapoint.Datapoint): From 2d8cf4fa273d5be9da73ba99ac4aff7569d92ee5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 26 Jan 2023 11:31:48 +0000 Subject: [PATCH 8/9] #NOTmypy --- torchvision/prototype/datapoints/_datapoint.py | 3 +++ torchvision/prototype/datapoints/_image.py | 3 +-- torchvision/prototype/datapoints/_video.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/datapoints/_datapoint.py b/torchvision/prototype/datapoints/_datapoint.py index d6472301e99..5c50542b07d 100644 --- a/torchvision/prototype/datapoints/_datapoint.py +++ b/torchvision/prototype/datapoints/_datapoint.py @@ -230,6 +230,9 @@ def elastic( ) -> Datapoint: return self + def to_grayscale(self, num_output_channels: int = 1) -> Datapoint: + return self + def adjust_brightness(self, brightness_factor: float) -> Datapoint: return self diff --git a/torchvision/prototype/datapoints/_image.py b/torchvision/prototype/datapoints/_image.py index 2365638ee0d..0b2ab7453ff 100644 --- a/torchvision/prototype/datapoints/_image.py +++ b/torchvision/prototype/datapoints/_image.py @@ -5,7 +5,6 @@ import PIL.Image import torch from torchvision.transforms.functional import InterpolationMode -from typing_extensions import Literal from ._datapoint import Datapoint, FillTypeJIT @@ -170,7 +169,7 @@ def elastic( ) return Image.wrap_like(self, output) - def to_grayscale(self, num_output_channels: Literal[1, 3] = 1) -> Image: + def to_grayscale(self, num_output_channels: int = 1) -> Image: output = self._F.rgb_to_grayscale_image_tensor( self.as_subclass(torch.Tensor), num_output_channels=num_output_channels ) diff --git a/torchvision/prototype/datapoints/_video.py b/torchvision/prototype/datapoints/_video.py index ad582353dff..50f9110f40c 100644 --- a/torchvision/prototype/datapoints/_video.py +++ b/torchvision/prototype/datapoints/_video.py @@ -4,7 +4,6 @@ import torch from torchvision.transforms.functional import InterpolationMode -from typing_extensions import Literal from ._datapoint import Datapoint, FillTypeJIT @@ -174,7 +173,7 @@ def elastic( ) return Video.wrap_like(self, output) - def to_grayscale(self, num_output_channels: Literal[1, 3] = 1) -> Video: + def to_grayscale(self, num_output_channels: int = 1) -> Video: output = self._F.rgb_to_grayscale_image_tensor( self.as_subclass(torch.Tensor), num_output_channels=num_output_channels ) From 29be7b518dcd1c6e2e2854f2e106ce8e385e97d5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 26 Jan 2023 13:31:35 +0000 Subject: [PATCH 9/9] remove Literal for good --- torchvision/prototype/transforms/_color.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/_color.py b/torchvision/prototype/transforms/_color.py index 19573d7148d..6ab997b1e93 100644 --- a/torchvision/prototype/transforms/_color.py +++ b/torchvision/prototype/transforms/_color.py @@ -6,7 +6,6 @@ from torchvision.prototype import datapoints from torchvision.prototype.transforms import functional as F, Transform -from typing_extensions import Literal from ._transform import _RandomApplyTransform from .utils import is_simple_tensor, query_chw @@ -20,7 +19,7 @@ class Grayscale(Transform): datapoints.Video, ) - def __init__(self, num_output_channels: Literal[1, 3] = 1): + def __init__(self, num_output_channels: int = 1): super().__init__() self.num_output_channels = num_output_channels