@@ -228,39 +228,6 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
228228        return  autocontrast_image_pil (inpt )
229229
230230
231- def  _equalize_image_tensor_vec (image : torch .Tensor ) ->  torch .Tensor :
232-     # input image shape should be [N, H, W] 
233-     shape  =  image .shape 
234-     # Compute image histogram: 
235-     flat_image  =  image .flatten (start_dim = 1 ).to (torch .long )  # -> [N, H * W] 
236-     hist  =  flat_image .new_zeros (shape [0 ], 256 )
237-     hist .scatter_add_ (dim = 1 , index = flat_image , src = flat_image .new_ones (1 ).expand_as (flat_image ))
238- 
239-     # Compute image cdf 
240-     chist  =  hist .cumsum_ (dim = 1 )
241-     # Compute steps, where step per channel is nonzero_hist[:-1].sum() // 255 
242-     # Trick: nonzero_hist[:-1].sum() == chist[idx - 1], where idx = chist.argmax() 
243-     idx  =  chist .argmax (dim = 1 ).sub_ (1 )
244-     # If histogram is degenerate (hist of zero image), index is -1 
245-     neg_idx_mask  =  idx  <  0 
246-     idx .clamp_ (min = 0 )
247-     step  =  chist .gather (dim = 1 , index = idx .unsqueeze (1 ))
248-     step [neg_idx_mask ] =  0 
249-     step .div_ (255 , rounding_mode = "floor" )
250- 
251-     # Compute batched Look-up-table: 
252-     # Necessary to avoid an integer division by zero, which raises 
253-     clamped_step  =  step .clamp (min = 1 )
254-     chist .add_ (torch .div (step , 2 , rounding_mode = "floor" )).div_ (clamped_step , rounding_mode = "floor" ).clamp_ (0 , 255 )
255-     lut  =  chist .to (torch .uint8 )  # [N, 256] 
256- 
257-     # Pad lut with zeros 
258-     zeros  =  lut .new_zeros ((1 , 1 )).expand (shape [0 ], 1 )
259-     lut  =  torch .cat ([zeros , lut [:, :- 1 ]], dim = 1 )
260- 
261-     return  torch .where ((step  ==  0 ).unsqueeze (- 1 ), image , lut .gather (dim = 1 , index = flat_image ).reshape_as (image ))
262- 
263- 
264231def  equalize_image_tensor (image : torch .Tensor ) ->  torch .Tensor :
265232    if  image .dtype  !=  torch .uint8 :
266233        raise  TypeError (f"Only torch.uint8 image tensors are supported, but found { image .dtype }  )
@@ -272,7 +239,60 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
272239    if  image .numel () ==  0 :
273240        return  image 
274241
275-     return  _equalize_image_tensor_vec (image .reshape (- 1 , height , width )).reshape (image .shape )
242+     batch_shape  =  image .shape [:- 2 ]
243+     flat_image  =  image .flatten (start_dim = - 2 ).to (torch .long )
244+ 
245+     # The algorithm for histogram equalization is mirrored from PIL: 
246+     # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 
247+ 
248+     # Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8 
249+     # images here and thus the values are already binned, the computation is trivial. The histogram is computed by using 
250+     # the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127 
251+     # in the histogram. 
252+     hist  =  flat_image .new_zeros (batch_shape  +  (256 ,), dtype = torch .int32 )
253+     hist .scatter_add_ (dim = - 1 , index = flat_image , src = hist .new_ones (1 ).expand_as (flat_image ))
254+     cum_hist  =  hist .cumsum (dim = - 1 )
255+ 
256+     # The simplest form of lookup-table (LUT) that also achieves histogram equalization is 
257+     # `lut = cum_hist / flat_image.shape[-1] * 255` 
258+     # However, PIL uses a more elaborate scheme: 
259+     # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` 
260+ 
261+     # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum 
262+     # value. Thus, the "max" in `num_non_max_pixels` does not refer to 255 as the maximum value of uint8 images, but 
263+     # rather the maximum value in the image, which might be or not be 255. 
264+     index  =  cum_hist .argmax (dim = - 1 )
265+     num_non_max_pixels  =  flat_image .shape [- 1 ] -  hist .gather (dim = - 1 , index = index .unsqueeze_ (- 1 ))
266+ 
267+     # This is performance optimization that saves us one multiplication later. With this, the LUT computation simplifies 
268+     # to `lut = (cum_hist + step // 2) // step` and thus saving the final multiplication by 255 while keeping the 
269+     # division count the same. PIL uses the variable name `step` for this, so we keep that for easier comparison. 
270+     step  =  num_non_max_pixels .div_ (255 , rounding_mode = "floor" )
271+ 
272+     # Although it looks like we could return early if we find `step == 0` like PIL does, that is unfortunately not as 
273+     # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, 
274+     # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to 
275+     # pay the runtime cost for checking it every time. 
276+     no_equalization  =  step .eq (0 ).unsqueeze_ (- 1 )
277+ 
278+     # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the 
279+     # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. 
280+     cum_hist  =  cum_hist [..., :- 1 ]
281+     (
282+         cum_hist .add_ (step  //  2 )
283+         # We need the `clamp_`(min=1) call here to avoid zero division since they fail for integer dtypes. This has no 
284+         # effect on the returned result of this kernel since images inside the batch with `step == 0` are returned as is 
285+         # instead of equalized version. 
286+         .div_ (step .clamp_ (min = 1 ), rounding_mode = "floor" )
287+         # We need the `clamp_` call here since PILs LUT computation scheme can produce values outside the valid value 
288+         # range of uint8 images 
289+         .clamp_ (0 , 255 )
290+     )
291+     lut  =  cum_hist .to (torch .uint8 )
292+     lut  =  torch .cat ([lut .new_zeros (1 ).expand (batch_shape  +  (1 ,)), lut ], dim = - 1 )
293+     equalized_image  =  lut .gather (dim = - 1 , index = flat_image ).view_as (image )
294+ 
295+     return  torch .where (no_equalization , image , equalized_image )
276296
277297
278298equalize_image_pil  =  _FP .equalize 
0 commit comments