Skip to content

Commit a9b38db

Browse files
vfdev-5vfdev-5
andauthored
[2/2] Added backward pass on CUDA for interpolation with anti-alias option (#4211)
* WIP on backward op interpolation with AA * Removed cuda tests and reformat cpp code * Fixed clang wrong formatting * Added channels last test case * Added CUDA support for backward pass, interpolation with AA * Removed unused buffers Co-authored-by: vfdev-5 <[email protected]>
1 parent c8b12fd commit a9b38db

File tree

2 files changed

+270
-45
lines changed

2 files changed

+270
-45
lines changed

test/test_functional_tensor.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -579,13 +579,11 @@ def test_assert_resize_antialias(interpolation):
579579
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)
580580

581581

582+
@pytest.mark.parametrize('device', cpu_and_gpu())
582583
@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16])
583584
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]])
584585
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
585-
def test_interpolate_antialias_backward(dt, size, interpolation):
586-
587-
# temporarily hard-code device as CPU, CUDA support will be done later
588-
device = "cpu"
586+
def test_interpolate_antialias_backward(device, dt, size, interpolation):
589587

590588
if dt == torch.float16 and device == "cpu":
591589
# skip float16 on CPU case

torchvision/csrc/ops/cuda/interpolate_aa_kernels.cu

Lines changed: 268 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -165,49 +165,32 @@ __global__ void upsample_gen2d_out_frame(
165165
// Compute weights
166166
int xmin, xsize, ymin, ysize;
167167
typedef scalar_t (*filter_fn_t)(scalar_t);
168+
filter_fn_t filter_fn;
168169
if (interp_size == 2) {
169-
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
170-
w2,
171-
width1,
172-
rwidth,
173-
support_w,
174-
wx,
175-
interp_width,
176-
bilinear_filter,
177-
xmin,
178-
xsize);
179-
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
180-
h2,
181-
height1,
182-
rheight,
183-
support_h,
184-
wy,
185-
interp_height,
186-
bilinear_filter,
187-
ymin,
188-
ysize);
170+
filter_fn = bilinear_filter;
189171
} else if (interp_size == 4) {
190-
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
191-
w2,
192-
width1,
193-
rwidth,
194-
support_w,
195-
wx,
196-
interp_width,
197-
bicubic_filter,
198-
xmin,
199-
xsize);
200-
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
201-
h2,
202-
height1,
203-
rheight,
204-
support_h,
205-
wy,
206-
interp_height,
207-
bicubic_filter,
208-
ymin,
209-
ysize);
172+
filter_fn = bicubic_filter;
210173
}
174+
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
175+
w2,
176+
width1,
177+
rwidth,
178+
support_w,
179+
wx,
180+
interp_width,
181+
filter_fn,
182+
xmin,
183+
xsize);
184+
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
185+
h2,
186+
height1,
187+
rheight,
188+
support_h,
189+
wy,
190+
interp_height,
191+
filter_fn,
192+
ymin,
193+
ysize);
211194

