@@ -1111,14 +1111,6 @@ def test_bbox_convert_jit(self):
11111111 torch .testing .assert_close (scripted_cxcywh , box_cxcywh )
11121112
11131113
1114- INT_BOXES = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]]
1115- FLOAT_BOXES = [
1116- [285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
1117- [285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
1118- [279.2440 , 197.9812 , 1189.4746 , 849.2019 ],
1119- ]
1120-
1121-
11221114class TestBoxArea :
11231115 def area_check (self , box , expected , atol = 1e-4 ):
11241116 out = ops .box_area (box )
@@ -1152,99 +1144,155 @@ def test_box_area_jit(self):
11521144 torch .testing .assert_close (scripted_area , expected )
11531145
11541146
1147+ INT_BOXES = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ], [0 , 0 , 25 , 25 ]]
1148+ INT_BOXES2 = [[0 , 0 , 100 , 100 ], [0 , 0 , 50 , 50 ], [200 , 200 , 300 , 300 ]]
1149+ FLOAT_BOXES = [
1150+ [285.3538 , 185.5758 , 1193.5110 , 851.4551 ],
1151+ [285.1472 , 188.7374 , 1192.4984 , 851.0669 ],
1152+ [279.2440 , 197.9812 , 1189.4746 , 849.2019 ],
1153+ ]
1154+
1155+
1156+ def gen_box (size , dtype = torch .float ):
1157+ xy1 = torch .rand ((size , 2 ), dtype = dtype )
1158+ xy2 = xy1 + torch .rand ((size , 2 ), dtype = dtype )
1159+ return torch .cat ([xy1 , xy2 ], axis = - 1 )
1160+
1161+
11551162class TestIouBase :
11561163 @staticmethod
1157- def _run_test (target_fn : Callable , test_input : List , dtypes : List [ torch . dtype ] , atol : float , expected : List ):
1164+ def _run_test (target_fn : Callable , actual_box1 , actual_box2 , dtypes , atol , expected ):
11581165 for dtype in dtypes :
1159- actual_box = torch .tensor (test_input , dtype = dtype )
1166+ actual_box1 = torch .tensor (actual_box1 , dtype = dtype )
1167+ actual_box2 = torch .tensor (actual_box2 , dtype = dtype )
11601168 expected_box = torch .tensor (expected )
1161- out = target_fn (actual_box , actual_box )
1169+ out = target_fn (actual_box1 , actual_box2 )
11621170 torch .testing .assert_close (out , expected_box , rtol = 0.0 , check_dtype = False , atol = atol )
11631171
11641172 @staticmethod
1165- def _run_jit_test (target_fn : Callable , test_input : List ):
1166- box_tensor = torch .tensor (test_input , dtype = torch .float )
1173+ def _run_jit_test (target_fn : Callable , actual_box : List ):
1174+ box_tensor = torch .tensor (actual_box , dtype = torch .float )
11671175 expected = target_fn (box_tensor , box_tensor )
11681176 scripted_fn = torch .jit .script (target_fn )
11691177 scripted_out = scripted_fn (box_tensor , box_tensor )
11701178 torch .testing .assert_close (scripted_out , expected )
11711179
1180+ @staticmethod
1181+ def _cartesian_product (boxes1 , boxes2 , target_fn : Callable ):
1182+ N = boxes1 .size (0 )
1183+ M = boxes2 .size (0 )
1184+ result = torch .zeros ((N , M ))
1185+ for i in range (N ):
1186+ for j in range (M ):
1187+ result [i , j ] = target_fn (boxes1 [i ].unsqueeze (0 ), boxes2 [j ].unsqueeze (0 ))
1188+ return result
1189+
1190+ @staticmethod
1191+ def _run_cartesian_test (target_fn : Callable ):
1192+ boxes1 = gen_box (5 )
1193+ boxes2 = gen_box (7 )
1194+ a = TestIouBase ._cartesian_product (boxes1 , boxes2 , target_fn )
1195+ b = target_fn (boxes1 , boxes2 )
1196+ assert torch .allclose (a , b )
1197+
11721198
11731199class TestBoxIou (TestIouBase ):
1174- int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]]
1200+ int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ], [ 0.0625 , 0.25 , 0.0 ] ]
11751201 float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
11761202
11771203 @pytest .mark .parametrize (
1178- "test_input , dtypes, atol, expected" ,
1204+ "actual_box1, actual_box2 , dtypes, atol, expected" ,
11791205 [
1180- pytest .param (INT_BOXES , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1181- pytest .param (FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1182- pytest .param (FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
1206+ pytest .param (INT_BOXES , INT_BOXES2 , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1207+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1208+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
11831209 ],
11841210 )
1185- def test_iou (self , test_input , dtypes , atol , expected ):
1186- self ._run_test (ops .box_iou , test_input , dtypes , atol , expected )
1211+ def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1212+ self ._run_test (ops .box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
11871213
11881214 def test_iou_jit (self ):
11891215 self ._run_jit_test (ops .box_iou , INT_BOXES )
11901216
1217+ def test_iou_cartesian (self ):
1218+ self ._run_cartesian_test (ops .box_iou )
1219+
11911220
11921221class TestGeneralizedBoxIou (TestIouBase ):
1193- int_expected = [[1.0 , 0.25 , - 0.7778 ], [0.25 , 1.0 , - 0.8611 ], [- 0.7778 , - 0.8611 , 1.0 ]]
1222+ int_expected = [[1.0 , 0.25 , - 0.7778 ], [0.25 , 1.0 , - 0.8611 ], [- 0.7778 , - 0.8611 , 1.0 ], [ 0.0625 , 0.25 , - 0.8819 ] ]
11941223 float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
11951224
11961225 @pytest .mark .parametrize (
1197- "test_input , dtypes, atol, expected" ,
1226+ "actual_box1, actual_box2 , dtypes, atol, expected" ,
11981227 [
1199- pytest .param (INT_BOXES , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1200- pytest .param (FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1201- pytest .param (FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
1228+ pytest .param (INT_BOXES , INT_BOXES2 , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1229+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1230+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
12021231 ],
12031232 )
1204- def test_iou (self , test_input , dtypes , atol , expected ):
1205- self ._run_test (ops .generalized_box_iou , test_input , dtypes , atol , expected )
1233+ def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1234+ self ._run_test (ops .generalized_box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
12061235
12071236 def test_iou_jit (self ):
12081237 self ._run_jit_test (ops .generalized_box_iou , INT_BOXES )
12091238
1239+ def test_iou_cartesian (self ):
1240+ self ._run_cartesian_test (ops .generalized_box_iou )
1241+
12101242
12111243class TestDistanceBoxIoU (TestIouBase ):
1212- int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]]
1244+ int_expected = [
1245+ [1.0000 , 0.1875 , - 0.4444 ],
1246+ [0.1875 , 1.0000 , - 0.5625 ],
1247+ [- 0.4444 , - 0.5625 , 1.0000 ],
1248+ [- 0.0781 , 0.1875 , - 0.6267 ],
1249+ ]
12131250 float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
12141251
12151252 @pytest .mark .parametrize (
1216- "test_input , dtypes, atol, expected" ,
1253+ "actual_box1, actual_box2 , dtypes, atol, expected" ,
12171254 [
1218- pytest .param (INT_BOXES , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1219- pytest .param (FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1220- pytest .param (FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
1255+ pytest .param (INT_BOXES , INT_BOXES2 , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1256+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1257+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
12211258 ],
12221259 )
1223- def test_iou (self , test_input , dtypes , atol , expected ):
1224- self ._run_test (ops .distance_box_iou , test_input , dtypes , atol , expected )
1260+ def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1261+ self ._run_test (ops .distance_box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
12251262
12261263 def test_iou_jit (self ):
12271264 self ._run_jit_test (ops .distance_box_iou , INT_BOXES )
12281265
1266+ def test_iou_cartesian (self ):
1267+ self ._run_cartesian_test (ops .distance_box_iou )
1268+
12291269
12301270class TestCompleteBoxIou (TestIouBase ):
1231- int_expected = [[1.0 , 0.25 , 0.0 ], [0.25 , 1.0 , 0.0 ], [0.0 , 0.0 , 1.0 ]]
1271+ int_expected = [
1272+ [1.0000 , 0.1875 , - 0.4444 ],
1273+ [0.1875 , 1.0000 , - 0.5625 ],
1274+ [- 0.4444 , - 0.5625 , 1.0000 ],
1275+ [- 0.0781 , 0.1875 , - 0.6267 ],
1276+ ]
12321277 float_expected = [[1.0 , 0.9933 , 0.9673 ], [0.9933 , 1.0 , 0.9737 ], [0.9673 , 0.9737 , 1.0 ]]
12331278
12341279 @pytest .mark .parametrize (
1235- "test_input , dtypes, atol, expected" ,
1280+ "actual_box1, actual_box2 , dtypes, atol, expected" ,
12361281 [
1237- pytest .param (INT_BOXES , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1238- pytest .param (FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1239- pytest .param (FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
1282+ pytest .param (INT_BOXES , INT_BOXES2 , [torch .int16 , torch .int32 , torch .int64 ], 1e-4 , int_expected ),
1283+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float16 ], 0.002 , float_expected ),
1284+ pytest .param (FLOAT_BOXES , FLOAT_BOXES , [torch .float32 , torch .float64 ], 1e-3 , float_expected ),
12401285 ],
12411286 )
1242- def test_iou (self , test_input , dtypes , atol , expected ):
1243- self ._run_test (ops .complete_box_iou , test_input , dtypes , atol , expected )
1287+ def test_iou (self , actual_box1 , actual_box2 , dtypes , atol , expected ):
1288+ self ._run_test (ops .complete_box_iou , actual_box1 , actual_box2 , dtypes , atol , expected )
12441289
12451290 def test_iou_jit (self ):
12461291 self ._run_jit_test (ops .complete_box_iou , INT_BOXES )
12471292
1293+ def test_iou_cartesian (self ):
1294+ self ._run_cartesian_test (ops .complete_box_iou )
1295+
12481296
12491297def get_boxes (dtype , device ):
12501298 box1 = torch .tensor ([- 1 , - 1 , 1 , 1 ], dtype = dtype , device = device )
0 commit comments