From f1a248fd1a062aadc88d2098bf2b2a934eb2d9e7 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 24 Oct 2022 12:34:09 +0200 Subject: [PATCH 1/5] improve performance of invert_image_tensor --- torchvision/prototype/transforms/functional/_color.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 2c268fa4085..52f699cbf22 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -435,7 +435,14 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return equalize_image_pil(inpt) -invert_image_tensor = _FT.invert +def invert_image_tensor(image: torch.Tensor): + num_channels, height, width = get_dimensions_image_tensor(image) + if num_channels not in (1, 3): + raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + + return 1.0 - image if image.is_floating_point() else image.bitwise_not() + + invert_image_pil = _FP.invert From 0cc6cdd42f3b4925cf540981e1be968cf0978fca Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 24 Oct 2022 14:46:59 +0200 Subject: [PATCH 2/5] cleanup --- torchvision/prototype/transforms/functional/_color.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 52f699cbf22..4206575e97f 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -436,11 +436,12 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def invert_image_tensor(image: torch.Tensor): - num_channels, height, width = get_dimensions_image_tensor(image) - if num_channels not in (1, 3): - raise TypeError(f"Input image tensor can have 1 or 3 channels, but found {num_channels}") + _FT._assert_image_tensor(image) - return 1.0 - image if image.is_floating_point() else image.bitwise_not() + if image.dtype == torch.uint8: + return image.bitwise_not() + else: + return _FT._max_value(image.dtype) - image invert_image_pil = _FP.invert From 34c075c71ecaa9561d49eef2c65a02e7eacf3bdb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 24 Oct 2022 14:49:59 +0200 Subject: [PATCH 3/5] lint --- torchvision/prototype/transforms/functional/_color.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 4206575e97f..a1b219d78f2 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -435,13 +435,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return equalize_image_pil(inpt) -def invert_image_tensor(image: torch.Tensor): +def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: _FT._assert_image_tensor(image) if image.dtype == torch.uint8: return image.bitwise_not() else: - return _FT._max_value(image.dtype) - image + return _FT._max_value(image.dtype) - image # type: ignore[no-any-return] invert_image_pil = _FP.invert From 7adfeb5d135cf0d6ba145a91fcfc09684ec1b448 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 24 Oct 2022 14:51:28 +0200 Subject: [PATCH 4/5] more cleanup --- torchvision/prototype/transforms/functional/_color.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index a1b219d78f2..bcd5413dff2 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -436,8 +436,6 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: - _FT._assert_image_tensor(image) - if image.dtype == torch.uint8: return image.bitwise_not() else: From f8f68d787208126295962b0efc7983b9345fd617 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 24 Oct 2022 14:56:31 +0200 Subject: [PATCH 5/5] use new invert in solarize --- torchvision/prototype/transforms/functional/_color.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index f742a8b9cfc..17878b0c698 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -312,7 +312,13 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: return posterize_image_pil(inpt, bits=bits) -solarize_image_tensor = _FT.solarize +def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: + if threshold > _FT._max_value(image.dtype): + raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") + + return torch.where(image >= threshold, invert_image_tensor(image), image) + + solarize_image_pil = _FP.solarize