Skip to content
Merged
130 changes: 89 additions & 41 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,14 +1111,6 @@ def test_bbox_convert_jit(self):
torch.testing.assert_close(scripted_cxcywh, box_cxcywh)


INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]


class TestBoxArea:
def area_check(self, box, expected, atol=1e-4):
out = ops.box_area(box)
Expand Down Expand Up @@ -1152,99 +1144,155 @@ def test_box_area_jit(self):
torch.testing.assert_close(scripted_area, expected)


INT_BOXES = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300], [0, 0, 25, 25]]
INT_BOXES2 = [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]
FLOAT_BOXES = [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]


def gen_box(size, dtype=torch.float):
xy1 = torch.rand((size, 2), dtype=dtype)
xy2 = xy1 + torch.rand((size, 2), dtype=dtype)
return torch.cat([xy1, xy2], axis=-1)


class TestIouBase:
@staticmethod
def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], atol: float, expected: List):
def _run_test(target_fn: Callable, actual_box1, actual_box2, dtypes, atol, expected):
for dtype in dtypes:
actual_box = torch.tensor(test_input, dtype=dtype)
actual_box1 = torch.tensor(actual_box1, dtype=dtype)
actual_box2 = torch.tensor(actual_box2, dtype=dtype)
expected_box = torch.tensor(expected)
out = target_fn(actual_box, actual_box)
out = target_fn(actual_box1, actual_box2)
torch.testing.assert_close(out, expected_box, rtol=0.0, check_dtype=False, atol=atol)

@staticmethod
def _run_jit_test(target_fn: Callable, test_input: List):
box_tensor = torch.tensor(test_input, dtype=torch.float)
def _run_jit_test(target_fn: Callable, actual_box: List):
box_tensor = torch.tensor(actual_box, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
torch.testing.assert_close(scripted_out, expected)

@staticmethod
def _cartesian_product(boxes1, boxes2, target_fn: Callable):
N = boxes1.size(0)
M = boxes2.size(0)
result = torch.zeros((N, M))
for i in range(N):
for j in range(M):
result[i, j] = target_fn(boxes1[i].unsqueeze(0), boxes2[j].unsqueeze(0))
return result

@staticmethod
def _run_cartesian_test(target_fn: Callable):
boxes1 = gen_box(5)
boxes2 = gen_box(7)
a = TestIouBase._cartesian_product(boxes1, boxes2, target_fn)
b = target_fn(boxes1, boxes2)
assert torch.allclose(a, b)


class TestBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0], [0.0625, 0.25, 0.0]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.box_iou)


class TestGeneralizedBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]
int_expected = [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0], [0.0625, 0.25, -0.8819]]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.generalized_box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.generalized_box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.generalized_box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.generalized_box_iou)


class TestDistanceBoxIoU(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
int_expected = [
[1.0000, 0.1875, -0.4444],
[0.1875, 1.0000, -0.5625],
[-0.4444, -0.5625, 1.0000],
[-0.0781, 0.1875, -0.6267],
]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.distance_box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.distance_box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.distance_box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.distance_box_iou)


class TestCompleteBoxIou(TestIouBase):
int_expected = [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]
int_expected = [
[1.0000, 0.1875, -0.4444],
[0.1875, 1.0000, -0.5625],
[-0.4444, -0.5625, 1.0000],
[-0.0781, 0.1875, -0.6267],
]
float_expected = [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, atol, expected",
"actual_box1, actual_box2, dtypes, atol, expected",
[
pytest.param(INT_BOXES, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
pytest.param(INT_BOXES, INT_BOXES2, [torch.int16, torch.int32, torch.int64], 1e-4, int_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float16], 0.002, float_expected),
pytest.param(FLOAT_BOXES, FLOAT_BOXES, [torch.float32, torch.float64], 1e-3, float_expected),
],
)
def test_iou(self, test_input, dtypes, atol, expected):
self._run_test(ops.complete_box_iou, test_input, dtypes, atol, expected)
def test_iou(self, actual_box1, actual_box2, dtypes, atol, expected):
self._run_test(ops.complete_box_iou, actual_box1, actual_box2, dtypes, atol, expected)

def test_iou_jit(self):
self._run_jit_test(ops.complete_box_iou, INT_BOXES)

def test_iou_cartesian(self):
self._run_cartesian_test(ops.complete_box_iou)


def get_boxes(dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
Expand Down
12 changes: 7 additions & 5 deletions torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,13 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso

diou, iou = _box_diou_iou(boxes1, boxes2, eps)

w_pred = boxes1[:, 2] - boxes1[:, 0]
h_pred = boxes1[:, 3] - boxes1[:, 1]
w_pred = boxes1[:, None, 2] - boxes1[:, None, 0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this is needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary for the broadcasting to work properly. See also:

lti = torch.min(boxes1[:, None, :2], boxes2[:, :2])
rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2]
diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps

h_pred = boxes1[:, None, 3] - boxes1[:, None, 1]

w_gt = boxes2[:, 2] - boxes2[:, 0]
h_gt = boxes2[:, 3] - boxes2[:, 1]

v = (4 / (torch.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2)
v = (4 / (torch.pi**2)) * torch.pow(torch.atan(w_pred / h_pred) - torch.atan(w_gt / h_gt), 2)
with torch.no_grad():
alpha = v / (1 - iou + v + eps)
return diou - alpha * v
Expand All @@ -358,7 +358,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso

boxes1 = _upcast(boxes1)
boxes2 = _upcast(boxes2)
diou, _ = _box_diou_iou(boxes1, boxes2)
diou, _ = _box_diou_iou(boxes1, boxes2, eps=eps)
return diou


Expand All @@ -375,7 +375,9 @@ def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Te
x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2
y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2
# The distance between boxes' centers squared.
centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2)
centers_distance_squared = (_upcast((x_p[:, None] - x_g[None, :])) ** 2) + (
_upcast((y_p[:, None] - y_g[None, :])) ** 2
)
# The distance IoU is the IoU penalized by a normalized
# distance between boxes' centers squared.
return iou - (centers_distance_squared / diagonal_distance_squared), iou
Expand Down
6 changes: 3 additions & 3 deletions torchvision/ops/ciou_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def complete_box_iou_loss(

"""
Gradient-friendly IoU loss with an additional penalty that is non-zero when the
boxes do not overlap overlap area, This loss function considers important geometrical
factors such as overlap area, normalized central point distance and aspect ratio.
boxes do not overlap. This loss function considers important geometrical
factors such as overlap area, normalized central point distance and aspect ratio.
This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable.

Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
Expand All @@ -35,7 +35,7 @@ def complete_box_iou_loss(
Tensor: Loss tensor with the reduction option applied.

Reference:
Zhaohui Zheng et. al: Complete Intersection over Union Loss:
Zhaohui Zheng et al.: Complete Intersection over Union Loss:
https://arxiv.org/abs/1911.08287

"""
Expand Down