Skip to content

Commit 7d19a64

Browse files
committed
Resize() rely on interpolate()'s native uint8 handling instead of converting to and from float.
1 parent d010e82 commit 7d19a64

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,10 @@ def resize_image_tensor(
183183
image = image.reshape(-1, num_channels, old_height, old_width)
184184

185185
dtype = image.dtype
186-
need_cast = dtype not in (torch.float32, torch.float64)
186+
acceptable_dtypes = [torch.float32, torch.float64]
187+
if interpolation.value in ["nearest", "bilinear"]:
188+
acceptable_dtypes.append(torch.uint8)
189+
need_cast = dtype not in acceptable_dtypes
187190
if need_cast:
188191
image = image.to(dtype=torch.float32)
189192

0 commit comments

Comments
 (0)