diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e641f8ca62e..9b4c1b5f9af 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,3 +1,4 @@ +from functools import partial import itertools import os import colorsys @@ -578,6 +579,52 @@ def test_assert_resize_antialias(interpolation): F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) +@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16]) +@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]]) +@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) +def test_interpolate_antialias_backward(dt, size, interpolation): + + # temporarily hard-code device as CPU, CUDA support will be done later + device = "cpu" + + if dt == torch.float16 and device == "cpu": + # skip float16 on CPU case + return + + torch.manual_seed(12) + if interpolation == BILINEAR: + forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa + backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward + elif interpolation == BICUBIC: + forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa + backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward + + class F(torch.autograd.Function): + + @staticmethod + def forward(ctx, i): + result = forward_op(i, size, False) + ctx.save_for_backward(i, result) + return result + + @staticmethod + def backward(ctx, grad_output): + i, result = ctx.saved_tensors + ishape = i.shape + oshape = result.shape[2:] + return backward_op(grad_output, oshape, ishape, False) + + x = ( + torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True), + ) + assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) + + x = ( + torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True), + ) + assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False) + + def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"): script_fn = torch.jit.script(fn) diff --git a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp index 97b025aafb4..32652466916 100644 --- a/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp +++ b/torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -141,6 +142,41 @@ void ti_cpu_upsample_generic_aa( // Helper structs to use with ti_upsample_generic_Nd_kernel_impl template struct HelperInterpBase { + template + static inline void _compute_weights_aa( + const int64_t i, + const int64_t input_size, + const scalar_t scale, + const scalar_t support, + scalar_t* wt_ptr, + const int64_t interp_size, + filter_fn_t filter_fn, + int64_t& xmin, + int64_t& xsize) { + scalar_t center = scale * (i + 0.5); + scalar_t total_w = 0.0; + scalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; + xmin = std::max( + static_cast(center - support + 0.5), static_cast(0)); + xsize = std::min(static_cast(center + support + 0.5), input_size) - + xmin; + + int64_t j = 0; + for (; j < xsize; j++) { + scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale); + wt_ptr[j] = w; + total_w += w; + } + for (j = 0; j < xsize; j++) { + if (total_w != 0.0) { + wt_ptr[j] /= total_w; + } + } + for (; j < interp_size; j++) { + wt_ptr[j] = static_cast(0.0); + } + } + template static inline std::vector _compute_indices_weights_aa( int64_t input_size, @@ -187,43 +223,30 @@ struct HelperInterpBase { empty(new_shape, CPU(c10::CppTypeToScalarType()))); } - scalar_t center, total_w, invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; - index_t zero = static_cast(0); int64_t* idx_ptr_xmin = output[0].data_ptr(); int64_t* idx_ptr_size = output[1].data_ptr(); int64_t* idx_ptr_stride = output[2].data_ptr(); scalar_t* wt_ptr = output[3].data_ptr(); int64_t* wt_idx_ptr = output[4].data_ptr(); - int64_t xmin, xmax, j; + int64_t xmin, xmax; for (int64_t i = 0; i < output_size; i++) { - center = scale * (i + 0.5); - xmin = std::max(static_cast(center - support + 0.5), zero); - xmax = - std::min(static_cast(center + support + 0.5), input_size) - - xmin; + HelperInterpBase::_compute_weights_aa( + i, + input_size, + scale, + support, + wt_ptr + i * interp_size, + interp_size, + filter_fn, + xmin, + xmax); + idx_ptr_xmin[i] = xmin * stride; idx_ptr_size[i] = xmax; idx_ptr_stride[i] = stride; - wt_idx_ptr[i] = i * interp_size * sizeof(scalar_t); - - total_w = 0.0; - for (j = 0; j < xmax; j++) { - scalar_t w = filter_fn((j + xmin - center + 0.5) * invscale); - wt_ptr[i * interp_size + j] = w; - total_w += w; - } - for (j = 0; j < xmax; j++) { - if (total_w != 0.0) { - wt_ptr[i * interp_size + j] /= total_w; - } - } - - for (; j < interp_size; j++) { - wt_ptr[i * interp_size + j] = static_cast(0.0); - } } return output; } @@ -475,6 +498,151 @@ void _ti_upsample_bicubic2d_kernel_impl( output, input, align_corners, {scales_h, scales_w}, antialias); } +template < + typename scalar_t, + typename scale_type, + template + class F> +void cpu_upsample_genNd_backward_aa( + const Tensor& grad_input_, + const Tensor& grad_output_, + bool align_corners, + const scale_type& scales) { + TORCH_CHECK( + grad_input_.dtype() == grad_output_.dtype(), + "expected dtype ", + grad_output_.dtype(), + " for `grad_input` but got dtype ", + grad_input_.dtype()); + + auto grad_output = grad_output_.contiguous(); + auto grad_input = grad_input_.contiguous(); + + auto grad_output_data = grad_output.data_ptr(); + auto grad_input_data = grad_input.data_ptr(); + auto input_sizes = grad_input.sizes().vec(); + auto output_sizes = grad_output.sizes().vec(); + auto ndim = input_sizes.size(); + + // treat nbatch and channels as one dimension + int64_t channels = input_sizes[0] * input_sizes[1]; + int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1; + int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1; + int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1; + int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1; + int64_t input_width = input_sizes[ndim - 1]; + int64_t output_width = output_sizes[ndim - 1]; + + int64_t output_slice_size = output_depth * output_height * output_width; + int interp_size = F::interp_size; + + auto loop2d = [&](int64_t begin, int64_t end) { + const scalar_t height_scale = area_pixel_compute_scale( + input_height, output_height, align_corners, scales[0]); + const scalar_t width_scale = area_pixel_compute_scale( + input_width, output_width, align_corners, scales[1]); + + auto input_indexr = [=](int64_t c, int64_t h, int64_t w) { + return grad_input_data + c * input_height * input_width + + h * input_width + w; + }; + + const scalar_t support_h = (height_scale >= 1.0) + ? (interp_size * 0.5) * height_scale + : interp_size * 0.5; + const scalar_t support_w = (width_scale >= 1.0) + ? (interp_size * 0.5) * width_scale + : interp_size * 0.5; + + const int interp_height = (int)ceilf(support_h) * 2 + 1; + const int interp_width = (int)ceilf(support_w) * 2 + 1; + + std::vector wx(interp_width, 0.0); + std::vector wy(interp_height, 0.0); + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + int64_t xmin, ymin; + int64_t xsize, ysize; + auto filter_fn = F::_filter; + + for (int64_t oh = 0; oh < output_height; oh++) { + F::_compute_weights_aa( + oh, + input_height, + height_scale, + support_h, + wy.data(), + interp_height, + filter_fn, + ymin, + ysize); + + for (int64_t ow = 0; ow < output_width; ow++) { + F::_compute_weights_aa( + ow, + input_width, + width_scale, + support_w, + wx.data(), + interp_width, + filter_fn, + xmin, + xsize); + + for (int64_t c = begin; c < end; c++) { + scalar_t grad_output_value = + grad_output_data[c * output_slice_size + oh * output_width + ow]; + + for (size_t y = 0; y < ysize; y++) { + for (size_t x = 0; x < xsize; x++) { + *input_indexr(c, ymin + y, xmin + x) += + wx[x] * wy[y] * grad_output_value; + } + } + } + } + } + }; + + if (ndim == 4) { + // upsample bilinear 2d + at::parallel_for( + 0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d); + } else { + TORCH_CHECK(false, "Unsupported tensor ndim"); + } + + if (!grad_input_.is_contiguous()) { + grad_input_.copy_(grad_input); + } +} + +void _upsample_bilinear2d_aa_backward_kernel_impl( + const Tensor& grad_input, + const Tensor& grad_output, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + AT_DISPATCH_FLOATING_TYPES( + grad_output.scalar_type(), "upsample_bilinear2d_backward_cpu", [&] { + cpu_upsample_genNd_backward_aa( + grad_input, grad_output, align_corners, {scales_h, scales_w}); + }); +} + +void _upsample_bicubic2d_aa_backward_kernel_impl( + const Tensor& grad_input, + const Tensor& grad_output, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + AT_DISPATCH_FLOATING_TYPES( + grad_output.scalar_type(), "upsample_bicubic2d_backward_cpu", [&] { + cpu_upsample_genNd_backward_aa( + grad_input, grad_output, align_corners, {scales_h, scales_w}); + }); +} + } // namespace internal_upsample } // namespace native } // namespace at @@ -484,7 +652,7 @@ namespace ops { namespace { -at::Tensor interpolate_linear_aa_forward_kernel( +at::Tensor interpolate_bilinear2d_aa_forward_kernel( const at::Tensor& input, at::IntArrayRef output_size, bool align_corners) { @@ -515,7 +683,7 @@ at::Tensor interpolate_linear_aa_forward_kernel( return output; } -at::Tensor interpolate_bicubic_aa_forward_kernel( +at::Tensor interpolate_bicubic2d_aa_forward_kernel( const at::Tensor& input, at::IntArrayRef output_size, bool align_corners) { @@ -546,26 +714,109 @@ at::Tensor interpolate_bicubic_aa_forward_kernel( return output; } -// TODO: Implement backward function -// at::Tensor interpolate_linear_aa_backward_kernel( -// const at::Tensor& grad) { -// return grad_input; -// } +at::Tensor interpolate_bilinear2d_aa_backward_kernel( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners) { + c10::optional> scale_factors = {}; + + // Copied from UpSampleBilinear2d.cpp::upsample_bilinear2d_backward + auto grad_input = at::empty({0}, grad_output.options()); + auto osize = at::native::upsample::compute_output_size( + input_size, output_size, scale_factors); + auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); + auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); + + auto full_output_size = + at::native::upsample_2d_common_check(input_size, osize); + + TORCH_CHECK( + grad_output.dim() == 4, + "Expected grad_output to be a tensor of dimension 4 but got: dimension ", + grad_output.dim()); + + for (int i = 0; i < 4; ++i) { + TORCH_CHECK( + grad_output.size(i) == full_output_size[i], + "Expected grad_output to have the same shape as output;", + " output.size(", + i, + ") = ", + full_output_size[i], + " but got grad_output.size(", + i, + ") = ", + grad_output.size(i)); + } + + grad_input.resize_(input_size, grad_output.suggest_memory_format()); + grad_input.zero_(); + at::native::internal_upsample::_upsample_bilinear2d_aa_backward_kernel_impl( + grad_input, grad_output, align_corners, scale_h, scale_w); + + return grad_input; +} + +at::Tensor interpolate_bicubic2d_aa_backward_kernel( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners) { + c10::optional> scale_factors = {}; + + // Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward + auto grad_input = at::empty({0}, grad_output.options()); + auto osize = at::native::upsample::compute_output_size( + input_size, output_size, scale_factors); + auto scale_h = at::native::upsample::get_scale_value(scale_factors, 0); + auto scale_w = at::native::upsample::get_scale_value(scale_factors, 1); + + auto full_output_size = + at::native::upsample_2d_common_check(input_size, osize); + + TORCH_CHECK( + grad_output.dim() == 4, + "Expected grad_output to be a tensor of dimension 4 but got: dimension ", + grad_output.dim()); + + for (int i = 0; i < 4; ++i) { + TORCH_CHECK( + grad_output.size(i) == full_output_size[i], + "Expected grad_output to have the same shape as output;", + " output.size(", + i, + ") = ", + full_output_size[i], + " but got grad_output.size(", + i, + ") = ", + grad_output.size(i)); + } + + grad_input.resize_(input_size, grad_output.suggest_memory_format()); + grad_input.zero_(); + at::native::internal_upsample::_upsample_bicubic2d_aa_backward_kernel_impl( + grad_input, grad_output, align_corners, scale_h, scale_w); + + return grad_input; +} } // namespace TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"), - TORCH_FN(interpolate_linear_aa_forward_kernel)); + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"), + TORCH_FN(interpolate_bilinear2d_aa_forward_kernel)); m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"), - TORCH_FN(interpolate_bicubic_aa_forward_kernel)); - - // TODO: Implement backward function - // m.impl( - // TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa_backward"), - // TORCH_FN(interpolate_linear_aa_backward_kernel)); + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"), + TORCH_FN(interpolate_bicubic2d_aa_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"), + TORCH_FN(interpolate_bilinear2d_aa_backward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"), + TORCH_FN(interpolate_bicubic2d_aa_backward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu b/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu index 4259fa2b0e8..46525f4dd01 100644 --- a/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu +++ b/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu @@ -62,23 +62,22 @@ __device__ __forceinline__ static accscalar_t bicubic_filter(accscalar_t x) { template __device__ __forceinline__ static void _compute_weights( - const int64_t i, - const int64_t input_size, + const int i, + const int input_size, const accscalar_t scale, const accscalar_t support, scalar_t* wt_ptr, - int64_t interp_size, + int interp_size, filter_fn_t filter_fn, - int64_t& xmin, - int64_t& xmax) { + int& xmin, + int& xmax) { accscalar_t invscale = (scale >= 1.0) ? 1.0 / scale : 1.0; accscalar_t center = scale * (i + 0.5); - xmin = max( - static_cast(center - support + 0.5), static_cast(0)); - xmax = min(static_cast(center + support + 0.5), input_size) - xmin; + xmin = max(static_cast(center - support + 0.5), static_cast(0)); + xmax = min(static_cast(center + support + 0.5), input_size) - xmin; accscalar_t total_w = 0.0; - int64_t j = 0; + int j = 0; for (j = 0; j < xmax; j++) { accscalar_t w = filter_fn((j + xmin - center + 0.5) * invscale); wt_ptr[j] = static_cast(w); @@ -164,7 +163,7 @@ __global__ void upsample_gen2d_out_frame( scalar_t buffer2[256]; // Compute weights - int64_t xmin, xsize, ymin, ysize; + int xmin, xsize, ymin, ysize; typedef scalar_t (*filter_fn_t)(scalar_t); if (interp_size == 2) { _compute_weights( @@ -213,7 +212,7 @@ __global__ void upsample_gen2d_out_frame( for (int n = 0; n < batchsize; n++) { for (int c = 0; c < channels; ++c) { // interpolate on x-axis for ymin to ymin + ysize - for (int64_t y = 0; y < ysize; y++) { + for (int y = 0; y < ysize; y++) { // copy data into the local buffer and use // interpolate_aa_single_dim method for (int x = 0; x < xsize; x++) { @@ -372,7 +371,7 @@ at::Tensor interpolate_gen2d_aa_forward_kernel( return output; } -at::Tensor interpolate_linear_aa_forward_kernel( +at::Tensor interpolate_bilinear2d_aa_forward_kernel( const at::Tensor& input, at::IntArrayRef output_size, bool align_corners) { @@ -380,7 +379,7 @@ at::Tensor interpolate_linear_aa_forward_kernel( input, output_size, align_corners); } -at::Tensor interpolate_bicubic_aa_forward_kernel( +at::Tensor interpolate_bicubic2d_aa_forward_kernel( const at::Tensor& input, at::IntArrayRef output_size, bool align_corners) { @@ -392,11 +391,11 @@ at::Tensor interpolate_bicubic_aa_forward_kernel( TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_linear_aa"), - TORCH_FN(interpolate_linear_aa_forward_kernel)); + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa"), + TORCH_FN(interpolate_bilinear2d_aa_forward_kernel)); m.impl( - TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic_aa"), - TORCH_FN(interpolate_bicubic_aa_forward_kernel)); + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"), + TORCH_FN(interpolate_bicubic2d_aa_forward_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/interpolate_aa.cpp b/torchvision/csrc/ops/interpolate_aa.cpp index 90bc26a1fb5..7f1680246a0 100644 --- a/torchvision/csrc/ops/interpolate_aa.cpp +++ b/torchvision/csrc/ops/interpolate_aa.cpp @@ -5,54 +5,69 @@ namespace vision { namespace ops { -at::Tensor interpolate_linear_aa( +at::Tensor _interpolate_bilinear2d_aa( const at::Tensor& input, // Input image at::IntArrayRef output_size, // Output image size bool align_corners) // The flag to align corners { static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_interpolate_linear_aa", "") - .typed(); + .findSchemaOrThrow("torchvision::_interpolate_bilinear2d_aa", "") + .typed(); return op.call(input, output_size, align_corners); } -at::Tensor interpolate_bicubic_aa( +at::Tensor _interpolate_bicubic_aa( const at::Tensor& input, // Input image at::IntArrayRef output_size, // Output image size bool align_corners) // The flag to align corners { static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::_interpolate_bicubic_aa", "") - .typed(); + .findSchemaOrThrow("torchvision::_interpolate_bicubic2d_aa", "") + .typed(); return op.call(input, output_size, align_corners); } namespace detail { -// TODO: Implement backward function -// at::Tensor _interpolate_linear_aa_backward( -// const at::Tensor& grad, -// at::IntArrayRef output_size, -// bool align_corners) -// { -// return at::Tensor(); -// } +at::Tensor _interpolate_bilinear2d_aa_backward( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow( + "torchvision::_interpolate_bilinear2d_aa_backward", "") + .typed(); + return op.call(grad_output, output_size, output_size, align_corners); +} + +at::Tensor _interpolate_bicubic2d_aa_backward( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow( + "torchvision::_interpolate_bicubic2d_aa_backward", "") + .typed(); + return op.call(grad_output, output_size, output_size, align_corners); +} } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_interpolate_linear_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); + "torchvision::_interpolate_bilinear2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_interpolate_bicubic2d_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::_interpolate_bilinear2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_interpolate_bicubic_aa(Tensor input, int[] output_size, bool align_corners) -> Tensor")); - // TODO: Implement backward function - // m.def(TORCH_SELECTIVE_SCHEMA( - // "torchvision::_interpolate_linear_aa_backward(Tensor grad, Tensor rois, - // float spatial_scale, int pooled_height, int pooled_width, int - // batch_size, int channels, int height, int width, int sampling_ratio, - // bool aligned) -> Tensor")); + "torchvision::_interpolate_bicubic2d_aa_backward(Tensor input, int[] output_size, int[] input_size, bool align_corners) -> Tensor")); } } // namespace ops diff --git a/torchvision/csrc/ops/interpolate_aa.h b/torchvision/csrc/ops/interpolate_aa.h index 0a9ffb4b168..283418b3935 100644 --- a/torchvision/csrc/ops/interpolate_aa.h +++ b/torchvision/csrc/ops/interpolate_aa.h @@ -6,23 +6,29 @@ namespace vision { namespace ops { -VISION_API at::Tensor _interpolate_linear_aa( +VISION_API at::Tensor _interpolate_bilinear2d_aa( const at::Tensor& input, at::IntArrayRef output_size, bool align_corners = false); -VISION_API at::Tensor _interpolate_bicubic_aa( +VISION_API at::Tensor _interpolate_bicubic2d_aa( const at::Tensor& input, at::IntArrayRef output_size, bool align_corners = false); namespace detail { -// TODO: Implement backward function -// at::Tensor _interpolate_linear_aa_backward( -// const at::Tensor& grad, -// at::IntArrayRef output_size, -// bool align_corners=false); +VISION_API at::Tensor _interpolate_bilinear2d_aa_backward( + const at::Tensor& grad, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners = false); + +VISION_API at::Tensor _interpolate_bicubic2d_aa_backward( + const at::Tensor& grad, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners = false); } // namespace detail diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index a0e32d4237e..5a13bd5d392 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -545,9 +545,9 @@ def resize( if antialias: if interpolation == "bilinear": - img = torch.ops.torchvision._interpolate_linear_aa(img, [new_h, new_w], align_corners=False) + img = torch.ops.torchvision._interpolate_bilinear2d_aa(img, [new_h, new_w], align_corners=False) elif interpolation == "bicubic": - img = torch.ops.torchvision._interpolate_bicubic_aa(img, [new_h, new_w], align_corners=False) + img = torch.ops.torchvision._interpolate_bicubic2d_aa(img, [new_h, new_w], align_corners=False) else: img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)