Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def __init__(
ArgsKwargs(num_output_channels=3),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
prototype_transforms.ConvertDtype,
Expand Down Expand Up @@ -271,6 +273,9 @@ def __init__(
ArgsKwargs(p=0),
ArgsKwargs(p=1),
],
make_images_kwargs=dict(DEFAULT_MAKE_IMAGES_KWARGS, color_spaces=["RGB", "GRAY"]),
# Use default tolerances of `torch.testing.assert_close`
closeness_kwargs=dict(rtol=None, atol=None),
),
ConsistencyConfig(
prototype_transforms.RandomResizedCrop,
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,9 @@ def elastic(
) -> Datapoint:
return self

def to_grayscale(self, num_output_channels: int = 1) -> Datapoint:
return self

def adjust_brightness(self, brightness_factor: float) -> Datapoint:
return self

Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,12 @@ def elastic(
)
return Image.wrap_like(self, output)

def to_grayscale(self, num_output_channels: int = 1) -> Image:
output = self._F.rgb_to_grayscale_image_tensor(
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
)
return Image.wrap_like(self, output)

def adjust_brightness(self, brightness_factor: float) -> Image:
output = self._F.adjust_brightness_image_tensor(
self.as_subclass(torch.Tensor), brightness_factor=brightness_factor
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/datapoints/_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ def elastic(
)
return Video.wrap_like(self, output)

def to_grayscale(self, num_output_channels: int = 1) -> Video:
output = self._F.rgb_to_grayscale_image_tensor(
self.as_subclass(torch.Tensor), num_output_channels=num_output_channels
)
return Video.wrap_like(self, output)

def adjust_brightness(self, brightness_factor: float) -> Video:
output = self._F.adjust_brightness_video(self.as_subclass(torch.Tensor), brightness_factor=brightness_factor)
return Video.wrap_like(self, output)
Expand Down
4 changes: 3 additions & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
from ._auto_augment import AugMix, AutoAugment, RandAugment, TrivialAugmentWide
from ._color import (
ColorJitter,
Grayscale,
RandomAdjustSharpness,
RandomAutocontrast,
RandomEqualize,
RandomGrayscale,
RandomInvert,
RandomPhotometricDistort,
RandomPosterize,
Expand Down Expand Up @@ -54,4 +56,4 @@
from ._temporal import UniformTemporalSubsample
from ._type_conversion import LabelToOneHot, PILToTensor, ToImagePIL, ToImageTensor, ToPILImage

from ._deprecated import Grayscale, RandomGrayscale, ToTensor # usort: skip
from ._deprecated import ToTensor # usort: skip
35 changes: 35 additions & 0 deletions torchvision/prototype/transforms/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,41 @@
from .utils import is_simple_tensor, query_chw


class Grayscale(Transform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)

def __init__(self, num_output_channels: int = 1):
super().__init__()
self.num_output_channels = num_output_channels

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)


class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)

def __init__(self, p: float = 0.1) -> None:
super().__init__(p=p)

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])


class ColorJitter(Transform):
def __init__(
self,
Expand Down
82 changes: 1 addition & 81 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import warnings
from typing import Any, Dict, List, Union
from typing import Any, Dict, Union

import numpy as np
import PIL.Image
import torch

from torchvision.prototype import datapoints
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from .utils import is_simple_tensor, query_chw


class ToTensor(Transform):
Expand All @@ -26,78 +21,3 @@ def __init__(self) -> None:

def _transform(self, inpt: Union[PIL.Image.Image, np.ndarray], params: Dict[str, Any]) -> torch.Tensor:
return _F.to_tensor(inpt)


# TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
class Grayscale(Transform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)

def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
f"is deprecated and will be removed in a future release."
)
if num_output_channels == 1:
replacement_msg = (
"transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)"
)
else:
replacement_msg = (
"transforms.Compose(\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
")"
)
warnings.warn(f"{deprecation_msg} Instead, please use\n\n{replacement_msg}")

super().__init__()
self.num_output_channels = num_output_channels

def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output


class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (
datapoints.Image,
PIL.Image.Image,
is_simple_tensor,
datapoints.Video,
)

def __init__(self, p: float = 0.1) -> None:
warnings.warn(
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
"Instead, please use\n\n"
"transforms.RandomApply(\n"
" transforms.Compose(\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n"
" transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n"
" )\n"
" p=...,\n"
")"
)

super().__init__(p=p)

def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
num_input_channels, *_ = query_chw(flat_inputs)
return dict(num_input_channels=num_input_channels)

def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[datapoints.ImageType, datapoints.VideoType]:
output = _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
return output
5 changes: 4 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@
posterize_image_pil,
posterize_image_tensor,
posterize_video,
rgb_to_grayscale,
rgb_to_grayscale_image_pil,
rgb_to_grayscale_image_tensor,
solarize,
solarize_image_pil,
solarize_image_tensor,
Expand Down Expand Up @@ -167,4 +170,4 @@
from ._temporal import uniform_temporal_subsample, uniform_temporal_subsample_video
from ._type_conversion import pil_to_tensor, to_image_pil, to_image_tensor, to_pil_image

