@@ -183,28 +183,37 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183183 return autocontrast_image_pil (inpt )
184184
185185
186- def _scale_channel (img_chan : torch .Tensor ) -> torch .Tensor :
187- # TODO: we should expect bincount to always be faster than histc, but this
188- # isn't always the case. Once
189- # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if
190- # block and only use bincount.
191- if img_chan .is_cuda :
192- hist = torch .histc (img_chan .to (torch .float32 ), bins = 256 , min = 0 , max = 255 )
193- else :
194- hist = torch .bincount (img_chan .view (- 1 ), minlength = 256 )
195-
196- nonzero_hist = hist [hist != 0 ]
197- step = torch .div (nonzero_hist [:- 1 ].sum (), 255 , rounding_mode = "floor" )
198- if step == 0 :
199- return img_chan
200-
201- lut = torch .div (torch .cumsum (hist , 0 ) + torch .div (step , 2 , rounding_mode = "floor" ), step , rounding_mode = "floor" )
202- # Doing inplace clamp and converting lut to uint8 improves perfs
203- lut .clamp_ (0 , 255 )
204- lut = lut .to (torch .uint8 )
205- lut = torch .nn .functional .pad (lut [:- 1 ], [1 , 0 ])
206-
207- return lut [img_chan .to (torch .int64 )]
186+ def _equalize_image_tensor_vec (img : torch .Tensor ) -> torch .Tensor :
187+ # input img shape should be [N, H, W]
188+ shape = img .shape
189+ # Compute image histogram:
190+ flat_img = img .flatten (start_dim = 1 ).to (torch .long ) # -> [N, H * W]
191+ hist = flat_img .new_zeros (shape [0 ], 256 )
192+ hist .scatter_add_ (dim = 1 , index = flat_img , src = flat_img .new_ones (1 ).expand_as (flat_img ))
193+
194+ # Compute image cdf
195+ chist = hist .cumsum_ (dim = 1 )
196+ # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255
197+ # Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax()
198+ idx = chist .argmax (dim = 1 ).sub_ (1 )
199+ # If histogram is degenerate (hist of zero image), index is -1
200+ neg_idx_mask = idx < 0
201+ idx .clamp_ (min = 0 )
202+ step = chist .gather (dim = 1 , index = idx .unsqueeze (1 ))
203+ step [neg_idx_mask ] = 0
204+ step .div_ (255 , rounding_mode = "floor" )
205+
206+ # Compute batched Look-up-table:
207+ # Necessary to avoid an integer division by zero, which raises
208+ clamped_step = step .clamp (min = 1 )
209+ chist .add_ (torch .div (step , 2 , rounding_mode = "floor" )).div_ (clamped_step , rounding_mode = "floor" ).clamp_ (0 , 255 )
210+ lut = chist .to (torch .uint8 ) # [N, 256]
211+
212+ # Pad lut with zeros
213+ zeros = lut .new_zeros ((1 , 1 )).expand (shape [0 ], 1 )
214+ lut = torch .cat ([zeros , lut [:, :- 1 ]], dim = 1 )
215+
216+ return torch .where ((step == 0 ).unsqueeze (- 1 ), img , lut .gather (dim = 1 , index = flat_img ).view_as (img ))
208217
209218
210219def equalize_image_tensor (image : torch .Tensor ) -> torch .Tensor :
@@ -217,10 +226,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
217226
218227 if image .numel () == 0 :
219228 return image
220- elif image .ndim == 2 :
221- return _scale_channel (image )
222- else :
223- return torch .stack ([_scale_channel (x ) for x in image .view (- 1 , height , width )]).view (image .shape )
229+
230+ return _equalize_image_tensor_vec (image .view (- 1 , height , width )).view (image .shape )
224231
225232
226233equalize_image_pil = _FP .equalize
0 commit comments