11import  torch 
22from  torch .nn .functional  import  conv2d 
33from  torchvision .prototype  import  features 
4- from  torchvision .transforms  import  functional_pil  as  _FP , functional_tensor  as  _FT 
4+ from  torchvision .transforms  import  functional_pil  as  _FP 
5+ from  torchvision .transforms .functional_tensor  import  _max_value 
56
67from  ._meta  import  _num_value_bits , _rgb_to_gray , convert_dtype_image_tensor 
78
89
910def  _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) ->  torch .Tensor :
1011    ratio  =  float (ratio )
1112    fp  =  image1 .is_floating_point ()
12-     bound  =  _FT . _max_value (image1 .dtype )
13+     bound  =  _max_value (image1 .dtype )
1314    output  =  image1 .mul (ratio ).add_ (image2 , alpha = (1.0  -  ratio )).clamp_ (0 , bound )
1415    return  output  if  fp  else  output .to (image1 .dtype )
1516
@@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
1819    if  brightness_factor  <  0 :
1920        raise  ValueError (f"brightness_factor ({ brightness_factor }  )
2021
21-     _FT ._assert_channels (image , [1 , 3 ])
22+     c  =  image .shape [- 3 ]
23+     if  c  not  in 1 , 3 ]:
24+         raise  TypeError (f"Input image tensor permitted channel values are 1 or 3, but found { c }  )
2225
2326    fp  =  image .is_floating_point ()
24-     bound  =  _FT . _max_value (image .dtype )
27+     bound  =  _max_value (image .dtype )
2528    output  =  image .mul (brightness_factor ).clamp_ (0 , bound )
2629    return  output  if  fp  else  output .to (image .dtype )
2730
@@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
4851
4952    c  =  image .shape [- 3 ]
5053    if  c  not  in 1 , 3 ]:
51-         raise  TypeError (f"Input image tensor permitted channel values are { [ 1 ,  3 ] } { c }  )
54+         raise  TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c }  )
5255
5356    if  c  ==  1 :  # Match PIL behaviour 
5457        return  image 
@@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
8285
8386    c  =  image .shape [- 3 ]
8487    if  c  not  in 1 , 3 ]:
85-         raise  TypeError (f"Input image tensor permitted channel values are { [ 1 ,  3 ] } { c }  )
88+         raise  TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c }  )
8689    fp  =  image .is_floating_point ()
8790    if  c  ==  3 :
8891        grayscale_image  =  _rgb_to_gray (image , cast = False )
@@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
121124    if  image .numel () ==  0  or  height  <=  2  or  width  <=  2 :
122125        return  image 
123126
124-     bound  =  _FT . _max_value (image .dtype )
127+     bound  =  _max_value (image .dtype )
125128    fp  =  image .is_floating_point ()
126129    shape  =  image .shape 
127130
@@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
248251
249252    c  =  image .shape [- 3 ]
250253    if  c  not  in 1 , 3 ]:
251-         raise  TypeError (f"Input image tensor permitted channel values are { [ 1 ,  3 ] } { c }  )
254+         raise  TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c }  )
252255
253256    if  c  ==  1 :  # Match PIL behaviour 
254257        return  image 
@@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
350353
351354
352355def  solarize_image_tensor (image : torch .Tensor , threshold : float ) ->  torch .Tensor :
353-     if  threshold  >  _FT . _max_value (image .dtype ):
356+     if  threshold  >  _max_value (image .dtype ):
354357        raise  TypeError (f"Threshold should be less or equal the maximum value of the dtype, but got { threshold }  )
355358
356359    return  torch .where (image  >=  threshold , invert_image_tensor (image ), image )
@@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
375378def  autocontrast_image_tensor (image : torch .Tensor ) ->  torch .Tensor :
376379    c  =  image .shape [- 3 ]
377380    if  c  not  in 1 , 3 ]:
378-         raise  TypeError (f"Input image tensor permitted channel values are { [ 1 ,  3 ] } { c }  )
381+         raise  TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c }  )
379382
380383    if  image .numel () ==  0 :
381384        # exit earlier on empty images 
382385        return  image 
383386
384-     bound  =  _FT . _max_value (image .dtype )
387+     bound  =  _max_value (image .dtype )
385388    fp  =  image .is_floating_point ()
386389    float_image  =  image  if  fp  else  image .to (torch .float32 )
387390
0 commit comments