Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,41 @@ def test_roi_pool_gradient_cpu(self):

assert torch.allclose(x.grad, gt_grad), 'gradient incorrect for roi_pool'

def test_roi_pool_align_non_cont_grad_cpu(self):
devices = ['cpu']
if torch.cuda.is_available():
devices.append('cuda')

for d in devices:
device = torch.device(d)
rois = torch.tensor([
[0, 0, 0, 9, 9],
[0, 0, 5, 5, 9],
[0, 5, 5, 9, 9]], dtype=self.dtype, device=device)

grad_cont = torch.rand(3, 1, 5, 5, dtype=self.dtype, device=device)
grad = grad_cont.permute(2, 1, 3, 0).contiguous().permute(3, 1, 0, 2)

for op in ['RoIPool', 'RoIAlign']:
x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
kwargs = {}
if op == 'RoIAlign':
kwargs['sampling_ratio'] = 1
m = getattr(ops, op)((5, 5), 1, **kwargs)

y = m(x, rois)
y.backward(grad_cont)

g1 = x.grad.detach().clone()
del x.grad

y = m(x, rois)
y.backward(grad)

g2 = x.grad.detach().clone()
del x.grad
assert torch.allclose(g1, g2), 'gradient incorrect for {}'.format(op)

def test_roi_pool_gradcheck_cpu(self):
device = torch.device('cpu')
x = torch.rand(1, 1, 10, 10, dtype=self.dtype, device=device, requires_grad=True)
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/ROIAlign_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ at::Tensor ROIAlign_backward_cpu(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_forward", [&] {
ROIAlignBackward<scalar_t>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
grad.data<scalar_t>(),
spatial_scale,
channels,
height,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/ROIPool_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ at::Tensor ROIPool_backward_cpu(

AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] {
RoIPoolBackward<scalar_t>(
grad.contiguous().data<scalar_t>(),
grad.data<scalar_t>(),
argmax.data<int>(),
num_rois,
channels,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cuda/ROIAlign_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ at::Tensor ROIAlign_backward_cuda(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIAlign_backward", [&] {
RoIAlignBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
grad.data<scalar_t>(),
spatial_scale,
channels,
height,
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cuda/ROIPool_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ at::Tensor ROIPool_backward_cuda(
AT_DISPATCH_FLOATING_TYPES_AND_HALF(grad.type(), "ROIPool_backward", [&] {
RoIPoolBackward<scalar_t><<<grid, block, 0, stream>>>(
grad.numel(),
grad.contiguous().data<scalar_t>(),
grad.data<scalar_t>(),
argmax.contiguous().data<int>(),
num_rois,
spatial_scale,
Expand Down