212195
for (int n = 0; n < batchsize; n++) {
213196
for (int c = 0; c < channels; ++c) {
@@ -239,6 +222,8 @@ static void upsample_gen2d_out_cuda_template(
239222
bool align_corners,
240223
c10::optional<double> scales_h,
241224
c10::optional<double> scales_w) {
225+
// Copied and adapted from
226+
// UpSampleBicubic2d.cu::upsample_bicubic2d_out_cuda_template
242227
TensorArg input_arg{input, "input", 1}, output_arg{output, "output", 2};
243228
checkAllSameGPU("upsample_gen2d_out_cuda", {input_arg, output_arg});
244229

@@ -256,7 +241,7 @@ static void upsample_gen2d_out_cuda_template(
256241
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
257242

258243
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
259-
input.scalar_type(), "upsample_bilinear2d_out_frame", [&] {
244+
input.scalar_type(), "upsample_gen2d_out_frame", [&] {
260245
using accscalar_t = at::acc_type<scalar_t, true>;
261246

262247
auto idata = input.packed_accessor64<scalar_t, 4>();
@@ -287,6 +272,174 @@ static void upsample_gen2d_out_cuda_template(
287272
});
288273
}
289274

275+
// Backward (adjoint) operation 1 <- 2 (accumulates)
276+
template <typename scalar_t, typename accscalar_t, int interp_size>
277+
C10_LAUNCH_BOUNDS_1(1024)
278+
__global__ void upsample_gen2d_backward_out_frame(
279+
const int num_elements,
280+
const accscalar_t height_scale,
281+
const accscalar_t width_scale,
282+
const bool align_corners,
283+
PackedTensorAccessor64<scalar_t, 4> idata,
284+
const PackedTensorAccessor64<scalar_t, 4> odata) {
285+
int index = threadIdx.x + blockIdx.x * blockDim.x;
286+
287+
const int batchsize = idata.size(0);
288+
const int channels = idata.size(1);
289+
const int input_height = idata.size(2);
290+
const int input_width = idata.size(3);
291+
const int output_height = odata.size(2);
292+
const int output_width = odata.size(3);
293+
294+
if (index >= num_elements) {
295+
return;
296+
}
297+
298+
const int output_x = index % output_width;
299+
const int output_y = index / output_width;
300+
// special case: output just copy
301+
if (input_height == output_height && input_width == output_width) {
302+
for (int n = 0; n < batchsize; n++) {
303+
for (int c = 0; c < channels; ++c) {
304+
const scalar_t val = odata[n][c][output_y][output_x];
305+
idata[n][c][output_y][output_x] = val;
306+
}
307+
}
308+
return;
309+
}
310+
311+
const accscalar_t support_h = static_cast<accscalar_t>(
312+
(height_scale >= 1.0) ? (interp_size * 0.5) * height_scale
313+
: interp_size * 0.5);
314+
const accscalar_t support_w = static_cast<accscalar_t>(
315+
(width_scale >= 1.0) ? (interp_size * 0.5) * width_scale
316+
: interp_size * 0.5);
317+
318+
const int interp_height = (int)ceilf(support_h) * 2 + 1;
319+
const int interp_width = (int)ceilf(support_w) * 2 + 1;
320+
321+
// Setup local buffers
322+
// TODO: maybe we can specify dynamic shared memory size before calling the
323+
// cuda code, however we should then ensure that device has enough shared
324+
// memory
325+
scalar_t wx[256];
326+
scalar_t wy[256];
327+
328+
// Compute weights
329+
int xmin, xsize, ymin, ysize;
330+
typedef scalar_t (*filter_fn_t)(scalar_t);
331+
filter_fn_t filter_fn;
332+
if (interp_size == 2) {
333+
filter_fn = bilinear_filter;
334+
} else if (interp_size == 4) {
335+
filter_fn = bicubic_filter;
336+
}
337+
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
338+
output_x,
339+
input_width,
340+
width_scale,
341+
support_w,
342+
wx,
343+
interp_width,
344+
filter_fn,
345+
xmin,
346+
xsize);
347+
_compute_weights<scalar_t, accscalar_t, filter_fn_t>(
348+
output_y,
349+
input_height,
350+
height_scale,
351+
support_h,
352+
wy,
353+
interp_height,
354+
filter_fn,
355+
ymin,
356+
ysize);
357+
358+
for (int n = 0; n < batchsize; n++) {
359+
for (int c = 0; c < channels; ++c) {
360+
scalar_t out_value = odata[n][c][output_y][output_x];
361+
for (int y = 0; y < ysize; y++) {
362+
for (int x = 0; x < xsize; x++) {
363+
upsample_increment_value_bounded<scalar_t, accscalar_t>(
364+
idata,
365+
n,
366+
c,
367+
input_height,
368+
input_width,
369+
ymin + y,
370+
xmin + x,
371+
wx[x] * wy[y] * out_value);
372+
}
373+
}
374+
}
375+
}
376+
}
377+
378+
template <int interp_size>
379+
static void upsample_gen2d_backward_out_cuda_template(
380+
const Tensor& grad_input,
381+
const Tensor& grad_output_,
382+
IntArrayRef output_size,
383+
IntArrayRef input_size,
384+
bool align_corners,
385+
c10::optional<double> scales_h,
386+
c10::optional<double> scales_w) {
387+
// Copied and adapted from
388+
// UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template
389+
TensorArg grad_input_arg{grad_input, "grad_input", 1},
390+
grad_output_arg{grad_output_, "grad_output_", 2};
391+
checkAllSameGPU(
392+
"upsample_gen2d_backward_out_cuda", {grad_output_arg, grad_input_arg});
393+
394+
int output_height = output_size[0];
395+
int output_width = output_size[1];
396+
397+
int nbatch = input_size[0];
398+
int channels = input_size[1];
399+
int input_height = input_size[2];
400+
int input_width = input_size[3];
401+
402+
Tensor grad_output = grad_output_.contiguous();
403+
404+
grad_input.zero_();
405+
406+
const int num_kernels = output_height * output_width;
407+
const int num_threads = std::min(
408+
at::cuda::getCurrentDeviceProperties()->maxThreadsPerBlock, 1024);
409+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
410+
411+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
412+
grad_output.scalar_type(), "upsample_gen2d_backward_out_frame", [&] {
413+
using accscalar_t = at::acc_type<scalar_t, true>;
414+
415+
auto idata = grad_input.packed_accessor64<scalar_t, 4>();
416+
auto odata = grad_output.packed_accessor64<scalar_t, 4>();
417+
418+
const accscalar_t rheight = area_pixel_compute_scale<accscalar_t>(
419+
input_height, output_height, align_corners, scales_h);
420+
const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t>(
421+
input_width, output_width, align_corners, scales_w);
422+
423+
// We are using static buffer memory of 256 * sizeof(float) per thread
424+
// to store weights. Size of weights array is
425+
// interp_size = scale * 2 + 1 for bilinear mode
426+
TORCH_CHECK(
427+
rheight < (255 / interp_size),
428+
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
429+
TORCH_CHECK(
430+
rwidth < (255 / interp_size),
431+
"Max supported scale factor is 127 (bilinear), 63 (bicubic)");
432+
433+
upsample_gen2d_backward_out_frame<scalar_t, accscalar_t, interp_size>
434+
<<<cuda::ATenCeilDiv(num_kernels, num_threads),
435+
num_threads,
436+
0,
437+
stream>>>(
438+
num_kernels, rheight, rwidth, align_corners, idata, odata);
439+
C10_CUDA_KERNEL_LAUNCH_CHECK();
440+
});
441+
}
442+
290443
} // namespace internal_upsample
291444
} // namespace native
292445
} // namespace at
@@ -371,6 +524,56 @@ at::Tensor interpolate_gen2d_aa_forward_kernel(
371524
return output;
372525
}
373526

