Skip to content

Commit 3428a7d

Browse files
authored
Added test for aligned=True (#3540)
1 parent 0139808 commit 3428a7d

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

test/test_ops.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def _test_backward(self, device, contiguous):
5454

5555

5656
class RoIOpTester(OpTester):
57-
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None):
57+
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
5858
x_dtype = self.dtype if x_dtype is None else x_dtype
5959
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
6060
pool_size = 5
@@ -70,11 +70,11 @@ def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None):
7070
dtype=rois_dtype, device=device)
7171

7272
pool_h, pool_w = pool_size, pool_size
73-
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
73+
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs)
7474
# the following should be true whether we're running an autocast test or not.
7575
self.assertTrue(y.dtype == x.dtype)
7676
gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1,
77-
sampling_ratio=-1, device=device, dtype=self.dtype)
77+
sampling_ratio=-1, device=device, dtype=self.dtype, **kwargs)
7878

7979
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
8080
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol))
@@ -304,6 +304,10 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
304304
def _test_boxes_shape(self):
305305
self._helper_boxes_shape(ops.roi_align)
306306

307+
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, **kwargs):
308+
for aligned in (True, False):
309+
super()._test_forward(device, contiguous, x_dtype, rois_dtype, aligned=aligned)
310+
307311

308312
class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
309313
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):

0 commit comments

Comments
 (0)