@@ -20,11 +20,10 @@ def slow_roi_pooling(self, x, rois, pool_h, pool_w, spatial_scale=1,
2020 c = x .size (1 )
2121 y = torch .zeros (rois .size (0 ), c , pool_h , pool_w , dtype = dtype , device = device )
2222
23- rois = torch .round (rois * spatial_scale )
24-
25- for n in range (0 , y .size (0 )):
23+ for n in range (0 , x .size (0 )):
2624 for r , roi in enumerate (rois ):
2725 if roi [0 ] == n :
26+ roi [1 :] = torch .round (roi [1 :] * spatial_scale )
2827 start_h , end_h = int (roi [2 ].item ()), int (roi [4 ].item ()) + 1
2928 start_w , end_w = int (roi [1 ].item ()), int (roi [3 ].item ()) + 1
3029 roi_x = x [roi [0 ].long (), :, start_h :end_h , start_w :end_w ]
@@ -58,6 +57,12 @@ def test_roi_pool_basic_cpu(self):
5857 gt_y = self .slow_roi_pooling (x .permute (0 , 1 , 3 , 2 ), rois , pool_h , pool_w , device = device , dtype = self .dtype )
5958 assert torch .allclose (gt_y , y ), 'RoIPool layer incorrect on CPU'
6059
60+ # spatial-scale != 1
61+ y = ops .RoIPool ((pool_h , pool_w ), 2 )(x .permute (0 , 1 , 3 , 2 ), rois )
62+ gt_y = self .slow_roi_pooling (x .permute (0 , 1 , 3 , 2 ), rois , pool_h , pool_w ,
63+ spatial_scale = 2 , device = device , dtype = self .dtype )
64+ assert torch .allclose (gt_y , y ), 'RoIPool layer incorrect on CPU'
65+
6166 def test_roi_pool_cpu (self ):
6267 device = torch .device ('cpu' )
6368 x = torch .rand (2 , 1 , 10 , 10 , dtype = self .dtype , device = device )
0 commit comments