Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,8 @@ def __init__(
ArgsKwargs([32]),
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
Expand All @@ -101,6 +99,15 @@ def __init__(
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.Resize,
legacy_transforms.Resize,
[
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to pull these tests out so in order not to affect the atol=1 for the other tests.

),
ConsistencyConfig(
v2_transforms.CenterCrop,
legacy_transforms.CenterCrop,
Expand Down Expand Up @@ -309,15 +316,22 @@ def __init__(
ArgsKwargs(17, scale=(0.3, 0.7)),
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=v2_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
closeness_kwargs=dict(rtol=0, atol=1),
),
ConsistencyConfig(
v2_transforms.RandomResizedCrop,
legacy_transforms.RandomResizedCrop,
[
ArgsKwargs((33, 26), interpolation=v2_transforms.InterpolationMode.BICUBIC, antialias=True),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC, antialias=True),
],
closeness_kwargs=dict(rtol=0, atol=21),
),
ConsistencyConfig(
v2_transforms.RandomErasing,
legacy_transforms.RandomErasing,
Expand Down
28 changes: 23 additions & 5 deletions test/transforms_v2_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,20 @@ def sample_inputs_resize_image_tensor():

for image_loader, interpolation in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=["RGB"]),
[
F.InterpolationMode.NEAREST,
F.InterpolationMode.BILINEAR,
F.InterpolationMode.BICUBIC,
],
[F.InterpolationMode.NEAREST, F.InterpolationMode.BILINEAR],
):
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)

yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)


def sample_inputs_resize_image_tensor_bicubic():
for image_loader, interpolation in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=["RGB"]), [F.InterpolationMode.BICUBIC]
):
yield ArgsKwargs(image_loader, size=[min(image_loader.spatial_size) + 1], interpolation=interpolation)


@pil_reference_wrapper
def reference_resize_image_tensor(*args, **kwargs):
if not kwargs.pop("antialias", False) and kwargs.get("interpolation", F.InterpolationMode.BILINEAR) in {
Expand Down Expand Up @@ -364,6 +367,21 @@ def reference_inputs_resize_bounding_box():
xfail_jit_python_scalar_arg("size"),
],
),
KernelInfo(
F.resize_image_tensor,
sample_inputs_fn=sample_inputs_resize_image_tensor_bicubic,
reference_fn=reference_resize_image_tensor,
reference_inputs_fn=reference_inputs_resize_image_tensor,
float32_vs_uint8=True,
closeness_kwargs={
**pil_reference_pixel_difference(10, mae=True),
**cuda_vs_cpu_pixel_difference(atol=30),
**float32_vs_uint8_pixel_difference(1, mae=True),
},
test_marks=[
xfail_jit_python_scalar_arg("size"),
],
),
KernelInfo(
F.resize_bounding_box,
sample_inputs_fn=sample_inputs_resize_bounding_box,
Expand Down
16 changes: 8 additions & 8 deletions torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,14 +190,13 @@ def resize_image_tensor(
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif (
interpolation == InterpolationMode.BILINEAR
and image.device.type == "cpu"
and "AVX2" in torch.backends.cpu.get_cpu_capability()
):
# uint8 dtype support for bilinear mode is limited to cpu and
# according to our benchmarks non-AVX CPUs should prefer u8->f32->interpolate->u8 path
acceptable_dtypes.append(torch.uint8)
elif image.device.type == "cpu":
# uint8 dtype support for bilinear and bicubic is limited to cpu and
# according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
interpolation == InterpolationMode.BICUBIC
):
acceptable_dtypes.append(torch.uint8)

strides = image.stride()
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
Expand Down Expand Up @@ -227,6 +226,7 @@ def resize_image_tensor(

if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
# This path is hit on non-AVX archs, or on GPU.
image = image.clamp_(min=0, max=255)
if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
image = image.round_()
Expand Down