Skip to content

ResizeV2 is calling round() on float16 #7667

@NicolasHug

Description

@NicolasHug
import torch
from torchvision.transforms import Resize
from torchvision.transforms.v2 import Resize as ResizeV2

img = torch.randint(0, 256, size=(1, 3, 512, 521), dtype=torch.float16)
out = Resize((10, 10))(img)
outV2 = ResizeV2((10, 10))(img)
print(out)  # Not rounded, OK
print(outV2)  # all rounded, bad

This is because of these lines which will round regardless of original dtype

if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
image = image.clamp_(min=0, max=255)
image = image.round_().to(dtype=dtype)

We should instead only cast on int dtypes like in the V1 version:

if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
# it is better to round before cast
img = torch.round(img)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions