From 88aede3df07b0439d2bf9dee4a1682335e05c6e5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 19 May 2019 11:38:42 -0700 Subject: [PATCH] Fix RoIAlign and RoIPool for non-contiguous gradients --- test/test_ops.py | 35 ++++++++++++++++++++++++++ torchvision/csrc/cpu/ROIAlign_cpu.cpp | 2 +- torchvision/csrc/cpu/ROIPool_cpu.cpp | 2 +- torchvision/csrc/cuda/ROIAlign_cuda.cu | 2 +- torchvision/csrc/cuda/ROIPool_cuda.cu | 2 +- 5 files changed, 39 insertions(+), 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 6374b4c93c8..737d9186df7 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -135,6 +135,41 @@ def test_roi_pool_gradient_cpu(self): assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for roi_pool' + def test_roi_pool_align_non_cont_grad_cpu(self): + devices = ['cpu'] + if torch.cuda.is_available(): + devices.append('cuda') + + for d in devices: + device = torch.device(d) + rois = torch.tensor([ + [0, 0, 0, 9, 9], + [0, 0, 5, 5, 9], + [0, 5, 5, 9, 9]], dtype=self.dtype, device=device) + + grad_cont = torch.rand(3, 1, 5, 5, dtype=self.dtype, device=device) + grad = grad_cont.permute(2, 1, 3, 0).contiguous().permute(3, 1, 0, 2) + + for op in ['RoIPool', 'RoIAlign']: + x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True) + kwargs = {} + if op == 'RoIAlign': + kwargs['sampling_ratio'] = 1 + m = getattr(ops, op)((5, 5), 1, **kwargs) + + y = m(x, rois) + y.backward(grad_cont) + + g1 = x.grad.detach().clone() + del x.grad + + y = m(x, rois) + y.backward(grad) + + g2 = x.grad.detach().clone() + del x.grad + assert torch.allclose(g1, g2), 'gradient incorrect for {}'.format(op) + def test_roi_pool_gradcheck_cpu(self): device = torch.device('cpu') x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True) diff --git a/torchvision/csrc/cpu/ROIAlign_cpu.cpp b/torchvision/csrc/cpu/ROIAlign_cpu.cpp index f854455efc9..8d7841a7861 100644 --- a/torchvision/csrc/cpu/ROIAlign_cpu.cpp +++ b/torchvision/csrc/cpu/ROIAlign_cpu.cpp @@ -456,7 +456,7 @@ at::Tensor ROIAlign_backward_cpu( AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] { ROIAlignBackward( grad.numel(), - grad.contiguous().data(), + grad.data(), spatial_scale, channels, height, diff --git a/torchvision/csrc/cpu/ROIPool_cpu.cpp b/torchvision/csrc/cpu/ROIPool_cpu.cpp index 9d81a728d16..6ca3f46cfd7 100644 --- a/torchvision/csrc/cpu/ROIPool_cpu.cpp +++ b/torchvision/csrc/cpu/ROIPool_cpu.cpp @@ -205,7 +205,7 @@ at::Tensor ROIPool_backward_cpu( AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] { RoIPoolBackward( - grad.contiguous().data(), + grad.data(), argmax.data(), num_rois, channels, diff --git a/torchvision/csrc/cuda/ROIAlign_cuda.cu b/torchvision/csrc/cuda/ROIAlign_cuda.cu index 8d68d20ee53..d7e999f08e6 100644 --- a/torchvision/csrc/cuda/ROIAlign_cuda.cu +++ b/torchvision/csrc/cuda/ROIAlign_cuda.cu @@ -396,7 +396,7 @@ at::Tensor ROIAlign_backward_cuda( AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] { RoIAlignBackward<<>>( grad.numel(), - grad.contiguous().data(), + grad.data(), spatial_scale, channels, height, diff --git a/torchvision/csrc/cuda/ROIPool_cuda.cu b/torchvision/csrc/cuda/ROIPool_cuda.cu index ff60447a92f..3ad9f1518e9 100644 --- a/torchvision/csrc/cuda/ROIPool_cuda.cu +++ b/torchvision/csrc/cuda/ROIPool_cuda.cu @@ -221,7 +221,7 @@ at::Tensor ROIPool_backward_cuda( AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] { RoIPoolBackward<<>>( grad.numel(), - grad.contiguous().data(), + grad.data(), argmax.contiguous().data(), num_rois, spatial_scale,