diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 34f1f875a05..c8cca77e0db 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -19,7 +19,6 @@ make_video_loaders, mark_framework_limitation, TestMark, - VALID_EXTRA_DIMS, ) from torchvision.prototype import features from torchvision.transforms.functional_tensor import _max_value as get_max_value @@ -215,16 +214,6 @@ def sample_inputs_resize_image_tensor(): ): yield ArgsKwargs(image_loader, size=[min(image_loader.image_size) + 1], interpolation=interpolation) - # We have a speed hack in place for nearest interpolation and single channel images (grayscale) - for image_loader in make_image_loaders( - sizes=["random"], - color_spaces=[features.ColorSpace.GRAY], - extra_dims=VALID_EXTRA_DIMS, - ): - yield ArgsKwargs( - image_loader, size=[min(image_loader.image_size) + 1], interpolation=F.InterpolationMode.NEAREST - ) - yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 2c064245e8a..93df59ad646 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -14,12 +14,7 @@ pil_to_tensor, to_pil_image, ) -from torchvision.transforms.functional_tensor import ( - _cast_squeeze_in, - _cast_squeeze_out, - _parse_pad_padding, - interpolate, -) +from torchvision.transforms.functional_tensor import _parse_pad_padding from ._meta import ( convert_format_bounding_box, @@ -130,32 +125,12 @@ def resize_image_tensor( if image.numel() > 0: image = image.view(-1, num_channels, old_height, old_width) - # This is a perf hack to avoid slow channels_last upsample code path - # Related issue: https://github.com/pytorch/pytorch/issues/83840 - # We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path - if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST: - # Below code is copied from _FT.resize - # This is due to the fact that we need to apply the hack on casted image and not before - # Otherwise, image will be copied while cast to float and interpolate will work on twice more data - image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64]) - - shape = (image.shape[0], 2, image.shape[2], image.shape[3]) - image = image.expand(shape) - - image = interpolate( - image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False - ) - - image = image[:, 0, ...] - image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype) - - else: - image = _FT.resize( - image, - size=[new_height, new_width], - interpolation=interpolation.value, - antialias=antialias, - ) + image = _FT.resize( + image, + size=[new_height, new_width], + interpolation=interpolation.value, + antialias=antialias, + ) return image.view(extra_dims + (num_channels, new_height, new_width))