From d5da6e437e52bdf9aa9d16071df4d1e1d6f71e3a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Fri, 25 Nov 2022 10:01:34 +0100 Subject: [PATCH] use bitshifts for int to int in convert_dtype --- torchvision/prototype/transforms/functional/_meta.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/torchvision/prototype/transforms/functional/_meta.py b/torchvision/prototype/transforms/functional/_meta.py index 0d2bd7bf10b..a2da77b1267 100644 --- a/torchvision/prototype/transforms/functional/_meta.py +++ b/torchvision/prototype/transforms/functional/_meta.py @@ -379,15 +379,7 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f if num_value_bits_input > num_value_bits_output: return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) else: - # The bitshift kernel is not vectorized - # https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322 - # This results in the multiplication actually being faster. - # TODO: If the bitshift kernel is optimized in core, replace the computation below with - # `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)` - max_value_input = float(_FT._max_value(dtype)) - max_value_output = float(_FT._max_value(image.dtype)) - factor = int((max_value_input + 1) // (max_value_output + 1)) - return image.to(dtype).mul_(factor) + return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) # We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is