@@ -165,49 +165,32 @@ __global__ void upsample_gen2d_out_frame(
165165 // Compute weights
166166 int xmin, xsize, ymin, ysize;
167167 typedef scalar_t (*filter_fn_t )(scalar_t );
168+ filter_fn_t filter_fn;
168169 if (interp_size == 2 ) {
169- _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
170- w2,
171- width1,
172- rwidth,
173- support_w,
174- wx,
175- interp_width,
176- bilinear_filter,
177- xmin,
178- xsize);
179- _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
180- h2,
181- height1,
182- rheight,
183- support_h,
184- wy,
185- interp_height,
186- bilinear_filter,
187- ymin,
188- ysize);
170+ filter_fn = bilinear_filter;
189171 } else if (interp_size == 4 ) {
190- _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
191- w2,
192- width1,
193- rwidth,
194- support_w,
195- wx,
196- interp_width,
197- bicubic_filter,
198- xmin,
199- xsize);
200- _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
201- h2,
202- height1,
203- rheight,
204- support_h,
205- wy,
206- interp_height,
207- bicubic_filter,
208- ymin,
209- ysize);
172+ filter_fn = bicubic_filter;
210173 }
174+ _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
175+ w2,
176+ width1,
177+ rwidth,
178+ support_w,
179+ wx,
180+ interp_width,
181+ filter_fn,
182+ xmin,
183+ xsize);
184+ _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
185+ h2,
186+ height1,
187+ rheight,
188+ support_h,
189+ wy,
190+ interp_height,
191+ filter_fn,
192+ ymin,
193+ ysize);
211194
212195 for (int n = 0 ; n < batchsize; n++) {
213196 for (int c = 0 ; c < channels; ++c) {
@@ -239,6 +222,8 @@ static void upsample_gen2d_out_cuda_template(
239222 bool align_corners,
240223 c10::optional<double > scales_h,
241224 c10::optional<double > scales_w) {
225+ // Copied and adapted from
226+ // UpSampleBicubic2d.cu::upsample_bicubic2d_out_cuda_template
242227 TensorArg input_arg{input, " input" , 1 }, output_arg{output, " output" , 2 };
243228 checkAllSameGPU (" upsample_gen2d_out_cuda" , {input_arg, output_arg});
244229
@@ -256,7 +241,7 @@ static void upsample_gen2d_out_cuda_template(
256241 cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
257242
258243 AT_DISPATCH_FLOATING_TYPES_AND_HALF (
259- input.scalar_type (), " upsample_bilinear2d_out_frame " , [&] {
244+ input.scalar_type (), " upsample_gen2d_out_frame " , [&] {
260245 using accscalar_t = at::acc_type<scalar_t , true >;
261246
262247 auto idata = input.packed_accessor64 <scalar_t , 4 >();
@@ -287,6 +272,174 @@ static void upsample_gen2d_out_cuda_template(
287272 });
288273}
289274
275+ // Backward (adjoint) operation 1 <- 2 (accumulates)
276+ template <typename scalar_t , typename accscalar_t , int interp_size>
277+ C10_LAUNCH_BOUNDS_1 (1024 )
278+ __global__ void upsample_gen2d_backward_out_frame (
279+ const int num_elements,
280+ const accscalar_t height_scale,
281+ const accscalar_t width_scale,
282+ const bool align_corners,
283+ PackedTensorAccessor64<scalar_t , 4 > idata,
284+ const PackedTensorAccessor64<scalar_t , 4 > odata) {
285+ int index = threadIdx .x + blockIdx .x * blockDim .x ;
286+
287+ const int batchsize = idata.size (0 );
288+ const int channels = idata.size (1 );
289+ const int input_height = idata.size (2 );
290+ const int input_width = idata.size (3 );
291+ const int output_height = odata.size (2 );
292+ const int output_width = odata.size (3 );
293+
294+ if (index >= num_elements) {
295+ return ;
296+ }
297+
298+ const int output_x = index % output_width;
299+ const int output_y = index / output_width;
300+ // special case: output just copy
301+ if (input_height == output_height && input_width == output_width) {
302+ for (int n = 0 ; n < batchsize; n++) {
303+ for (int c = 0 ; c < channels; ++c) {
304+ const scalar_t val = odata[n][c][output_y][output_x];
305+ idata[n][c][output_y][output_x] = val;
306+ }
307+ }
308+ return ;
309+ }
310+
311+ const accscalar_t support_h = static_cast <accscalar_t >(
312+ (height_scale >= 1.0 ) ? (interp_size * 0.5 ) * height_scale
313+ : interp_size * 0.5 );
314+ const accscalar_t support_w = static_cast <accscalar_t >(
315+ (width_scale >= 1.0 ) ? (interp_size * 0.5 ) * width_scale
316+ : interp_size * 0.5 );
317+
318+ const int interp_height = (int )ceilf (support_h) * 2 + 1 ;
319+ const int interp_width = (int )ceilf (support_w) * 2 + 1 ;
320+
321+ // Setup local buffers
322+ // TODO: maybe we can specify dynamic shared memory size before calling the
323+ // cuda code, however we should then ensure that device has enough shared
324+ // memory
325+ scalar_t wx[256 ];
326+ scalar_t wy[256 ];
327+
328+ // Compute weights
329+ int xmin, xsize, ymin, ysize;
330+ typedef scalar_t (*filter_fn_t )(scalar_t );
331+ filter_fn_t filter_fn;
332+ if (interp_size == 2 ) {
333+ filter_fn = bilinear_filter;
334+ } else if (interp_size == 4 ) {
335+ filter_fn = bicubic_filter;
336+ }
337+ _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
338+ output_x,
339+ input_width,
340+ width_scale,
341+ support_w,
342+ wx,
343+ interp_width,
344+ filter_fn,
345+ xmin,
346+ xsize);
347+ _compute_weights<scalar_t , accscalar_t , filter_fn_t >(
348+ output_y,
349+ input_height,
350+ height_scale,
351+ support_h,
352+ wy,
353+ interp_height,
354+ filter_fn,
355+ ymin,
356+ ysize);
357+
358+ for (int n = 0 ; n < batchsize; n++) {
359+ for (int c = 0 ; c < channels; ++c) {
360+ scalar_t out_value = odata[n][c][output_y][output_x];
361+ for (int y = 0 ; y < ysize; y++) {
362+ for (int x = 0 ; x < xsize; x++) {
363+ upsample_increment_value_bounded<scalar_t , accscalar_t >(
364+ idata,
365+ n,
366+ c,
367+ input_height,
368+ input_width,
369+ ymin + y,
370+ xmin + x,
371+ wx[x] * wy[y] * out_value);
372+ }
373+ }
374+ }
375+ }
376+ }
377+
378+ template <int interp_size>
379+ static void upsample_gen2d_backward_out_cuda_template (
380+ const Tensor& grad_input,
381+ const Tensor& grad_output_,
382+ IntArrayRef output_size,
383+ IntArrayRef input_size,
384+ bool align_corners,
385+ c10::optional<double > scales_h,
386+ c10::optional<double > scales_w) {
387+ // Copied and adapted from
388+ // UpSampleBicubic2d.cu::upsample_bicubic2d_backward_out_cuda_template
389+ TensorArg grad_input_arg{grad_input, " grad_input" , 1 },
390+ grad_output_arg{grad_output_, " grad_output_" , 2 };
391+ checkAllSameGPU (
392+ " upsample_gen2d_backward_out_cuda" , {grad_output_arg, grad_input_arg});
393+
394+ int output_height = output_size[0 ];
395+ int output_width = output_size[1 ];
396+
397+ int nbatch = input_size[0 ];
398+ int channels = input_size[1 ];
399+ int input_height = input_size[2 ];
400+ int input_width = input_size[3 ];
401+
402+ Tensor grad_output = grad_output_.contiguous ();
403+
404+ grad_input.zero_ ();
405+
406+ const int num_kernels = output_height * output_width;
407+ const int num_threads = std::min (
408+ at::cuda::getCurrentDeviceProperties ()->maxThreadsPerBlock , 1024 );
409+ cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
410+
411+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (
412+ grad_output.scalar_type (), " upsample_gen2d_backward_out_frame" , [&] {
413+ using accscalar_t = at::acc_type<scalar_t , true >;
414+
415+ auto idata = grad_input.packed_accessor64 <scalar_t , 4 >();
416+ auto odata = grad_output.packed_accessor64 <scalar_t , 4 >();
417+
418+ const accscalar_t rheight = area_pixel_compute_scale<accscalar_t >(
419+ input_height, output_height, align_corners, scales_h);
420+ const accscalar_t rwidth = area_pixel_compute_scale<accscalar_t >(
421+ input_width, output_width, align_corners, scales_w);
422+
423+ // We are using static buffer memory of 256 * sizeof(float) per thread
424+ // to store weights. Size of weights array is
425+ // interp_size = scale * 2 + 1 for bilinear mode
426+ TORCH_CHECK (
427+ rheight < (255 / interp_size),
428+ " Max supported scale factor is 127 (bilinear), 63 (bicubic)" );
429+ TORCH_CHECK (
430+ rwidth < (255 / interp_size),
431+ " Max supported scale factor is 127 (bilinear), 63 (bicubic)" );
432+
433+ upsample_gen2d_backward_out_frame<scalar_t , accscalar_t , interp_size>
434+ <<<cuda::ATenCeilDiv(num_kernels, num_threads),
435+ num_threads,
436+ 0 ,
437+ stream>>> (
438+ num_kernels, rheight, rwidth, align_corners, idata, odata);
439+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
440+ });
441+ }
442+
290443} // namespace internal_upsample
291444} // namespace native
292445} // namespace at
@@ -371,6 +524,56 @@ at::Tensor interpolate_gen2d_aa_forward_kernel(
371524 return output;
372525}
373526
527+ template <int interp_size>
528+ at::Tensor interpolate_gen2d_aa_backward_kernel (
529+ const at::Tensor& grad_output,
530+ at::IntArrayRef output_size,
531+ at::IntArrayRef input_size,
532+ bool align_corners) {
533+ c10::optional<c10::ArrayRef<double >> scale_factors = {};
534+
535+ // Copied from UpSampleBicubic2d.cpp::upsample_bicubic2d_backward
536+ auto grad_input = at::empty ({0 }, grad_output.options ());
537+ auto osize = at::native::upsample::compute_output_size (
538+ input_size, output_size, scale_factors);
539+ auto scale_h = at::native::upsample_cuda::get_scale_value (scale_factors, 0 );
540+ auto scale_w = at::native::upsample_cuda::get_scale_value (scale_factors, 1 );
541+
542+ auto full_output_size = upsample_2d_common_check (input_size, osize);
543+
544+ TORCH_CHECK (
545+ grad_output.dim () == 4 ,
546+ " Expected grad_output to be a tensor of dimension 4 but got: dimension " ,
547+ grad_output.dim ());
548+
549+ for (int i = 0 ; i < 4 ; ++i) {
550+ TORCH_CHECK (
551+ grad_output.size (i) == full_output_size[i],
552+ " Expected grad_output to have the same shape as output;" ,
553+ " output.size(" ,
554+ i,
555+ " ) = " ,
556+ full_output_size[i],
557+ " but got grad_output.size(" ,
558+ i,
559+ " ) = " ,
560+ grad_output.size (i));
561+ }
562+
563+ grad_input.resize_ (input_size, grad_output.suggest_memory_format ());
564+
565+ at::native::internal_upsample::upsample_gen2d_backward_out_cuda_template<
566+ interp_size>(
567+ grad_input,
568+ grad_output,
569+ {full_output_size[2 ], full_output_size[3 ]},
570+ input_size,
571+ align_corners,
572+ scale_h,
573+ scale_w);
574+ return grad_input;
575+ }
576+
374577at::Tensor interpolate_bilinear2d_aa_forward_kernel (
375578 const at::Tensor& input,
376579 at::IntArrayRef output_size,
@@ -387,6 +590,24 @@ at::Tensor interpolate_bicubic2d_aa_forward_kernel(
387590 input, output_size, align_corners);
388591}
389592
593+ at::Tensor interpolate_bilinear2d_aa_backward_kernel (
594+ const at::Tensor& grad_output,
595+ at::IntArrayRef output_size,
596+ at::IntArrayRef input_size,
597+ bool align_corners) {
598+ return interpolate_gen2d_aa_backward_kernel<2 >(
599+ grad_output, output_size, input_size, align_corners);
600+ }
601+
602+ at::Tensor interpolate_bicubic2d_aa_backward_kernel (
603+ const at::Tensor& grad_output,
604+ at::IntArrayRef output_size,
605+ at::IntArrayRef input_size,
606+ bool align_corners) {
607+ return interpolate_gen2d_aa_backward_kernel<4 >(
608+ grad_output, output_size, input_size, align_corners);
609+ }
610+
390611} // namespace
391612
392613TORCH_LIBRARY_IMPL (torchvision, CUDA, m) {
@@ -396,6 +617,12 @@ TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
396617 m.impl (
397618 TORCH_SELECTIVE_NAME (" torchvision::_interpolate_bicubic2d_aa" ),
398619 TORCH_FN (interpolate_bicubic2d_aa_forward_kernel));
620+ m.impl (
621+ TORCH_SELECTIVE_NAME (" torchvision::_interpolate_bilinear2d_aa_backward" ),
622+ TORCH_FN (interpolate_bilinear2d_aa_backward_kernel));
623+ m.impl (
624+ TORCH_SELECTIVE_NAME (" torchvision::_interpolate_bicubic2d_aa_backward" ),
625+ TORCH_FN (interpolate_bicubic2d_aa_backward_kernel));
399626}
400627
401628} // namespace ops
0 commit comments