22from  torchvision .prototype  import  features 
33from  torchvision .transforms  import  functional_pil  as  _FP , functional_tensor  as  _FT 
44
5- from  ._meta  import  get_dimensions_image_tensor 
5+ from  ._meta  import  _rgb_to_gray , get_dimensions_image_tensor , get_num_channels_image_tensor 
6+ 
7+ 
8+ def  _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) ->  torch .Tensor :
9+     ratio  =  float (ratio )
10+     fp  =  image1 .is_floating_point ()
11+     bound  =  1.0  if  fp  else  255.0 
12+     output  =  image1 .mul (ratio ).add_ (image2 , alpha = (1.0  -  ratio )).clamp_ (0 , bound )
13+     return  output  if  fp  else  output .to (image1 .dtype )
14+ 
15+ 
16+ def  adjust_brightness_image_tensor (image : torch .Tensor , brightness_factor : float ) ->  torch .Tensor :
17+     if  brightness_factor  <  0 :
18+         raise  ValueError (f"brightness_factor ({ brightness_factor }  )
19+ 
20+     _FT ._assert_channels (image , [1 , 3 ])
21+ 
22+     fp  =  image .is_floating_point ()
23+     bound  =  1.0  if  fp  else  255.0 
24+     output  =  image .mul (brightness_factor ).clamp_ (0 , bound )
25+     return  output  if  fp  else  output .to (image .dtype )
26+ 
627
7- adjust_brightness_image_tensor  =  _FT .adjust_brightness 
828adjust_brightness_image_pil  =  _FP .adjust_brightness 
929
1030
@@ -21,7 +41,20 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) ->
2141        return  adjust_brightness_image_pil (inpt , brightness_factor = brightness_factor )
2242
2343
24- adjust_saturation_image_tensor  =  _FT .adjust_saturation 
44+ def  adjust_saturation_image_tensor (image : torch .Tensor , saturation_factor : float ) ->  torch .Tensor :
45+     if  saturation_factor  <  0 :
46+         raise  ValueError (f"saturation_factor ({ saturation_factor }  )
47+ 
48+     c  =  get_num_channels_image_tensor (image )
49+     if  c  not  in 1 , 3 ]:
50+         raise  TypeError (f"Input image tensor permitted channel values are { [1 , 3 ]} { c }  )
51+ 
52+     if  c  ==  1 :  # Match PIL behaviour 
53+         return  image 
54+ 
55+     return  _blend (image , _rgb_to_gray (image ), saturation_factor )
56+ 
57+ 
2558adjust_saturation_image_pil  =  _FP .adjust_saturation 
2659
2760
@@ -38,7 +71,19 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) ->
3871        return  adjust_saturation_image_pil (inpt , saturation_factor = saturation_factor )
3972
4073
41- adjust_contrast_image_tensor  =  _FT .adjust_contrast 
74+ def  adjust_contrast_image_tensor (image : torch .Tensor , contrast_factor : float ) ->  torch .Tensor :
75+     if  contrast_factor  <  0 :
76+         raise  ValueError (f"contrast_factor ({ contrast_factor }  )
77+ 
78+     c  =  get_num_channels_image_tensor (image )
79+     if  c  not  in 1 , 3 ]:
80+         raise  TypeError (f"Input image tensor permitted channel values are { [1 , 3 ]} { c }  )
81+     dtype  =  image .dtype  if  torch .is_floating_point (image ) else  torch .float32 
82+     grayscale_image  =  _rgb_to_gray (image ) if  c  ==  3  else  image 
83+     mean  =  torch .mean (grayscale_image .to (dtype ), dim = (- 3 , - 2 , - 1 ), keepdim = True )
84+     return  _blend (image , mean , contrast_factor )
85+ 
86+ 
4287adjust_contrast_image_pil  =  _FP .adjust_contrast 
4388
4489
@@ -74,7 +119,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
74119    else :
75120        needs_unsquash  =  False 
76121
77-     output  =  _FT . _blend (image , _FT ._blurred_degenerate_image (image ), sharpness_factor )
122+     output  =  _blend (image , _FT ._blurred_degenerate_image (image ), sharpness_factor )
78123
79124    if  needs_unsquash :
80125        output  =  output .reshape (shape )
@@ -183,13 +228,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183228        return  autocontrast_image_pil (inpt )
184229
185230
186- def  _equalize_image_tensor_vec (img : torch .Tensor ) ->  torch .Tensor :
187-     # input img  shape should be [N, H, W] 
188-     shape  =  img .shape 
231+ def  _equalize_image_tensor_vec (image : torch .Tensor ) ->  torch .Tensor :
232+     # input image  shape should be [N, H, W] 
233+     shape  =  image .shape 
189234    # Compute image histogram: 
190-     flat_img  =  img .flatten (start_dim = 1 ).to (torch .long )  # -> [N, H * W] 
191-     hist  =  flat_img .new_zeros (shape [0 ], 256 )
192-     hist .scatter_add_ (dim = 1 , index = flat_img , src = flat_img .new_ones (1 ).expand_as (flat_img ))
235+     flat_image  =  image .flatten (start_dim = 1 ).to (torch .long )  # -> [N, H * W] 
236+     hist  =  flat_image .new_zeros (shape [0 ], 256 )
237+     hist .scatter_add_ (dim = 1 , index = flat_image , src = flat_image .new_ones (1 ).expand_as (flat_image ))
193238
194239    # Compute image cdf 
195240    chist  =  hist .cumsum_ (dim = 1 )
@@ -213,7 +258,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
213258    zeros  =  lut .new_zeros ((1 , 1 )).expand (shape [0 ], 1 )
214259    lut  =  torch .cat ([zeros , lut [:, :- 1 ]], dim = 1 )
215260
216-     return  torch .where ((step  ==  0 ).unsqueeze (- 1 ), img , lut .gather (dim = 1 , index = flat_img ).reshape_as (img ))
261+     return  torch .where ((step  ==  0 ).unsqueeze (- 1 ), image , lut .gather (dim = 1 , index = flat_image ).reshape_as (image ))
217262
218263
219264def  equalize_image_tensor (image : torch .Tensor ) ->  torch .Tensor :
0 commit comments