@@ -54,7 +54,7 @@ def _test_backward(self, device, contiguous):
5454
5555
5656class RoIOpTester (OpTester ):
57- def _test_forward (self , device , contiguous , x_dtype = None , rois_dtype = None ):
57+ def _test_forward (self , device , contiguous , x_dtype = None , rois_dtype = None , ** kwargs ):
5858 x_dtype = self .dtype if x_dtype is None else x_dtype
5959 rois_dtype = self .dtype if rois_dtype is None else rois_dtype
6060 pool_size = 5
@@ -70,11 +70,11 @@ def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None):
7070 dtype = rois_dtype , device = device )
7171
7272 pool_h , pool_w = pool_size , pool_size
73- y = self .fn (x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 )
73+ y = self .fn (x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs )
7474 # the following should be true whether we're running an autocast test or not.
7575 self .assertTrue (y .dtype == x .dtype )
7676 gt_y = self .expected_fn (x , rois , pool_h , pool_w , spatial_scale = 1 ,
77- sampling_ratio = - 1 , device = device , dtype = self .dtype )
77+ sampling_ratio = - 1 , device = device , dtype = self .dtype , ** kwargs )
7878
7979 tol = 1e-3 if (x_dtype is torch .half or rois_dtype is torch .half ) else 1e-5
8080 self .assertTrue (torch .allclose (gt_y .to (y .dtype ), y , rtol = tol , atol = tol ))
@@ -304,6 +304,10 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
304304 def _test_boxes_shape (self ):
305305 self ._helper_boxes_shape (ops .roi_align )
306306
307+ def _test_forward (self , device , contiguous , x_dtype = None , rois_dtype = None , ** kwargs ):
308+ for aligned in (True , False ):
309+ super ()._test_forward (device , contiguous , x_dtype , rois_dtype , aligned = aligned )
310+
307311
308312class PSRoIAlignTester (RoIOpTester , unittest .TestCase ):
309313 def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
0 commit comments