From 6134fd8615b16391757f8d400c473ef62f04c42f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 13 Jun 2023 04:06:52 -0700 Subject: [PATCH] Fix: don't call round() on float images for ResizeV2 --- test/test_transforms_v2_functional.py | 10 ++++++++++ torchvision/transforms/v2/functional/_geometry.py | 4 +++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2_functional.py b/test/test_transforms_v2_functional.py index ed861fee97e..60a06f571b1 100644 --- a/test/test_transforms_v2_functional.py +++ b/test/test_transforms_v2_functional.py @@ -1395,3 +1395,13 @@ def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwarg assert expected_stride == output_stride, error_msg_fn("") else: assert False, error_msg_fn("") + + +def test_resize_float16_no_rounding(): + # Make sure Resize() doesn't round float16 images + # Non-regression test for https://github.com/pytorch/vision/issues/7667 + + img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16) + out = F.resize(img, size=(10, 10)) + assert out.dtype == torch.float16 + assert (out.round() - out).sum() > 0 diff --git a/torchvision/transforms/v2/functional/_geometry.py b/torchvision/transforms/v2/functional/_geometry.py index b9124f280bd..ced7ff0b28b 100644 --- a/torchvision/transforms/v2/functional/_geometry.py +++ b/torchvision/transforms/v2/functional/_geometry.py @@ -228,7 +228,9 @@ def resize_image_tensor( if need_cast: if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8: image = image.clamp_(min=0, max=255) - image = image.round_().to(dtype=dtype) + if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): + image = image.round_() + image = image.to(dtype=dtype) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))