Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
make_video_loaders,
mark_framework_limitation,
TestMark,
VALID_EXTRA_DIMS,
)
from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _max_value as get_max_value
Expand Down Expand Up @@ -215,16 +214,6 @@ def sample_inputs_resize_image_tensor():
):
yield ArgsKwargs(image_loader, size=[min(image_loader.image_size) + 1], interpolation=interpolation)

# We have a speed hack in place for nearest interpolation and single channel images (grayscale)
for image_loader in make_image_loaders(
sizes=["random"],
color_spaces=[features.ColorSpace.GRAY],
extra_dims=VALID_EXTRA_DIMS,
):
yield ArgsKwargs(
image_loader, size=[min(image_loader.image_size) + 1], interpolation=F.InterpolationMode.NEAREST
)

yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)


Expand Down
39 changes: 7 additions & 32 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import (
_cast_squeeze_in,
_cast_squeeze_out,
_parse_pad_padding,
interpolate,
)
from torchvision.transforms.functional_tensor import _parse_pad_padding

from ._meta import (
convert_format_bounding_box,
Expand Down Expand Up @@ -130,32 +125,12 @@ def resize_image_tensor(
if image.numel() > 0:
image = image.view(-1, num_channels, old_height, old_width)

# This is a perf hack to avoid slow channels_last upsample code path
# Related issue: https://github.com/pytorch/pytorch/issues/83840
# We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path
if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST:
# Below code is copied from _FT.resize
# This is due to the fact that we need to apply the hack on casted image and not before
# Otherwise, image will be copied while cast to float and interpolate will work on twice more data
image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64])

shape = (image.shape[0], 2, image.shape[2], image.shape[3])
image = image.expand(shape)

image = interpolate(
image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False
)

image = image[:, 0, ...]
image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)

else:
image = _FT.resize(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
image = _FT.resize(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)

return image.view(extra_dims + (num_channels, new_height, new_width))

Expand Down