From 9781b8321804d40306473f31dbbcaac6c7addb7f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 24 Oct 2022 15:51:29 +0200 Subject: [PATCH 1/2] remove unneccesary checks from posterize_image_tensor --- torchvision/prototype/transforms/functional/_color.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 17878b0c698..5e69524c7c0 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -295,7 +295,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) -posterize_image_tensor = _FT.posterize +def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: + if image.dtype != torch.uint8: + raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") + + # JIT-friendly for: ~(2 ** (8 - bits) - 1) + mask = -int(2 ** (8 - bits)) + return mask & image + + posterize_image_pil = _FP.posterize From 292b8bba558803a51e26c36758b72135b0837cf2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 24 Oct 2022 19:14:57 +0200 Subject: [PATCH 2/2] fix JIT --- torchvision/prototype/transforms/functional/_color.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 5e69524c7c0..35e0abd833a 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -301,7 +301,7 @@ def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: # JIT-friendly for: ~(2 ** (8 - bits) - 1) mask = -int(2 ** (8 - bits)) - return mask & image + return image & mask posterize_image_pil = _FP.posterize