11import warnings
2- from typing import Any , Dict , List , Union
2+ from typing import Any , Dict , Union
33
44import numpy as np
55import PIL .Image
66import torch
77
8- from torchvision .prototype import datapoints
98from torchvision .prototype .transforms import Transform
109from torchvision .transforms import functional as _F
11- from typing_extensions import Literal
12-
13- from ._transform import _RandomApplyTransform
14- from .utils import is_simple_tensor , query_chw
1510
1611
1712class ToTensor (Transform ):
@@ -26,78 +21,3 @@ def __init__(self) -> None:
2621
2722 def _transform (self , inpt : Union [PIL .Image .Image , np .ndarray ], params : Dict [str , Any ]) -> torch .Tensor :
2823 return _F .to_tensor (inpt )
29-
30-
31- # TODO: in other PR (?) undeprecate those and make them use _rgb_to_gray?
32- class Grayscale (Transform ):
33- _transformed_types = (
34- datapoints .Image ,
35- PIL .Image .Image ,
36- is_simple_tensor ,
37- datapoints .Video ,
38- )
39-
40- def __init__ (self , num_output_channels : Literal [1 , 3 ] = 1 ) -> None :
41- deprecation_msg = (
42- f"The transform `Grayscale(num_output_channels={ num_output_channels } )` "
43- f"is deprecated and will be removed in a future release."
44- )
45- if num_output_channels == 1 :
46- replacement_msg = (
47- "transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY)"
48- )
49- else :
50- replacement_msg = (
51- "transforms.Compose(\n "
52- " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n "
53- " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n "
54- ")"
55- )
56- warnings .warn (f"{ deprecation_msg } Instead, please use\n \n { replacement_msg } " )
57-
58- super ().__init__ ()
59- self .num_output_channels = num_output_channels
60-
61- def _transform (
62- self , inpt : Union [datapoints .ImageType , datapoints .VideoType ], params : Dict [str , Any ]
63- ) -> Union [datapoints .ImageType , datapoints .VideoType ]:
64- output = _F .rgb_to_grayscale (inpt , num_output_channels = self .num_output_channels )
65- if isinstance (inpt , (datapoints .Image , datapoints .Video )):
66- output = inpt .wrap_like (inpt , output ) # type: ignore[arg-type]
67- return output
68-
69-
70- class RandomGrayscale (_RandomApplyTransform ):
71- _transformed_types = (
72- datapoints .Image ,
73- PIL .Image .Image ,
74- is_simple_tensor ,
75- datapoints .Video ,
76- )
77-
78- def __init__ (self , p : float = 0.1 ) -> None :
79- warnings .warn (
80- "The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
81- "Instead, please use\n \n "
82- "transforms.RandomApply(\n "
83- " transforms.Compose(\n "
84- " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.RGB, color_space=ColorSpace.GRAY),\n "
85- " transforms.ConvertImageColorSpace(old_color_space=ColorSpace.GRAY, color_space=ColorSpace.RGB),\n "
86- " )\n "
87- " p=...,\n "
88- ")"
89- )
90-
91- super ().__init__ (p = p )
92-
93- def _get_params (self , flat_inputs : List [Any ]) -> Dict [str , Any ]:
94- num_input_channels , * _ = query_chw (flat_inputs )
95- return dict (num_input_channels = num_input_channels )
96-
97- def _transform (
98- self , inpt : Union [datapoints .ImageType , datapoints .VideoType ], params : Dict [str , Any ]
99- ) -> Union [datapoints .ImageType , datapoints .VideoType ]:
100- output = _F .rgb_to_grayscale (inpt , num_output_channels = params ["num_input_channels" ])
101- if isinstance (inpt , (datapoints .Image , datapoints .Video )):
102- output = inpt .wrap_like (inpt , output ) # type: ignore[arg-type]
103- return output
0 commit comments