77
88from  ..utils  import  _log_api_usage_once 
99from  ._box_convert  import  _box_cxcywh_to_xyxy , _box_xyxy_to_cxcywh , _box_xywh_to_xyxy , _box_xyxy_to_xywh 
10+ from  ._utils  import  _upcast 
1011
1112
1213def  nms (boxes : Tensor , scores : Tensor , iou_threshold : float ) ->  Tensor :
@@ -215,14 +216,6 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
215216    return  boxes 
216217
217218
218- def  _upcast (t : Tensor ) ->  Tensor :
219-     # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type 
220-     if  t .is_floating_point ():
221-         return  t  if  t .dtype  in  (torch .float32 , torch .float64 ) else  t .float ()
222-     else :
223-         return  t  if  t .dtype  in  (torch .int32 , torch .int64 ) else  t .int ()
224- 
225- 
226219def  box_area (boxes : Tensor ) ->  Tensor :
227220    """ 
228221    Computes the area of a set of bounding boxes, which are specified by their 
@@ -330,22 +323,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
330323    boxes1  =  _upcast (boxes1 )
331324    boxes2  =  _upcast (boxes2 )
332325
333-     inter , union  =  _box_inter_union (boxes1 , boxes2 )
334-     iou  =  inter  /  union 
335- 
336-     lti  =  torch .min (boxes1 [:, None , :2 ], boxes2 [:, None , :2 ])
337-     rbi  =  torch .max (boxes1 [:, None , 2 :], boxes2 [:, None , 2 :])
338- 
339-     whi  =  (rbi  -  lti ).clamp (min = 0 )  # [N,M,2] 
340-     diagonal_distance_squared  =  (whi [:, :, 0 ] **  2 ) +  (whi [:, :, 1 ] **  2 ) +  eps 
341- 
342-     # centers of boxes 
343-     x_p  =  (boxes1 [:, 0 ] +  boxes1 [:, 2 ]) /  2 
344-     y_p  =  (boxes1 [:, 1 ] +  boxes1 [:, 3 ]) /  2 
345-     x_g  =  (boxes2 [:, 0 ] +  boxes2 [:, 2 ]) /  2 
346-     y_g  =  (boxes2 [:, 1 ] +  boxes2 [:, 3 ]) /  2 
347-     # The distance between boxes' centers squared. 
348-     centers_distance_squared  =  (x_p  -  x_g ) **  2  +  (y_p  -  y_g ) **  2 
326+     diou , iou  =  _box_diou_iou (boxes1 , boxes2 , eps )
349327
350328    w_pred  =  boxes1 [:, 2 ] -  boxes1 [:, 0 ]
351329    h_pred  =  boxes1 [:, 3 ] -  boxes1 [:, 1 ]
@@ -356,7 +334,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
356334    v  =  (4  /  (torch .pi  **  2 )) *  torch .pow ((torch .atan (w_gt  /  h_gt ) -  torch .atan (w_pred  /  h_pred )), 2 )
357335    with  torch .no_grad ():
358336        alpha  =  v  /  (1  -  iou  +  v  +  eps )
359-     return  iou   -  ( centers_distance_squared   /   diagonal_distance_squared )  -  alpha  *  v 
337+     return  diou  -  alpha  *  v 
360338
361339
362340def  distance_box_iou (boxes1 : Tensor , boxes2 : Tensor , eps : float  =  1e-7 ) ->  Tensor :
@@ -380,27 +358,27 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
380358
381359    boxes1  =  _upcast (boxes1 )
382360    boxes2  =  _upcast (boxes2 )
361+     diou , _  =  _box_diou_iou (boxes1 , boxes2 )
362+     return  diou 
383363
384-     inter , union  =  _box_inter_union (boxes1 , boxes2 )
385-     iou  =  inter  /  union 
386364
365+ def  _box_diou_iou (boxes1 : Tensor , boxes2 : Tensor , eps : float  =  1e-7 ) ->  Tuple [Tensor , Tensor ]:
366+ 
367+     iou  =  box_iou (boxes1 , boxes2 )
387368    lti  =  torch .min (boxes1 [:, None , :2 ], boxes2 [:, :2 ])
388369    rbi  =  torch .max (boxes1 [:, None , 2 :], boxes2 [:, 2 :])
389- 
390370    whi  =  _upcast (rbi  -  lti ).clamp (min = 0 )  # [N,M,2] 
391371    diagonal_distance_squared  =  (whi [:, :, 0 ] **  2 ) +  (whi [:, :, 1 ] **  2 ) +  eps 
392- 
393372    # centers of boxes 
394373    x_p  =  (boxes1 [:, 0 ] +  boxes1 [:, 2 ]) /  2 
395374    y_p  =  (boxes1 [:, 1 ] +  boxes1 [:, 3 ]) /  2 
396375    x_g  =  (boxes2 [:, 0 ] +  boxes2 [:, 2 ]) /  2 
397376    y_g  =  (boxes2 [:, 1 ] +  boxes2 [:, 3 ]) /  2 
398377    # The distance between boxes' centers squared. 
399378    centers_distance_squared  =  (_upcast (x_p  -  x_g ) **  2 ) +  (_upcast (y_p  -  y_g ) **  2 )
400- 
401379    # The distance IoU is the IoU penalized by a normalized 
402380    # distance between boxes' centers squared. 
403-     return  iou  -  (centers_distance_squared  /  diagonal_distance_squared )
381+     return  iou  -  (centers_distance_squared  /  diagonal_distance_squared ),  iou 
404382
405383
406384def  masks_to_boxes (masks : torch .Tensor ) ->  torch .Tensor :
0 commit comments