Skip to content

Commit ae1d707

Browse files
oke-adityapmeierdatumbox
authored
Cleanup ops (#6024)
* Cleanup ops * Address nits Co-authored-by: Philip Meier <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 49ee65f commit ae1d707

File tree

6 files changed

+97
-114
lines changed

6 files changed

+97
-114
lines changed

torchvision/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
remove_small_boxes,
66
clip_boxes_to_image,
77
box_area,
8+
box_convert,
89
box_iou,
910
generalized_box_iou,
1011
distance_box_iou,
1112
complete_box_iou,
1213
masks_to_boxes,
1314
)
14-
from .boxes import box_convert
1515
from .ciou_loss import complete_box_iou_loss
1616
from .deform_conv import deform_conv2d, DeformConv2d
1717
from .diou_loss import distance_box_iou_loss

torchvision/ops/_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,40 @@ def split_normalization_params(
6767
else:
6868
other_params.extend(p for p in module.parameters() if p.requires_grad)
6969
return norm_params, other_params
70+
71+
72+
def _upcast(t: Tensor) -> Tensor:
73+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
74+
if t.is_floating_point():
75+
return t if t.dtype in (torch.float32, torch.float64) else t.float()
76+
else:
77+
return t if t.dtype in (torch.int32, torch.int64) else t.int()
78+
79+
80+
def _upcast_non_float(t: Tensor) -> Tensor:
81+
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
82+
if t.dtype not in (torch.float32, torch.float64):
83+
return t.float()
84+
return t
85+
86+
87+
def _loss_inter_union(
88+
boxes1: torch.Tensor,
89+
boxes2: torch.Tensor,
90+
) -> Tuple[torch.Tensor, torch.Tensor]:
91+
92+
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
93+
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
94+
95+
# Intersection keypoints
96+
xkis1 = torch.max(x1, x1g)
97+
ykis1 = torch.max(y1, y1g)
98+
xkis2 = torch.min(x2, x2g)
99+
ykis2 = torch.min(y2, y2g)
100+
101+
intsctk = torch.zeros_like(x1)
102+
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
103+
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
104+
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
105+
106+
return intsctk, unionk

torchvision/ops/boxes.py

Lines changed: 9 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from ..utils import _log_api_usage_once
99
from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh
10+
from ._utils import _upcast
1011

1112

1213
def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor:
@@ -215,14 +216,6 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
215216
return boxes
216217

217218

218-
def _upcast(t: Tensor) -> Tensor:
219-
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
220-
if t.is_floating_point():
221-
return t if t.dtype in (torch.float32, torch.float64) else t.float()
222-
else:
223-
return t if t.dtype in (torch.int32, torch.int64) else t.int()
224-
225-
226219
def box_area(boxes: Tensor) -> Tensor:
227220
"""
228221
Computes the area of a set of bounding boxes, which are specified by their
@@ -330,22 +323,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
330323
boxes1 = _upcast(boxes1)
331324
boxes2 = _upcast(boxes2)
332325

333-
inter, union = _box_inter_union(boxes1, boxes2)
334-
iou = inter / union
335-
336-
lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2])
337-
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:])
338-
339-
whi = (rbi - lti).clamp(min=0) # [N,M,2]
340-
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
341-
342-
# centers of boxes
343-
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
344-
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
345-
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
346-
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
347-
# The distance between boxes' centers squared.
348-
centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2
326+
diou, iou = _box_diou_iou(boxes1, boxes2, eps)
349327

350328
w_pred = boxes1[:, 2] - boxes1[:, 0]
351329
h_pred = boxes1[:, 3] - boxes1[:, 1]
@@ -356,7 +334,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
356334
v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
357335
with torch.no_grad():
358336
alpha = v / (1 - iou + v + eps)
359-
return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v
337+
return diou - alpha * v
360338

361339

362340
def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor:
@@ -380,27 +358,27 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso
380358

381359
boxes1 = _upcast(boxes1)
382360
boxes2 = _upcast(boxes2)
361+
diou, _ = _box_diou_iou(boxes1, boxes2)
362+
return diou
383363

384-
inter, union = _box_inter_union(boxes1, boxes2)
385-
iou = inter / union
386364

365+
def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]:
366+
367+
iou = box_iou(boxes1, boxes2)
387368
lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
388369
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
389-
390370
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
391371
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps
392-
393372
# centers of boxes
394373
x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2
395374
y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2
396375
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
397376
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
398377
# The distance between boxes' centers squared.
399378
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
400-
401379
# The distance IoU is the IoU penalized by a normalized
402380
# distance between boxes' centers squared.
403-
return iou - (centers_distance_squared / diagonal_distance_squared)
381+
return iou - (centers_distance_squared / diagonal_distance_squared), iou
404382

405383

406384
def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor:

torchvision/ops/ciou_loss.py

Lines changed: 13 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import torch
22

33
from ..utils import _log_api_usage_once
4-
from .giou_loss import _upcast
4+
from ._utils import _upcast_non_float
5+
from .diou_loss import _diou_iou_loss
56

67

78
def complete_box_iou_loss(
@@ -30,50 +31,28 @@ def complete_box_iou_loss(
3031
``'sum'``: The output will be summed. Default: ``'none'``
3132
eps : (float): small number to prevent division by zero. Default: 1e-7
3233
33-
Reference:
34+
Returns:
35+
Tensor: Loss tensor with the reduction option applied.
3436
35-
Complete Intersection over Union Loss (Zhaohui Zheng et. al)
36-
https://arxiv.org/abs/1911.08287
37+
Reference:
38+
Zhaohui Zheng et. al: Complete Intersection over Union Loss:
39+
https://arxiv.org/abs/1911.08287
3740
3841
"""
3942

40-
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
43+
# Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
4144

4245
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
4346
_log_api_usage_once(complete_box_iou_loss)
4447

45-
boxes1 = _upcast(boxes1)
46-
boxes2 = _upcast(boxes2)
48+
boxes1 = _upcast_non_float(boxes1)
49+
boxes2 = _upcast_non_float(boxes2)
50+
51+
diou_loss, iou = _diou_iou_loss(boxes1, boxes2)
4752

4853
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
4954
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
5055

51-
# Intersection keypoints
52-
xkis1 = torch.max(x1, x1g)
53-
ykis1 = torch.max(y1, y1g)
54-
xkis2 = torch.min(x2, x2g)
55-
ykis2 = torch.min(y2, y2g)
56-
57-
intsct = torch.zeros_like(x1)
58-
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
59-
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
60-
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
61-
iou = intsct / union
62-
63-
# smallest enclosing box
64-
xc1 = torch.min(x1, x1g)
65-
yc1 = torch.min(y1, y1g)
66-
xc2 = torch.max(x2, x2g)
67-
yc2 = torch.max(y2, y2g)
68-
diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
69-
70-
# centers of boxes
71-
x_p = (x2 + x1) / 2
72-
y_p = (y2 + y1) / 2
73-
x_g = (x1g + x2g) / 2
74-
y_g = (y1g + y2g) / 2
75-
distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
76-
7756
# width and height of boxes
7857
w_pred = x2 - x1
7958
h_pred = y2 - y1
@@ -83,7 +62,7 @@ def complete_box_iou_loss(
8362
with torch.no_grad():
8463
alpha = v / (1 - iou + v + eps)
8564

86-
loss = 1 - iou + (distance / diag_len) + alpha * v
65+
loss = diou_loss + alpha * v
8766
if reduction == "mean":
8867
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
8968
elif reduction == "sum":

torchvision/ops/diou_loss.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from typing import Tuple
2+
13
import torch
24

35
from ..utils import _log_api_usage_once
4-
from .boxes import _upcast
6+
from ._utils import _loss_inter_union, _upcast_non_float
57

68

79
def distance_box_iou_loss(
@@ -10,6 +12,7 @@ def distance_box_iou_loss(
1012
reduction: str = "none",
1113
eps: float = 1e-7,
1214
) -> torch.Tensor:
15+
1316
"""
1417
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
1518
distance between boxes' centers isn't zero. Indeed, for two exactly overlapping
@@ -37,50 +40,48 @@ def distance_box_iou_loss(
3740
https://arxiv.org/abs/1911.08287
3841
"""
3942

40-
# Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
43+
# Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py
4144

4245
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
4346
_log_api_usage_once(distance_box_iou_loss)
4447

45-
boxes1 = _upcast(boxes1)
46-
boxes2 = _upcast(boxes2)
48+
boxes1 = _upcast_non_float(boxes1)
49+
boxes2 = _upcast_non_float(boxes2)
4750

48-
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
49-
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
51+
loss, _ = _diou_iou_loss(boxes1, boxes2, eps)
5052

51-
# Intersection keypoints
52-
xkis1 = torch.max(x1, x1g)
53-
ykis1 = torch.max(y1, y1g)
54-
xkis2 = torch.min(x2, x2g)
55-
ykis2 = torch.min(y2, y2g)
53+
if reduction == "mean":
54+
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
55+
elif reduction == "sum":
56+
loss = loss.sum()
57+
return loss
5658

57-
intsct = torch.zeros_like(x1)
58-
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
59-
intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
60-
union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps
61-
iou = intsct / union
6259

60+
def _diou_iou_loss(
61+
boxes1: torch.Tensor,
62+
boxes2: torch.Tensor,
63+
eps: float = 1e-7,
64+
) -> Tuple[torch.Tensor, torch.Tensor]:
65+
66+
intsct, union = _loss_inter_union(boxes1, boxes2)
67+
iou = intsct / (union + eps)
6368
# smallest enclosing box
69+
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
70+
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
6471
xc1 = torch.min(x1, x1g)
6572
yc1 = torch.min(y1, y1g)
6673
xc2 = torch.max(x2, x2g)
6774
yc2 = torch.max(y2, y2g)
6875
# The diagonal distance of the smallest enclosing box squared
6976
diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps
70-
7177
# centers of boxes
7278
x_p = (x2 + x1) / 2
7379
y_p = (y2 + y1) / 2
7480
x_g = (x1g + x2g) / 2
7581
y_g = (y1g + y2g) / 2
7682
# The distance between boxes' centers squared.
7783
centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2)
78-
7984
# The distance IoU is the IoU penalized by a normalized
8085
# distance between boxes' centers squared.
8186
loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared)
82-
if reduction == "mean":
83-
loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum()
84-
elif reduction == "sum":
85-
loss = loss.sum()
86-
return loss
87+
return loss, iou

torchvision/ops/giou_loss.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,7 @@
11
import torch
2-
from torch import Tensor
32

43
from ..utils import _log_api_usage_once
5-
6-
7-
def _upcast(t: Tensor) -> Tensor:
8-
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
9-
if t.dtype not in (torch.float32, torch.float64):
10-
return t.float()
11-
return t
4+
from ._utils import _upcast_non_float, _loss_inter_union
125

136

147
def generalized_box_iou_loss(
@@ -17,10 +10,8 @@ def generalized_box_iou_loss(
1710
reduction: str = "none",
1811
eps: float = 1e-7,
1912
) -> torch.Tensor:
20-
"""
21-
Original implementation from
22-
https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py
2313

14+
"""
2415
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
2516
boxes do not overlap and scales with the size of their smallest enclosing box.
2617
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.
@@ -38,31 +29,28 @@ def generalized_box_iou_loss(
3829
``'sum'``: The output will be summed. Default: ``'none'``
3930
eps (float): small number to prevent division by zero. Default: 1e-7
4031
32+
Returns:
33+
Tensor: Loss tensor with the reduction option applied.
34+
4135
Reference:
4236
Hamid Rezatofighi et. al: Generalized Intersection over Union:
4337
A Metric and A Loss for Bounding Box Regression:
4438
https://arxiv.org/abs/1902.09630
4539
"""
40+
41+
# Original implementation from https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py
42+
4643
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
4744
_log_api_usage_once(generalized_box_iou_loss)
4845

49-
boxes1 = _upcast(boxes1)
50-
boxes2 = _upcast(boxes2)
46+
boxes1 = _upcast_non_float(boxes1)
47+
boxes2 = _upcast_non_float(boxes2)
48+
intsctk, unionk = _loss_inter_union(boxes1, boxes2)
49+
iouk = intsctk / (unionk + eps)
50+
5151
x1, y1, x2, y2 = boxes1.unbind(dim=-1)
5252
x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1)
5353

54-
# Intersection keypoints
55-
xkis1 = torch.max(x1, x1g)
56-
ykis1 = torch.max(y1, y1g)
57-
xkis2 = torch.min(x2, x2g)
58-
ykis2 = torch.min(y2, y2g)
59-
60-
intsctk = torch.zeros_like(x1)
61-
mask = (ykis2 > ykis1) & (xkis2 > xkis1)
62-
intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask])
63-
unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk
64-
iouk = intsctk / (unionk + eps)
65-
6654
# smallest enclosing box
6755
xc1 = torch.min(x1, x1g)
6856
yc1 = torch.min(y1, y1g)

0 commit comments

Comments
 (0)