527+
template <int interp_size>
528+
at::Tensor interpolate_gen2d_aa_backward_kernel(
529+
const at::Tensor& grad_output,
530+
at::IntArrayRef output_size,
531+
at::IntArrayRef input_size,
532+
bool align_corners) {
533+
c10::optional<c10::ArrayRef<double>> scale_factors = {};
534+
535+
// Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward
536+
auto grad_input = at::empty({0}, grad_output.options());
537+
auto osize = at::native::upsample::compute_output_size(
538+
input_size, output_size, scale_factors);
539+
auto scale_h = at::native::upsample_cuda::get_scale_value(scale_factors, 0);
540+
auto scale_w = at::native::upsample_cuda::get_scale_value(scale_factors, 1);
541+
542+
auto full_output_size = upsample_2d_common_check(input_size, osize);
543+
544+
TORCH_CHECK(
545+
grad_output.dim() == 4,
546+
"Expected grad_output to be a tensor of dimension 4 but got: dimension ",
547+
grad_output.dim());
548+
549+
for (int i = 0; i < 4; ++i) {
550+
TORCH_CHECK(
551+
grad_output.size(i) == full_output_size[i],
552+
"Expected grad_output to have the same shape as output;",
553+
" output.size(",
554+
i,
555+
") = ",
556+
full_output_size[i],
557+
" but got grad_output.size(",
558+
i,
559+
") = ",
560+
grad_output.size(i));
561+
}
562+
563+
grad_input.resize_(input_size, grad_output.suggest_memory_format());
564+
565+
at::native::internal_upsample::upsample_gen2d_backward_out_cuda_template<
566+
interp_size>(
567+
grad_input,
568+
grad_output,
569+
{full_output_size[2], full_output_size[3]},
570+
input_size,
571+
align_corners,
572+
scale_h,
573+
scale_w);
574+
return grad_input;
575+
}
576+
374577
at::Tensor interpolate_bilinear2d_aa_forward_kernel(
375578
const at::Tensor& input,
376579
at::IntArrayRef output_size,
@@ -387,6 +590,24 @@ at::Tensor interpolate_bicubic2d_aa_forward_kernel(
387590
input, output_size, align_corners);
388591
}
389592

593+
at::Tensor interpolate_bilinear2d_aa_backward_kernel(
594+
const at::Tensor& grad_output,
595+
at::IntArrayRef output_size,
596+
at::IntArrayRef input_size,
597+
bool align_corners) {
598+
return interpolate_gen2d_aa_backward_kernel<2>(
599+
grad_output, output_size, input_size, align_corners);
600+
}
601+
602+
at::Tensor interpolate_bicubic2d_aa_backward_kernel(
603+
const at::Tensor& grad_output,
604+
at::IntArrayRef output_size,
605+
at::IntArrayRef input_size,
606+
bool align_corners) {
607+
return interpolate_gen2d_aa_backward_kernel<4>(
608+
grad_output, output_size, input_size, align_corners);
609+
}
610+
390611
} // namespace
391612

392613
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
@@ -396,6 +617,12 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
396617
m.impl(
397618
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa"),
398619
TORCH_FN(interpolate_bicubic2d_aa_forward_kernel));
620+
m.impl(
621+
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bilinear2d_aa_backward"),
622+
TORCH_FN(interpolate_bilinear2d_aa_backward_kernel));
623+
m.impl(
624+
TORCH_SELECTIVE_NAME("torchvision::_interpolate_bicubic2d_aa_backward"),
625+
TORCH_FN(interpolate_bicubic2d_aa_backward_kernel));
399626
}
400627

401628
} // namespace ops

0 commit comments

Comments
 (0)