diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 17878b0c698..742b344cf71 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -313,7 +313,8 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: - if threshold > _FT._max_value(image.dtype): + bound = 1 if image.is_floating_point() else 255 + if threshold > bound: raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") return torch.where(image >= threshold, invert_image_tensor(image), image) @@ -466,7 +467,7 @@ def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.dtype == torch.uint8: return image.bitwise_not() else: - return _FT._max_value(image.dtype) - image # type: ignore[no-any-return] + return (1 if image.is_floating_point() else 255) - image # type: ignore[no-any-return] invert_image_pil = _FP.invert