1010import torch
1111import torch .fx
1212import torch .nn .functional as F
13- from common_utils import assert_equal , cpu_and_cuda , needs_cuda
13+ from common_utils import assert_equal , cpu_and_cuda , cpu_and_cuda_and_mps , needs_cuda , needs_mps
1414from PIL import Image
1515from torch import nn , Tensor
1616from torch .autograd import gradcheck
@@ -96,12 +96,33 @@ def forward(self, imgs: Tensor, boxes: List[Tensor]) -> Tensor:
9696
9797class RoIOpTester (ABC ):
9898 dtype = torch .float64
99+ mps_dtype = torch .float32
100+ mps_backward_atol = 2e-2
99101
100- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
102+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
101103 @pytest .mark .parametrize ("contiguous" , (True , False ))
102- def test_forward (self , device , contiguous , x_dtype = None , rois_dtype = None , deterministic = False , ** kwargs ):
103- x_dtype = self .dtype if x_dtype is None else x_dtype
104- rois_dtype = self .dtype if rois_dtype is None else rois_dtype
104+ @pytest .mark .parametrize (
105+ "x_dtype" ,
106+ (
107+ torch .float16 ,
108+ torch .float32 ,
109+ torch .float64 ,
110+ ),
111+ ids = str ,
112+ )
113+ def test_forward (self , device , contiguous , x_dtype , rois_dtype = None , deterministic = False , ** kwargs ):
114+ if device == "mps" and x_dtype is torch .float64 :
115+ pytest .skip ("MPS does not support float64" )
116+
117+ rois_dtype = x_dtype if rois_dtype is None else rois_dtype
118+
119+ tol = 1e-5
120+ if x_dtype is torch .half :
121+ if device == "mps" :
122+ tol = 5e-3
123+ else :
124+ tol = 4e-3
125+
105126 pool_size = 5
106127 # n_channels % (pool_size ** 2) == 0 required for PS operations.
107128 n_channels = 2 * (pool_size ** 2 )
@@ -120,10 +141,9 @@ def test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None, determ
120141 # the following should be true whether we're running an autocast test or not.
121142 assert y .dtype == x .dtype
122143 gt_y = self .expected_fn (
123- x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , device = device , dtype = self . dtype , ** kwargs
144+ x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , device = device , dtype = x_dtype , ** kwargs
124145 )
125146
126- tol = 1e-3 if (x_dtype is torch .half or rois_dtype is torch .half ) else 1e-5
127147 torch .testing .assert_close (gt_y .to (y ), y , rtol = tol , atol = tol )
128148
129149 @pytest .mark .parametrize ("device" , cpu_and_cuda ())
@@ -155,16 +175,19 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
155175 torch .testing .assert_close (output_gt , output_fx , rtol = tol , atol = tol )
156176
157177 @pytest .mark .parametrize ("seed" , range (10 ))
158- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
178+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
159179 @pytest .mark .parametrize ("contiguous" , (True , False ))
160180 def test_backward (self , seed , device , contiguous , deterministic = False ):
181+ atol = self .mps_backward_atol if device == "mps" else 1e-05
182+ dtype = self .mps_dtype if device == "mps" else self .dtype
183+
161184 torch .random .manual_seed (seed )
162185 pool_size = 2
163- x = torch .rand (1 , 2 * (pool_size ** 2 ), 5 , 5 , dtype = self . dtype , device = device , requires_grad = True )
186+ x = torch .rand (1 , 2 * (pool_size ** 2 ), 5 , 5 , dtype = dtype , device = device , requires_grad = True )
164187 if not contiguous :
165188 x = x .permute (0 , 1 , 3 , 2 )
166189 rois = torch .tensor (
167- [[0 , 0 , 0 , 4 , 4 ], [0 , 0 , 2 , 3 , 4 ], [0 , 2 , 2 , 4 , 4 ]], dtype = self . dtype , device = device # format is (xyxy)
190+ [[0 , 0 , 0 , 4 , 4 ], [0 , 0 , 2 , 3 , 4 ], [0 , 2 , 2 , 4 , 4 ]], dtype = dtype , device = device # format is (xyxy)
168191 )
169192
170193 def func (z ):
@@ -173,9 +196,25 @@ def func(z):
173196 script_func = self .get_script_fn (rois , pool_size )
174197
175198 with DeterministicGuard (deterministic ):
176- gradcheck (func , (x ,))
199+ gradcheck (func , (x ,), atol = atol )
200+
201+ gradcheck (script_func , (x ,), atol = atol )
177202
178- gradcheck (script_func , (x ,))
203+ @needs_mps
204+ def test_mps_error_inputs (self ):
205+ pool_size = 2
206+ x = torch .rand (1 , 2 * (pool_size ** 2 ), 5 , 5 , dtype = torch .float16 , device = "mps" , requires_grad = True )
207+ rois = torch .tensor (
208+ [[0 , 0 , 0 , 4 , 4 ], [0 , 0 , 2 , 3 , 4 ], [0 , 2 , 2 , 4 , 4 ]], dtype = torch .float16 , device = "mps" # format is (xyxy)
209+ )
210+
211+ def func (z ):
212+ return self .fn (z , rois , pool_size , pool_size , spatial_scale = 1 , sampling_ratio = 1 )
213+
214+ with pytest .raises (
215+ RuntimeError , match = "MPS does not support (?:ps_)?roi_(?:align|pool)? backward with float16 inputs."
216+ ):
217+ gradcheck (func , (x ,))
179218
180219 @needs_cuda
181220 @pytest .mark .parametrize ("x_dtype" , (torch .float , torch .half ))
@@ -271,6 +310,8 @@ def test_jit_boxes_list(self):
271310
272311
273312class TestPSRoIPool (RoIOpTester ):
313+ mps_backward_atol = 5e-2
314+
274315 def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
275316 return ops .PSRoIPool ((pool_h , pool_w ), 1 )(x , rois )
276317
@@ -352,6 +393,8 @@ def bilinear_interpolate(data, y, x, snap_border=False):
352393
353394
354395class TestRoIAlign (RoIOpTester ):
396+ mps_backward_atol = 6e-2
397+
355398 def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , aligned = False , ** kwargs ):
356399 return ops .RoIAlign (
357400 (pool_h , pool_w ), spatial_scale = spatial_scale , sampling_ratio = sampling_ratio , aligned = aligned
@@ -418,10 +461,11 @@ def test_boxes_shape(self):
418461 self ._helper_boxes_shape (ops .roi_align )
419462
420463 @pytest .mark .parametrize ("aligned" , (True , False ))
421- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
464+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
465+ @pytest .mark .parametrize ("x_dtype" , (torch .float16 , torch .float32 , torch .float64 ), ids = str )
422466 @pytest .mark .parametrize ("contiguous" , (True , False ))
423467 @pytest .mark .parametrize ("deterministic" , (True , False ))
424- def test_forward (self , device , contiguous , deterministic , aligned , x_dtype = None , rois_dtype = None ):
468+ def test_forward (self , device , contiguous , deterministic , aligned , x_dtype , rois_dtype = None ):
425469 if deterministic and device == "cpu" :
426470 pytest .skip ("cpu is always deterministic, don't retest" )
427471 super ().test_forward (
@@ -450,7 +494,7 @@ def test_autocast(self, aligned, deterministic, x_dtype, rois_dtype):
450494 )
451495
452496 @pytest .mark .parametrize ("seed" , range (10 ))
453- @pytest .mark .parametrize ("device" , cpu_and_cuda ())
497+ @pytest .mark .parametrize ("device" , cpu_and_cuda_and_mps ())
454498 @pytest .mark .parametrize ("contiguous" , (True , False ))
455499 @pytest .mark .parametrize ("deterministic" , (True , False ))
456500 def test_backward (self , seed , device , contiguous , deterministic ):
@@ -537,6 +581,8 @@ def test_jit_boxes_list(self):
537581
538582
539583class TestPSRoIAlign (RoIOpTester ):
584+ mps_backward_atol = 5e-2
585+
540586 def fn (self , x , rois , pool_h , pool_w , spatial_scale = 1 , sampling_ratio = - 1 , ** kwargs ):
541587 return ops .PSRoIAlign ((pool_h , pool_w ), spatial_scale = spatial_scale , sampling_ratio = sampling_ratio )(x , rois )
542588
@@ -705,40 +751,53 @@ def test_qnms(self, iou, scale, zero_point):
705751
706752 torch .testing .assert_close (qkeep , keep , msg = err_msg .format (iou ))
707753
708- @needs_cuda
754+ @pytest .mark .parametrize (
755+ "device" ,
756+ (
757+ pytest .param ("cuda" , marks = pytest .mark .needs_cuda ),
758+ pytest .param ("mps" , marks = pytest .mark .needs_mps ),
759+ ),
760+ )
709761 @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
710- def test_nms_cuda (self , iou , dtype = torch .float64 ):
762+ def test_nms_gpu (self , iou , device , dtype = torch .float64 ):
763+ dtype = torch .float32 if device == "mps" else dtype
711764 tol = 1e-3 if dtype is torch .half else 1e-5
712765 err_msg = "NMS incompatible between CPU and CUDA for IoU={}"
713766
714767 boxes , scores = self ._create_tensors_with_iou (1000 , iou )
715768 r_cpu = ops .nms (boxes , scores , iou )
716- r_cuda = ops .nms (boxes .cuda ( ), scores .cuda ( ), iou )
769+ r_gpu = ops .nms (boxes .to ( device ), scores .to ( device ), iou )
717770
718- is_eq = torch .allclose (r_cpu , r_cuda .cpu ())
771+ is_eq = torch .allclose (r_cpu , r_gpu .cpu ())
719772 if not is_eq :
720773 # if the indices are not the same, ensure that it's because the scores
721774 # are duplicate
722- is_eq = torch .allclose (scores [r_cpu ], scores [r_cuda .cpu ()], rtol = tol , atol = tol )
775+ is_eq = torch .allclose (scores [r_cpu ], scores [r_gpu .cpu ()], rtol = tol , atol = tol )
723776 assert is_eq , err_msg .format (iou )
724777
725778 @needs_cuda
726779 @pytest .mark .parametrize ("iou" , (0.2 , 0.5 , 0.8 ))
727780 @pytest .mark .parametrize ("dtype" , (torch .float , torch .half ))
728781 def test_autocast (self , iou , dtype ):
729782 with torch .cuda .amp .autocast ():
730- self .test_nms_cuda (iou = iou , dtype = dtype )
783+ self .test_nms_gpu (iou = iou , dtype = dtype , device = "cuda" )
731784
732- @needs_cuda
733- def test_nms_cuda_float16 (self ):
785+ @pytest .mark .parametrize (
786+ "device" ,
787+ (
788+ pytest .param ("cuda" , marks = pytest .mark .needs_cuda ),
789+ pytest .param ("mps" , marks = pytest .mark .needs_mps ),
790+ ),
791+ )
792+ def test_nms_float16 (self , device ):
734793 boxes = torch .tensor (
735794 [
736795 [285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
737796 [285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
738797 [279.2440 , 197.9812 , 1189.4746 , 849.2019 ],
739798 ]
740- ).cuda ( )
741- scores = torch .tensor ([0.6370 , 0.7569 , 0.3966 ]).cuda ( )
799+ ).to ( device )
800+ scores = torch .tensor ([0.6370 , 0.7569 , 0.3966 ]).to ( device )
742801
743802 iou_thres = 0.2
744803 keep32 = ops .nms (boxes , scores , iou_thres )
0 commit comments