Skip to content

Commit a4d1adf

Browse files
committed
Improve RoIPool test
1 parent a129b6b commit a4d1adf

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

test/test_ops.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
2020
c = x.size(1)
2121
y = torch.zeros(rois.size(0), c, pool_h, pool_w, dtype=dtype, device=device)
2222

23-
rois = torch.round(rois * spatial_scale)
24-
25-
for n in range(0, y.size(0)):
23+
for n in range(0, x.size(0)):
2624
for r, roi in enumerate(rois):
2725
if roi[0] == n:
26+
roi[1:] = torch.round(roi[1:] * spatial_scale)
2827
start_h, end_h = int(roi[2].item()), int(roi[4].item()) + 1
2928
start_w, end_w = int(roi[1].item()), int(roi[3].item()) + 1
3029
roi_x = x[roi[0].long(), :, start_h:end_h, start_w:end_w]
@@ -58,6 +57,12 @@ def test_roi_pool_basic_cpu(self):
5857
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w, device=device, dtype=self.dtype)
5958
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU'
6059

60+
# spatial-scale != 1
61+
y = ops.RoIPool((pool_h, pool_w), 2)(x.permute(0, 1, 3, 2), rois)
62+
gt_y = self.slow_roi_pooling(x.permute(0, 1, 3, 2), rois, pool_h, pool_w,
63+
spatial_scale=2, device=device, dtype=self.dtype)
64+
assert torch.allclose(gt_y, y), 'RoIPool layer incorrect on CPU'
65+
6166
def test_roi_pool_cpu(self):
6267
device = torch.device('cpu')
6368
x = torch.rand(2, 1, 10, 10, dtype=self.dtype, device=device)

0 commit comments

Comments
 (0)