from ._deprecated import get_image_size, rgb_to_grayscale, to_grayscale, to_tensor # usort: skip
from ._deprecated import get_image_size, to_grayscale, to_tensor # usort: skip
51 changes: 48 additions & 3 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Union

import PIL.Image
import torch
from torch.nn.functional import conv2d
Expand All @@ -7,10 +9,53 @@

from torchvision.utils import _log_api_usage_once

from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
from ._meta import _num_value_bits, convert_dtype_image_tensor
from ._utils import is_simple_tensor


def _rgb_to_grayscale_image_tensor(
image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True
) -> torch.Tensor:
if image.shape[-3] == 1:
return image.clone()

r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.unsqueeze(dim=-3)
if preserve_dtype:
l_img = l_img.to(image.dtype)
if num_output_channels == 3:
l_img = l_img.expand(image.shape)
return l_img


def rgb_to_grayscale_image_tensor(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
return _rgb_to_grayscale_image_tensor(image, num_output_channels=num_output_channels, preserve_dtype=True)


rgb_to_grayscale_image_pil = _FP.to_grayscale


def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
if not torch.jit.is_scripting():
_log_api_usage_once(rgb_to_grayscale)
if num_output_channels not in (1, 3):
raise ValueError(f"num_output_channels must be 1 or 3, got {num_output_channels}.")
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return rgb_to_grayscale_image_tensor(inpt, num_output_channels=num_output_channels)
elif isinstance(inpt, datapoints._datapoint.Datapoint):
return inpt.to_grayscale(num_output_channels=num_output_channels)
elif isinstance(inpt, PIL.Image.Image):
return rgb_to_grayscale_image_pil(inpt, num_output_channels=num_output_channels)
else:
raise TypeError(
f"Input can either be a plain tensor, any TorchVision datapoint, or a PIL image, "
f"but got {type(inpt)} instead."
)


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
Expand Down Expand Up @@ -68,7 +113,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
if c == 1: # Match PIL behaviour
return image

grayscale_image = _rgb_to_gray(image, cast=False)
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
if not image.is_floating_point():
grayscale_image = grayscale_image.floor_()

Expand Down Expand Up @@ -110,7 +155,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
fp = image.is_floating_point()
if c == 3:
grayscale_image = _rgb_to_gray(image, cast=False)
grayscale_image = _rgb_to_grayscale_image_tensor(image, num_output_channels=1, preserve_dtype=False)
if not fp:
grayscale_image = grayscale_image.floor_()
else:
Expand Down
29 changes: 0 additions & 29 deletions torchvision/prototype/transforms/functional/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from torchvision.prototype import datapoints
from torchvision.transforms import functional as _F

from ._utils import is_simple_tensor


@torch.jit.unused
def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Image.Image:
Expand All @@ -24,33 +22,6 @@ def to_grayscale(inpt: PIL.Image.Image, num_output_channels: int = 1) -> PIL.Ima
return _F.to_grayscale(inpt, num_output_channels=num_output_channels)


def rgb_to_grayscale(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], num_output_channels: int = 1
) -> Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT]:
old_color_space = None # TODO: remove when un-deprecating
if not (torch.jit.is_scripting() or is_simple_tensor(inpt)) and isinstance(
inpt, (datapoints.Image, datapoints.Video)
):
inpt = inpt.as_subclass(torch.Tensor)

call = ", num_output_channels=3" if num_output_channels == 3 else ""
replacement = (
f"convert_color_space(..., color_space=datapoints.ColorSpace.GRAY"
f"{f', old_color_space=datapoints.ColorSpace.{old_color_space}' if old_color_space is not None else ''})"
)
if num_output_channels == 3:
replacement = (
f"convert_color_space({replacement}, color_space=datapoints.ColorSpace.RGB"
f"{f', old_color_space=datapoints.ColorSpace.GRAY' if old_color_space is not None else ''})"
)
warnings.warn(
f"The function `rgb_to_grayscale(...{call})` is deprecated in will be removed in a future release. "
f"Instead, please use `{replacement}`.",
)

return _F.rgb_to_grayscale(inpt, num_output_channels=num_output_channels)


@torch.jit.unused
def to_tensor(inpt: Any) -> torch.Tensor:
warnings.warn(
Expand Down
9 changes: 0 additions & 9 deletions torchvision/prototype/transforms/functional/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,6 @@ def clamp_bounding_box(
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)


def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
if cast:
l_img = l_img.to(image.dtype)
l_img = l_img.unsqueeze(dim=-3)
return l_img


def _num_value_bits(dtype: torch.dtype) -> int:
if dtype == torch.uint8:
return 8
Expand Down