File tree Expand file tree Collapse file tree 1 file changed +13
-7
lines changed
torchvision/csrc/ops/cuda Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Original file line number Diff line number Diff line change @@ -228,7 +228,9 @@ void deformable_im2col(
228228 int deformable_group,
229229 bool use_mask,
230230 at::Tensor data_col) {
231- int64_t num_kernels = (int64_t )n_in_channels * out_h * out_w * parallel_imgs;
231+ at::cuda::CUDAGuard device_guard (input.get_device ());
232+
233+ const int64_t num_kernels = (int64_t )n_in_channels * out_h * out_w * parallel_imgs;
232234
233235 const unsigned int threads = GET_THREADS ();
234236 const unsigned int blocks = GET_BLOCKS (threads, num_kernels);
@@ -408,12 +410,14 @@ void compute_grad_input(
408410 int n_offset_grps,
409411 bool use_mask,
410412 at::Tensor grad_im) {
411- int out_h =
413+ at::cuda::CUDAGuard device_guard (columns.get_device ());
414+
415+ const int out_h =
412416 (height + 2 * pad_h - (dilation_h * (weight_h - 1 ) + 1 )) / stride_h + 1 ;
413- int out_w =
417+ const int out_w =
414418 (width + 2 * pad_w - (dilation_w * (weight_w - 1 ) + 1 )) / stride_w + 1 ;
415419
416- int64_t num_kernels =
420+ const int64_t num_kernels =
417421 (int64_t )channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
418422
419423 const unsigned int threads = GET_THREADS ();
@@ -650,11 +654,13 @@ void compute_grad_offset_and_mask(
650654 bool use_mask,
651655 at::Tensor grad_offset,
652656 at::Tensor grad_mask) {
653- int out_h =
657+ at::cuda::CUDAGuard device_guard (columns.get_device ());
658+
659+ const int out_h =
654660 (height + 2 * pad_h - (dilation_h * (weight_h - 1 ) + 1 )) / stride_h + 1 ;
655- int out_w =
661+ const int out_w =
656662 (width + 2 * pad_w - (dilation_w * (weight_w - 1 ) + 1 )) / stride_w + 1 ;
657- int64_t num_kernels = (int64_t )out_h * out_w * 2 * weight_h * weight_w *
663+ const int64_t num_kernels = (int64_t )out_h * out_w * 2 * weight_h * weight_w *
658664 n_offset_grps * parallel_imgs;
659665
660666 const unsigned int threads = GET_THREADS ();
You can’t perform that action at this time.
0 commit comments