@@ -1051,38 +1051,35 @@ def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_firs
10511051        return  value 
10521052
10531053    @staticmethod  
1054-     @torch .jit .unused  
1055-     def  get_params (brightness , contrast , saturation , hue ):
1056-         """Get a randomized transform to be applied on image. 
1054+     def  get_params (brightness : Optional [List [float ]],
1055+                    contrast : Optional [List [float ]],
1056+                    saturation : Optional [List [float ]],
1057+                    hue : Optional [List [float ]]
1058+                    ) ->  Tuple [Tensor , Optional [float ], Optional [float ], Optional [float ], Optional [float ]]:
1059+         """Get the parameters for the randomized transform to be applied on image. 
10571060
1058-         Arguments are same as that of __init__. 
1061+         Args: 
1062+             brightness (tuple of float (min, max), optional): The range from which the brightness_factor is chosen 
1063+                 uniformly. Pass None to turn off the transformation. 
1064+             contrast (tuple of float (min, max), optional): The range from which the contrast_factor is chosen 
1065+                 uniformly. Pass None to turn off the transformation. 
1066+             saturation (tuple of float (min, max), optional): The range from which the saturation_factor is chosen 
1067+                 uniformly. Pass None to turn off the transformation. 
1068+             hue (tuple of float (min, max), optional): The range from which the hue_factor is chosen uniformly. 
1069+                 Pass None to turn off the transformation. 
10591070
10601071        Returns: 
1061-             Transform which randomly adjusts brightness, contrast and  
1062-             saturation in a  random order. 
1072+             tuple: The parameters used to apply the randomized transform  
1073+             along with their  random order. 
10631074        """ 
1064-         transforms  =  []
1065- 
1066-         if  brightness  is  not None :
1067-             brightness_factor  =  random .uniform (brightness [0 ], brightness [1 ])
1068-             transforms .append (Lambda (lambda  img : F .adjust_brightness (img , brightness_factor )))
1069- 
1070-         if  contrast  is  not None :
1071-             contrast_factor  =  random .uniform (contrast [0 ], contrast [1 ])
1072-             transforms .append (Lambda (lambda  img : F .adjust_contrast (img , contrast_factor )))
1073- 
1074-         if  saturation  is  not None :
1075-             saturation_factor  =  random .uniform (saturation [0 ], saturation [1 ])
1076-             transforms .append (Lambda (lambda  img : F .adjust_saturation (img , saturation_factor )))
1077- 
1078-         if  hue  is  not None :
1079-             hue_factor  =  random .uniform (hue [0 ], hue [1 ])
1080-             transforms .append (Lambda (lambda  img : F .adjust_hue (img , hue_factor )))
1075+         fn_idx  =  torch .randperm (4 )
10811076
1082-         random .shuffle (transforms )
1083-         transform  =  Compose (transforms )
1077+         b  =  None  if  brightness  is  None  else  float (torch .empty (1 ).uniform_ (brightness [0 ], brightness [1 ]))
1078+         c  =  None  if  contrast  is  None  else  float (torch .empty (1 ).uniform_ (contrast [0 ], contrast [1 ]))
1079+         s  =  None  if  saturation  is  None  else  float (torch .empty (1 ).uniform_ (saturation [0 ], saturation [1 ]))
1080+         h  =  None  if  hue  is  None  else  float (torch .empty (1 ).uniform_ (hue [0 ], hue [1 ]))
10841081
1085-         return  transform 
1082+         return  fn_idx ,  b ,  c ,  s ,  h 
10861083
10871084    def  forward (self , img ):
10881085        """ 
@@ -1092,26 +1089,17 @@ def forward(self, img):
10921089        Returns: 
10931090            PIL Image or Tensor: Color jittered image. 
10941091        """ 
1095-         fn_idx  =  torch .randperm (4 )
1092+         fn_idx , brightness_factor , contrast_factor , saturation_factor , hue_factor  =  \
1093+             self .get_params (self .brightness , self .contrast , self .saturation , self .hue )
1094+ 
10961095        for  fn_id  in  fn_idx :
1097-             if  fn_id  ==  0  and  self .brightness  is  not None :
1098-                 brightness  =  self .brightness 
1099-                 brightness_factor  =  torch .tensor (1.0 ).uniform_ (brightness [0 ], brightness [1 ]).item ()
1096+             if  fn_id  ==  0  and  brightness_factor  is  not None :
11001097                img  =  F .adjust_brightness (img , brightness_factor )
1101- 
1102-             if  fn_id  ==  1  and  self .contrast  is  not None :
1103-                 contrast  =  self .contrast 
1104-                 contrast_factor  =  torch .tensor (1.0 ).uniform_ (contrast [0 ], contrast [1 ]).item ()
1098+             elif  fn_id  ==  1  and  contrast_factor  is  not None :
11051099                img  =  F .adjust_contrast (img , contrast_factor )
1106- 
1107-             if  fn_id  ==  2  and  self .saturation  is  not None :
1108-                 saturation  =  self .saturation 
1109-                 saturation_factor  =  torch .tensor (1.0 ).uniform_ (saturation [0 ], saturation [1 ]).item ()
1100+             elif  fn_id  ==  2  and  saturation_factor  is  not None :
11101101                img  =  F .adjust_saturation (img , saturation_factor )
1111- 
1112-             if  fn_id  ==  3  and  self .hue  is  not None :
1113-                 hue  =  self .hue 
1114-                 hue_factor  =  torch .tensor (1.0 ).uniform_ (hue [0 ], hue [1 ]).item ()
1102+             elif  fn_id  ==  3  and  hue_factor  is  not None :
11151103                img  =  F .adjust_hue (img , hue_factor )
11161104
11171105        return  img 
0 commit comments