@@ -371,33 +371,34 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
371371
372372
373373def  equalize_image_tensor (image : torch .Tensor ) ->  torch .Tensor :
374-     if  image .dtype  !=  torch .uint8 :
375-         raise  TypeError (f"Only torch.uint8 image tensors are supported, but found { image .dtype }  " )
376- 
377-     num_channels , height , width  =  get_dimensions_image_tensor (image )
378-     if  num_channels  not  in   (1 , 3 ):
379-         raise  TypeError (f"Input image tensor can have 1 or 3 channels, but found { num_channels }  " )
380- 
381374    if  image .numel () ==  0 :
382375        return  image 
383376
377+     # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that 
378+     #    would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for 
379+     #    `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely 
380+     #    unfeasible for `torch.int64`. 
381+     # 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we 
382+     #    could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition 
383+     #    to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower 
384+     #    and more complicated to implement than a simple conversion and a fast histogram implementation for integers. 
385+     # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is 
386+     # by far the most common, we choose it as base. 
387+     output_dtype  =  image .dtype 
388+     image  =  convert_dtype_image_tensor (image , torch .uint8 )
389+ 
390+     # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image 
391+     # corresponds to adding 1 to index 127 in the histogram. 
384392    batch_shape  =  image .shape [:- 2 ]
385393    flat_image  =  image .flatten (start_dim = - 2 ).to (torch .long )
386- 
387-     # The algorithm for histogram equalization is mirrored from PIL: 
388-     # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 
389- 
390-     # Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8 
391-     # images here and thus the values are already binned, the computation is trivial. The histogram is computed by using 
392-     # the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127 
393-     # in the histogram. 
394394    hist  =  flat_image .new_zeros (batch_shape  +  (256 ,), dtype = torch .int32 )
395395    hist .scatter_add_ (dim = - 1 , index = flat_image , src = hist .new_ones (1 ).expand_as (flat_image ))
396396    cum_hist  =  hist .cumsum (dim = - 1 )
397397
398398    # The simplest form of lookup-table (LUT) that also achieves histogram equalization is 
399399    # `lut = cum_hist / flat_image.shape[-1] * 255` 
400400    # However, PIL uses a more elaborate scheme: 
401+     # https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385 
401402    # `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255` 
402403
403404    # The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum 
@@ -415,7 +416,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
415416    # easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't, 
416417    # we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to 
417418    # pay the runtime cost for checking it every time. 
418-     no_equalization  =  step .eq (0 ).unsqueeze_ (- 1 )
419+     valid_equalization  =  step .ne (0 ).unsqueeze_ (- 1 )
419420
420421    # `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the 
421422    # computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards. 
@@ -434,7 +435,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
434435    lut  =  torch .cat ([lut .new_zeros (1 ).expand (batch_shape  +  (1 ,)), lut ], dim = - 1 )
435436    equalized_image  =  lut .gather (dim = - 1 , index = flat_image ).view_as (image )
436437
437-     return  torch .where (no_equalization , image , equalized_image )
438+     output  =  torch .where (valid_equalization , equalized_image , image )
439+     return  convert_dtype_image_tensor (output , output_dtype )
438440
439441
440442equalize_image_pil  =  _FP .equalize 
0 commit comments