@@ -65,11 +65,11 @@ def func(z):
6565 gradcheck (func , (x ,))
6666 gradcheck (script_func , (x ,))
6767
68- @pytest .mark .parametrize (' device' , cpu_and_gpu ())
69- @pytest .mark .parametrize (' x_dtype' , (torch .float , torch .half ))
70- @pytest .mark .parametrize (' rois_dtype' , (torch .float , torch .half ))
68+ @pytest .mark .parametrize (" device" , cpu_and_gpu ())
69+ @pytest .mark .parametrize (" x_dtype" , (torch .float , torch .half ))
70+ @pytest .mark .parametrize (" rois_dtype" , (torch .float , torch .half ))
7171 def test_autocast (self , device , x_dtype , rois_dtype ):
72- cm = torch .cpu .amp .autocast if device == ' cpu' else torch .cuda .amp .autocast
72+ cm = torch .cpu .amp .autocast if device == " cpu" else torch .cuda .amp .autocast
7373 with cm ():
7474 self .test_forward (torch .device (device ), contiguous = False , x_dtype = x_dtype , rois_dtype = rois_dtype )
7575
@@ -285,15 +285,16 @@ def test_forward(self, device, contiguous, aligned, x_dtype=None, rois_dtype=Non
285285 device = device , contiguous = contiguous , x_dtype = x_dtype , rois_dtype = rois_dtype , aligned = aligned
286286 )
287287
288- @pytest .mark .parametrize (' device' , cpu_and_gpu ())
289- @pytest .mark .parametrize (' aligned' , (True , False ))
290- @pytest .mark .parametrize (' x_dtype' , (torch .float , torch .half ))
291- @pytest .mark .parametrize (' rois_dtype' , (torch .float , torch .half ))
288+ @pytest .mark .parametrize (" device" , cpu_and_gpu ())
289+ @pytest .mark .parametrize (" aligned" , (True , False ))
290+ @pytest .mark .parametrize (" x_dtype" , (torch .float , torch .half ))
291+ @pytest .mark .parametrize (" rois_dtype" , (torch .float , torch .half ))
292292 def test_autocast (self , device , aligned , x_dtype , rois_dtype ):
293- cm = torch .cpu .amp .autocast if device == ' cpu' else torch .cuda .amp .autocast
293+ cm = torch .cpu .amp .autocast if device == " cpu" else torch .cuda .amp .autocast
294294 with cm ():
295- self .test_forward (torch .device (device ), contiguous = False , aligned = aligned , x_dtype = x_dtype ,
296- rois_dtype = rois_dtype )
295+ self .test_forward (
296+ torch .device (device ), contiguous = False , aligned = aligned , x_dtype = x_dtype , rois_dtype = rois_dtype
297+ )
297298
298299 def _make_rois (self , img_size , num_imgs , dtype , num_rois = 1000 ):
299300 rois = torch .randint (0 , img_size // 2 , size = (num_rois , 5 )).to (dtype )
@@ -533,12 +534,12 @@ def test_nms_cuda(self, iou, dtype=torch.float64):
533534 is_eq = torch .allclose (scores [r_cpu ], scores [r_cuda .cpu ()], rtol = tol , atol = tol )
534535 assert is_eq , err_msg .format (iou )
535536
536- @pytest .mark .parametrize (' device' , cpu_and_gpu ())
537- @pytest .mark .parametrize ("iou" , (.2 , .5 , .8 ))
537+ @pytest .mark .parametrize (" device" , cpu_and_gpu ())
538+ @pytest .mark .parametrize ("iou" , (0 .2 , 0 .5 , 0 .8 ))
538539 @pytest .mark .parametrize ("dtype" , (torch .float , torch .half ))
539540 def test_autocast (self , device , iou , dtype ):
540- test_fn = self .test_nms_ref if device == ' cpu' else partial (self .test_nms_cuda , dtype = dtype )
541- cm = torch .cpu .amp .autocast if device == ' cpu' else torch .cuda .amp .autocast
541+ test_fn = self .test_nms_ref if device == " cpu" else partial (self .test_nms_cuda , dtype = dtype )
542+ cm = torch .cpu .amp .autocast if device == " cpu" else torch .cuda .amp .autocast
542543 with cm ():
543544 test_fn (iou = iou )
544545
@@ -826,11 +827,11 @@ def test_compare_cpu_cuda_grads(self, contiguous):
826827 res_grads = init_weight .grad .to ("cpu" )
827828 torch .testing .assert_close (true_cpu_grads , res_grads )
828829
829- @pytest .mark .parametrize (' device' , cpu_and_gpu ())
830- @pytest .mark .parametrize (' batch_sz' , (0 , 33 ))
831- @pytest .mark .parametrize (' dtype' , (torch .float , torch .half ))
830+ @pytest .mark .parametrize (" device" , cpu_and_gpu ())
831+ @pytest .mark .parametrize (" batch_sz" , (0 , 33 ))
832+ @pytest .mark .parametrize (" dtype" , (torch .float , torch .half ))
832833 def test_autocast (self , device , batch_sz , dtype ):
833- cm = torch .cpu .amp .autocast if device == ' cpu' else torch .cuda .amp .autocast
834+ cm = torch .cpu .amp .autocast if device == " cpu" else torch .cuda .amp .autocast
834835 with cm ():
835836 self .test_forward (torch .device (device ), contiguous = False , batch_sz = batch_sz , dtype = dtype )
836837
0 commit comments