From 2f513d8a95ade3b8521837a7c2e121afa3878c52 Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Thu, 14 Oct 2021 16:40:02 -0700 Subject: [PATCH] Fix missing kernel guards (#4620) Summary: Pull Request resolved: https://github.com/pytorch/vision/pull/4620 Pull Request resolved: https://github.com/pytorch/nestedtensor/pull/455 Fixes missing kernel guards as identified by D30072495 Reviewed By: jingsh, xush6528 Differential Revision: D31553158 fbshipit-source-id: 80de017ba2ddc52e2a684df9b3eae5de84ed49f4 --- .../csrc/ops/cuda/deform_conv2d_kernel.cu | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu index 8d3beab888f..f1b0f75b008 100644 --- a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu +++ b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu @@ -228,7 +228,9 @@ void deformable_im2col( int deformable_group, bool use_mask, at::Tensor data_col) { - int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; + at::cuda::CUDAGuard device_guard(input.get_device()); + + const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; const unsigned int threads = GET_THREADS(); const unsigned int blocks = GET_BLOCKS(threads, num_kernels); @@ -408,12 +410,14 @@ void compute_grad_input( int n_offset_grps, bool use_mask, at::Tensor grad_im) { - int out_h = + at::cuda::CUDAGuard device_guard(columns.get_device()); + + const int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = + const int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int64_t num_kernels = + const int64_t num_kernels = (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; const unsigned int threads = GET_THREADS(); @@ -650,11 +654,13 @@ void compute_grad_offset_and_mask( bool use_mask, at::Tensor grad_offset, at::Tensor grad_mask) { - int out_h = + at::cuda::CUDAGuard device_guard(columns.get_device()); + + const int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = + const int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * + const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; const unsigned int threads = GET_THREADS();