-
Notifications
You must be signed in to change notification settings - Fork 7.2k
Closed
Labels
Description
import torch
from torchvision.transforms import Resize
from torchvision.transforms.v2 import Resize as ResizeV2
img = torch.randint(0, 256, size=(1, 3, 512, 521), dtype=torch.float16)
out = Resize((10, 10))(img)
outV2 = ResizeV2((10, 10))(img)
print(out) # Not rounded, OK
print(outV2) # all rounded, badThis is because of these lines which will round regardless of original dtype
vision/torchvision/transforms/v2/functional/_geometry.py
Lines 228 to 231 in 906c2e9
| if need_cast: | |
| if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8: | |
| image = image.clamp_(min=0, max=255) | |
| image = image.round_().to(dtype=dtype) |
We should instead only cast on int dtypes like in the V1 version:
vision/torchvision/transforms/_functional_tensor.py
Lines 537 to 539 in 906c2e9
| if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64): | |
| # it is better to round before cast | |
| img = torch.round(img) |
antoinebrl