diff --git a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu index 8d3beab888f..f1b0f75b008 100644 --- a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu +++ b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu @@ -228,7 +228,9 @@ void deformable_im2col( int deformable_group, bool use_mask, at::Tensor data_col) { - int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; + at::cuda::CUDAGuard device_guard(input.get_device()); + + const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; const unsigned int threads = GET_THREADS(); const unsigned int blocks = GET_BLOCKS(threads, num_kernels); @@ -408,12 +410,14 @@ void compute_grad_input( int n_offset_grps, bool use_mask, at::Tensor grad_im) { - int out_h = + at::cuda::CUDAGuard device_guard(columns.get_device()); + + const int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = + const int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int64_t num_kernels = + const int64_t num_kernels = (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; const unsigned int threads = GET_THREADS(); @@ -650,11 +654,13 @@ void compute_grad_offset_and_mask( bool use_mask, at::Tensor grad_offset, at::Tensor grad_mask) { - int out_h = + at::cuda::CUDAGuard device_guard(columns.get_device()); + + const int out_h = (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; - int out_w = + const int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * + const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; const unsigned int threads = GET_THREADS();