From 7d19a645fd738e0f2444cb5722ff61c4fa552ca6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 16 Feb 2023 00:52:02 +0100 Subject: [PATCH 1/2] Resize() rely on interpolate()'s native uint8 handling instead of converting to and from float. --- torchvision/prototype/transforms/functional/_geometry.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 22731bb157f..07b33d78a15 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -183,7 +183,10 @@ def resize_image_tensor( image = image.reshape(-1, num_channels, old_height, old_width) dtype = image.dtype - need_cast = dtype not in (torch.float32, torch.float64) + acceptable_dtypes = [torch.float32, torch.float64] + if interpolation.value in ["nearest", "bilinear"]: + acceptable_dtypes.append(torch.uint8) + need_cast = dtype not in acceptable_dtypes if need_cast: image = image.to(dtype=torch.float32) From c4b0f96fc217d4efd543c751bf6681f76ee9e98c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 16 Feb 2023 12:40:05 +0100 Subject: [PATCH 2/2] WIP, Added resize benchmark on arm64 m1 macosx --- .github/workflows/test-m1.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/test-m1.yml b/.github/workflows/test-m1.yml index c03fa9f76e4..98112740696 100644 --- a/.github/workflows/test-m1.yml +++ b/.github/workflows/test-m1.yml @@ -46,5 +46,12 @@ jobs: run: | . ~/miniconda3/etc/profile.d/conda.sh set -ex + + # Run resize benchmark + echo "--- Run resize benchmark ---" + wget https://gist.githubusercontent.com/vfdev-5/a2e30ed50b5996807c9b09d5d33d8bc2/raw/f6691d472e1729e39448d6753bd350a8433d0f08/check_resize_uint8.py + conda run -p ${ENV_NAME} python3 -u check_resize_uint8.py + echo "--- END Run resize benchmark ---" + conda run -p ${ENV_NAME} --no-capture-output python3 -u -mpytest -v --tb=long --durations 20 conda env remove -p ${ENV_NAME}