@@ -304,20 +304,20 @@ def test_qroialign(self):
304304        pool_size  =  5 
305305        img_size  =  10 
306306        n_channels  =  2 
307-         num_batches  =  2 
307+         num_imgs  =  2 
308308        dtype  =  torch .float 
309309
310310        def  make_rois (num_rois = 1000 ):
311311            rois  =  torch .randint (0 , img_size  //  2 , size = (num_rois , 5 )).to (dtype )
312-             rois [:, 0 ] =  torch .randint (0 , num_batches , size = (num_rois ,))  # set batch index 
312+             rois [:, 0 ] =  torch .randint (0 , num_imgs , size = (num_rois ,))  # set batch index 
313313            rois [:, 3 :] +=  rois [:, 1 :3 ]  # make sure boxes aren't degenerate 
314314            return  rois 
315315
316316        for  aligned  in  (True , False ):
317317            for  scale , zero_point  in  ((1 , 0 ), (2 , 10 ), (0.1 , 50 )):
318318                for  qdtype  in  (torch .qint8 , torch .quint8 , torch .qint32 ):
319319
320-                     x  =  torch .randint (50 , 100 , size = (num_batches , n_channels , img_size , img_size )).to (dtype )
320+                     x  =  torch .randint (50 , 100 , size = (num_imgs , n_channels , img_size , img_size )).to (dtype )
321321                    qx  =  torch .quantize_per_tensor (x , scale = scale , zero_point = zero_point , dtype = qdtype )
322322
323323                    rois  =  make_rois ()
@@ -364,6 +364,13 @@ def make_rois(num_rois=1000):
364364                        t_scale  =  torch .full_like (abs_diff , fill_value = scale )
365365                        self .assertTrue (torch .allclose (abs_diff , t_scale , atol = 1e-5 ))
366366
367+         x  =  torch .randint (50 , 100 , size = (129 , 3 , 10 , 10 )).to (dtype )
368+         qx  =  torch .quantize_per_tensor (x , scale = 0 , zero_point = 1 , dtype = torch .qint8 )
369+         rois  =  make_rois (10 )
370+         qrois  =  torch .quantize_per_tensor (rois , scale = 0 , zero_point = 1 , dtype = torch .qint8 )
371+         with  self .assertRaisesRegex (RuntimeError , "There are 129 input images in the batch, but the RoIs tensor" ):
372+             ops .roi_align (qx , qrois , output_size = pool_size )
373+ 
367374
368375class  PSRoIAlignTester (RoIOpTester , unittest .TestCase ):
369376    def  fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
0 commit comments