@@ -82,16 +82,30 @@ def to_numpy(tensor):
8282 raise
8383
8484 def test_nms (self ):
85- boxes = torch .rand (5 , 4 )
86- boxes [:, 2 :] += torch .rand (5 , 2 )
87- scores = torch .randn (5 )
85+ num_boxes = 100
86+ boxes = torch .rand (num_boxes , 4 )
87+ boxes [:, 2 :] += boxes [:, :2 ]
88+ scores = torch .randn (num_boxes )
8889
8990 class Module (torch .nn .Module ):
9091 def forward (self , boxes , scores ):
9192 return ops .nms (boxes , scores , 0.5 )
9293
9394 self .run_model (Module (), [(boxes , scores )])
9495
96+ def test_batched_nms (self ):
97+ num_boxes = 100
98+ boxes = torch .rand (num_boxes , 4 )
99+ boxes [:, 2 :] += boxes [:, :2 ]
100+ scores = torch .randn (num_boxes )
101+ idxs = torch .randint (0 , 5 , size = (num_boxes ,))
102+
103+ class Module (torch .nn .Module ):
104+ def forward (self , boxes , scores , idxs ):
105+ return ops .batched_nms (boxes , scores , idxs , 0.5 )
106+
107+ self .run_model (Module (), [(boxes , scores , idxs )])
108+
95109 def test_clip_boxes_to_image (self ):
96110 boxes = torch .randn (5 , 4 ) * 500
97111 boxes [:, 2 :] += boxes [:, :2 ]
0 commit comments