Skip to content

Commit 3591964

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] extend equalize to all integer and floating dtypes (#6851)
Summary: * extend equalize to all integer and floating dtypes * address nits Reviewed By: datumbox Differential Revision: D40851020 fbshipit-source-id: 63b36f02e630b9c230431527b359876fedc52f3e
1 parent f2e76af commit 3591964

File tree

2 files changed

+50
-35
lines changed

2 files changed

+50
-35
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,7 @@ def sample_inputs_gaussian_blur_video():
13221322

13231323
def sample_inputs_equalize_image_tensor():
13241324
for image_loader in make_image_loaders(
1325-
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
1325+
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
13261326
):
13271327
yield ArgsKwargs(image_loader)
13281328

@@ -1331,27 +1331,41 @@ def reference_inputs_equalize_image_tensor():
13311331
# We are not using `make_image_loaders` here since that uniformly samples the values over the whole value range.
13321332
# Since the whole point of this kernel is to transform an arbitrary distribution of values into a uniform one,
13331333
# the information gain is low if we already provide something really close to the expected value.
1334+
def make_uniform_band_image(shape, dtype, device, *, low_factor, high_factor):
1335+
if dtype.is_floating_point:
1336+
low = low_factor
1337+
high = high_factor
1338+
else:
1339+
max_value = torch.iinfo(dtype).max
1340+
low = int(low_factor * max_value)
1341+
high = int(high_factor * max_value)
1342+
return torch.testing.make_tensor(shape, dtype=dtype, device=device, low=low, high=high)
1343+
1344+
def make_beta_distributed_image(shape, dtype, device, *, alpha, beta):
1345+
image = torch.distributions.Beta(alpha, beta).sample(shape)
1346+
if not dtype.is_floating_point:
1347+
image.mul_(torch.iinfo(dtype).max).round_()
1348+
return image.to(dtype=dtype, device=device)
1349+
13341350
spatial_size = (256, 256)
1335-
for fn, color_space in itertools.product(
1351+
for dtype, color_space, fn in itertools.product(
1352+
[torch.uint8, torch.float32],
1353+
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
13361354
[
1355+
lambda shape, dtype, device: torch.zeros(shape, dtype=dtype, device=device),
1356+
lambda shape, dtype, device: torch.full(
1357+
shape, 1.0 if dtype.is_floating_point else torch.iinfo(dtype).max, dtype=dtype, device=device
1358+
),
13371359
*[
1338-
lambda shape, dtype, device, low=low, high=high: torch.randint(
1339-
low, high, shape, dtype=dtype, device=device
1340-
)
1341-
for low, high in [
1342-
(0, 1),
1343-
(255, 256),
1344-
(0, 64),
1345-
(64, 192),
1346-
(192, 256),
1360+
functools.partial(make_uniform_band_image, low_factor=low_factor, high_factor=high_factor)
1361+
for low_factor, high_factor in [
1362+
(0.0, 0.25),
1363+
(0.25, 0.75),
1364+
(0.75, 1.0),
13471365
]
13481366
],
13491367
*[
1350-
lambda shape, dtype, device, alpha=alpha, beta=beta: torch.distributions.Beta(alpha, beta)
1351-
.sample(shape)
1352-
.mul_(255)
1353-
.round_()
1354-
.to(dtype=dtype, device=device)
1368+
functools.partial(make_beta_distributed_image, alpha=alpha, beta=beta)
13551369
for alpha, beta in [
13561370
(0.5, 0.5),
13571371
(2, 2),
@@ -1360,10 +1374,9 @@ def reference_inputs_equalize_image_tensor():
13601374
]
13611375
],
13621376
],
1363-
[features.ColorSpace.GRAY, features.ColorSpace.RGB],
13641377
):
13651378
image_loader = ImageLoader(
1366-
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=torch.uint8, color_space=color_space
1379+
fn, shape=(get_num_channels(color_space), *spatial_size), dtype=dtype, color_space=color_space
13671380
)
13681381
yield ArgsKwargs(image_loader)
13691382

torchvision/prototype/transforms/functional/_color.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -371,33 +371,34 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
371371

372372

373373
def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
374-
if image.dtype != torch.uint8:
375-
raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}")
376-
377-
num_channels, height, width = get_dimensions_image_tensor(image)
378-
if num_channels not in (1, 3):
379-
raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}")
380-
381374
if image.numel() == 0:
382375
return image
383376

377+
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
378+
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
379+
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
380+
# unfeasible for `torch.int64`.
381+
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
382+
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
383+
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
384+
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
385+
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
386+
# by far the most common, we choose it as base.
387+
output_dtype = image.dtype
388+
image = convert_dtype_image_tensor(image, torch.uint8)
389+
390+
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
391+
# corresponds to adding 1 to index 127 in the histogram.
384392
batch_shape = image.shape[:-2]
385393
flat_image = image.flatten(start_dim=-2).to(torch.long)
386-
387-
# The algorithm for histogram equalization is mirrored from PIL:
388-
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
389-
390-
# Although PyTorch has builtin functionality for histograms, it doesn't support batches. Since we deal with uint8
391-
# images here and thus the values are already binned, the computation is trivial. The histogram is computed by using
392-
# the flattened image as index. For example, a pixel value of 127 in the image corresponds to adding 1 to index 127
393-
# in the histogram.
394394
hist = flat_image.new_zeros(batch_shape + (256,), dtype=torch.int32)
395395
hist.scatter_add_(dim=-1, index=flat_image, src=hist.new_ones(1).expand_as(flat_image))
396396
cum_hist = hist.cumsum(dim=-1)
397397

398398
# The simplest form of lookup-table (LUT) that also achieves histogram equalization is
399399
# `lut = cum_hist / flat_image.shape[-1] * 255`
400400
# However, PIL uses a more elaborate scheme:
401+
# https://github.com/python-pillow/Pillow/blob/eb59cb61d5239ee69cbbf12709a0c6fd7314e6d7/src/PIL/ImageOps.py#L368-L385
401402
# `lut = ((cum_hist + num_non_max_pixels // (2 * 255)) // num_non_max_pixels) * 255`
402403

403404
# The last non-zero element in the histogram is the first element in the cumulative histogram with the maximum
@@ -415,7 +416,7 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
415416
# easy due to our support for batched images. We can only return early if `(step == 0).all()` holds. If it doesn't,
416417
# we have to go through the computation below anyway. Since `step == 0` is an edge case anyway, it makes no sense to
417418
# pay the runtime cost for checking it every time.
418-
no_equalization = step.eq(0).unsqueeze_(-1)
419+
valid_equalization = step.ne(0).unsqueeze_(-1)
419420

420421
# `lut[k]` is computed with `cum_hist[k-1]` with `lut[0] == (step // 2) // step == 0`. Thus, we perform the
421422
# computation only for `lut[1:]` with `cum_hist[:-1]` and add `lut[0] == 0` afterwards.
@@ -434,7 +435,8 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
434435
lut = torch.cat([lut.new_zeros(1).expand(batch_shape + (1,)), lut], dim=-1)
435436
equalized_image = lut.gather(dim=-1, index=flat_image).view_as(image)
436437

437-
return torch.where(no_equalization, image, equalized_image)
438+
output = torch.where(valid_equalization, equalized_image, image)
439+
return convert_dtype_image_tensor(output, output_dtype)
438440

439441

440442
equalize_image_pil = _FP.equalize

0 commit comments

Comments
 (0)