diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 9b4c1b5f9af..30ee144888c 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -579,13 +579,11 @@ def test_assert_resize_antialias(interpolation): F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True) +@pytest.mark.parametrize('device', cpu_and_gpu()) @pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]]) @pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC]) -def test_interpolate_antialias_backward(dt, size, interpolation): - - # temporarily hard-code device as CPU, CUDA support will be done later - device = "cpu" +def test_interpolate_antialias_backward(device, dt, size, interpolation): if dt == torch.float16 and device == "cpu": # skip float16 on CPU case diff --git a/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu b/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu index 46525f4dd01..f52793408f4 100644 --- a/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu +++ b/torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu @@ -165,49 +165,32 @@ __global__ void upsample_gen2d_out_frame( // Compute weights int xmin, xsize, ymin, ysize; typedef scalar_t (*filter_fn_t)(scalar_t); + filter_fn_t filter_fn; if (interp_size == 2) { - _compute_weights( - w2, - width1, - rwidth, - support_w, - wx, - interp_width, - bilinear_filter, - xmin, - xsize); - _compute_weights( - h2, - height1, - rheight, - support_h, - wy, - interp_height, - bilinear_filter, - ymin, - ysize); + filter_fn = bilinear_filter; } else if (interp_size == 4) { - _compute_weights( - w2, - width1, - rwidth, - support_w, - wx, - interp_width, - bicubic_filter, - xmin, - xsize); - _compute_weights( - h2, - height1, - rheight, - support_h, - wy, - interp_height, - bicubic_filter, - ymin, - ysize); + filter_fn = bicubic_filter; } + _compute_weights( + w2, + width1, + rwidth, + support_w, + wx, + interp_width, + filter_fn, + xmin, + xsize); + _compute_weights( + h2, + height1, + rheight, + support_h, + wy, + interp_height, + filter_fn, + ymin, + ysize); for (int n = 0; n < batchsize; n++) { for (int c = 0; c < channels; ++c) { @@ -239,6 +222,8 @@ static void upsample_gen2d_out_cuda_template( bool align_corners, c10::optional scales_h, c10::optional scales_w) { + // Copied and adapted from + // UpSampleBicubic2d.cu::upsample_bicubic2d_out_cuda_template TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2}; checkAllSameGPU("upsample_gen2d_out_cuda", {input_arg, output_arg}); @@ -256,7 +241,7 @@ static void upsample_gen2d_out_cuda_template( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "upsample_bilinear2d_out_frame", [&] { + input.scalar_type(), "upsample_gen2d_out_frame", [&] { using accscalar_t = at::acc_type; auto idata = input.packed_accessor64(); @@ -287,6 +272,174 @@ static void upsample_gen2d_out_cuda_template( }); } +// Backward (adjoint) operation 1 <- 2 (accumulates) +template +C10_LAUNCH_BOUNDS_1(1024) +__global__ void upsample_gen2d_backward_out_frame( + const int num_elements, + const accscalar_t height_scale, + const accscalar_t width_scale, + const bool align_corners, + PackedTensorAccessor64 idata, + const PackedTensorAccessor64 odata) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + + const int batchsize = idata.size(0); + const int channels = idata.size(1); + const int input_height = idata.size(2); + const int input_width = idata.size(3); + const int output_height = odata.size(2); + const int output_width = odata.size(3); + + if (index >= num_elements) { + return; + } + + const int output_x = index % output_width; + const int output_y = index / output_width; + // special case: output just copy + if (input_height == output_height && input_width == output_width) { + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + const scalar_t val = odata[n][c][output_y][output_x]; + idata[n][c][output_y][output_x] = val; + } + } + return; + } + + const accscalar_t support_h = static_cast( + (height_scale >= 1.0) ? (interp_size * 0.5) * height_scale + : interp_size * 0.5); + const accscalar_t support_w = static_cast( + (width_scale >= 1.0) ? (interp_size * 0.5) * width_scale + : interp_size * 0.5); + + const int interp_height = (int)ceilf(support_h) * 2 + 1; + const int interp_width = (int)ceilf(support_w) * 2 + 1; + + // Setup local buffers + // TODO: maybe we can specify dynamic shared memory size before calling the + // cuda code, however we should then ensure that device has enough shared + // memory + scalar_t wx[256]; + scalar_t wy[256]; + + // Compute weights + int xmin, xsize, ymin, ysize; + typedef scalar_t (*filter_fn_t)(scalar_t); + filter_fn_t filter_fn; + if (interp_size == 2) { + filter_fn = bilinear_filter; + } else if (interp_size == 4) { + filter_fn = bicubic_filter; + } + _compute_weights( + output_x, + input_width, + width_scale, + support_w, + wx, + interp_width, + filter_fn, + xmin, + xsize); + _compute_weights( + output_y, + input_height, + height_scale, + support_h, + wy, + interp_height, + filter_fn, + ymin, + ysize); + + for (int n = 0; n < batchsize; n++) { + for (int c = 0; c < channels; ++c) { + scalar_t out_value = odata[n][c][output_y][output_x]; + for (int y = 0; y < ysize; y++) { + for (int x = 0; x < xsize; x++) { + upsample_increment_value_bounded( + idata, + n, + c, + input_height, + input_width, + ymin + y, + xmin + x, + wx[x] * wy[y] * out_value); + } + } + } + } +} + +template +static void upsample_gen2d_backward_out_cuda_template( + const Tensor& grad_input, + const Tensor& grad_output_, + IntArrayRef output_size, + IntArrayRef input_size, + bool align_corners, + c10::optional scales_h, + c10::optional scales_w) { + // Copied and adapted from + // UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template + TensorArg grad_input_arg{grad_input, "grad_input", 1}, + grad_output_arg{grad_output_, "grad_output_", 2}; + checkAllSameGPU( + "upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg}); + + int output_height = output_size[0]; + int output_width = output_size[1]; + + int nbatch = input_size[0]; + int channels = input_size[1]; + int input_height = input_size[2]; + int input_width = input_size[3]; + + Tensor grad_output = grad_output_.contiguous(); + + grad_input.zero_(); + + const int num_kernels = output_height * output_width; + const int num_threads = std::min( + at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "upsample_gen2d_backward_out_frame", [&] { + using accscalar_t = at::acc_type; + + auto idata = grad_input.packed_accessor64(); + auto odata = grad_output.packed_accessor64(); + + const accscalar_t rheight = area_pixel_compute_scale( + input_height, output_height, align_corners, scales_h); + const accscalar_t rwidth = area_pixel_compute_scale( + input_width, output_width, align_corners, scales_w); + + // We are using static buffer memory of 256 * sizeof(float) per thread + // to store weights. Size of weights array is + // interp_size = scale * 2 + 1 for bilinear mode + TORCH_CHECK( + rheight < (255 / interp_size), + "Max supported scale factor is 127 (bilinear), 63 (bicubic)"); + TORCH_CHECK( + rwidth < (255 / interp_size), + "Max supported scale factor is 127 (bilinear), 63 (bicubic)"); + + upsample_gen2d_backward_out_frame + <<>>( + num_kernels, rheight, rwidth, align_corners, idata, odata); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + } // namespace internal_upsample } // namespace native } // namespace at @@ -371,6 +524,56 @@ at::Tensor interpolate_gen2d_aa_forward_kernel( return output; } +template +at::Tensor interpolate_gen2d_aa_backward_kernel( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners) { + c10::optional> scale_factors = {}; + + // Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward + auto grad_input = at::empty({0}, grad_output.options()); + auto osize = at::native::upsample::compute_output_size( + input_size, output_size, scale_factors); + auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0); + auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1); + + auto full_output_size = upsample_2d_common_check(input_size, osize); + + TORCH_CHECK( + grad_output.dim() == 4, + "Expected grad_output to be a tensor of dimension 4 but got: dimension ", + grad_output.dim()); + + for (int i = 0; i < 4; ++i) { + TORCH_CHECK( + grad_output.size(i) == full_output_size[i], + "Expected grad_output to have the same shape as output;", + " output.size(", + i, + ") = ", + full_output_size[i], + " but got grad_output.size(", + i, + ") = ", + grad_output.size(i)); + } + + grad_input.resize_(input_size, grad_output.suggest_memory_format()); + + at::native::internal_upsample::upsample_gen2d_backward_out_cuda_template< + interp_size>( + grad_input, + grad_output, + {full_output_size[2], full_output_size[3]}, + input_size, + align_corners, + scale_h, + scale_w); + return grad_input; +} + at::Tensor interpolate_bilinear2d_aa_forward_kernel( const at::Tensor& input, at::IntArrayRef output_size, @@ -387,6 +590,24 @@ at::Tensor interpolate_bicubic2d_aa_forward_kernel( input, output_size, align_corners); } +at::Tensor interpolate_bilinear2d_aa_backward_kernel( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners) { + return interpolate_gen2d_aa_backward_kernel<2>( + grad_output, output_size, input_size, align_corners); +} + +at::Tensor interpolate_bicubic2d_aa_backward_kernel( + const at::Tensor& grad_output, + at::IntArrayRef output_size, + at::IntArrayRef input_size, + bool align_corners) { + return interpolate_gen2d_aa_backward_kernel<4>( + grad_output, output_size, input_size, align_corners); +} + } // namespace TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { @@ -396,6 +617,12 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) { m.impl( TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"), TORCH_FN(interpolate_bicubic2d_aa_forward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"), + TORCH_FN(interpolate_bilinear2d_aa_backward_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"), + TORCH_FN(interpolate_bicubic2d_aa_backward_kernel)); } } // namespace ops