Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions test/test_ious.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import List, Callable

import pytest
import torch
import torch.fx
from torch import Tensor
from torchvision import ops


class IouTestBase:
@staticmethod
def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List):
def assert_close(box: Tensor, expected: Tensor, tolerance):
out = target_fn(box, box)
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)

for dtype in dtypes:
actual_box = torch.tensor(test_input, dtype=dtype)
expected_box = torch.tensor(expected)
assert_close(actual_box, expected_box, tolerance)

@staticmethod
def _run_jit_test(target_fn: Callable, test_input: List):
box_tensor = torch.tensor(test_input, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3)


def _generate_int_input():
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]


def _generate_float_input():
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]


class TestBoxIou(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_input():
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestGenBoxIou(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestDistanceBoxIoU(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestCompleteBoxIou(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


if __name__ == "__main__":
pytest.main([__file__])
229 changes: 229 additions & 0 deletions test/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
import pytest
import torch
import torch.nn.functional as F
from common_utils import cpu_and_gpu
from torchvision import ops


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making a class, inheriting and the calling method is also possible. For now is this fine?

def get_boxes(dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super happy with this choice of box. Since this is actually invalid input

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with the concern.

Detectron2 used the same set of boxes. see this.

I think we should use valid input boxes.

That being said, should we also check if the input boxes have non-negative values?
What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot assert here as it will lead to cuda call and cause trouble.

I'm not sure if we can use torch._assert_async either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, ig this situation is similar to #5776 (comment)

box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)

box1s = torch.stack([box2, box2], dim=0)
box2s = torch.stack([box3, box4], dim=0)

return box1, box2, box3, box4, box1s, box2s


def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"):
tol = 1e-3 if dtype is torch.half else 1e-5
computed_loss = iou_fn(box1, box2, reduction=reduction)
expected_loss = torch.tensor(expected_loss, device=device)
torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol)


def assert_empty_loss(iou_fn, dtype, device):
box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
loss = iou_fn(box1, box2, reduction="mean")
loss.backward()
tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol)
assert box1.grad is not None, "box1.grad should not be None after backward is called"
assert box2.grad is not None, "box2.grad should not be None after backward is called"
loss = iou_fn(box1, box2, reduction="none")
assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"


class TestGeneralizedBoxIouLoss:
# We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_giou_loss(self, dtype, device):

box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

# Identical boxes should have loss of 0
assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device)

# quarter size box inside other box = IoU of 0.25
assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device)

# Two side by side boxes, area=union
# IoU=0 and GIoU=0 (loss 1.0)
assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device)

# Two diagonally adjacent boxes, area=2*union
# IoU=0 and GIoU=-0.5 (loss 1.5)
assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device)

# Test batched loss and reductions
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum")
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device):
assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)


class TestCIOULoss:
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_ciou_loss(self, dtype, device):
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean")
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device):
assert_empty_loss(ops.complete_box_iou_loss, dtype, device)


class TestDIouLoss:
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_distance_iou_loss(self, dtype, device):
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean")
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_distance_iou_inputs(self, dtype, device):
assert_empty_loss(ops.distance_box_iou_loss, dtype, device)


class TestFocalLoss:
Copy link
Contributor Author

@oke-aditya oke-aditya May 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are testing losses here, I felt to add this here . To avoid confusion between files.

def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
def logit(p):
return torch.log(p / (1 - p))

def generate_tensor_with_range_type(shape, range_type, **kwargs):
if range_type != "random_binary":
low, high = {
"small": (0.0, 0.2),
"big": (0.8, 1.0),
"zeros": (0.0, 0.0),
"ones": (1.0, 1.0),
"random": (0.0, 1.0),
}[range_type]
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
else:
return torch.randint(0, 2, shape, **kwargs)

# This function will return inputs and targets with shape: (shape[0]*9, shape[1])
inputs = []
targets = []
for input_range_type, target_range_type in [
("small", "zeros"),
("small", "ones"),
("small", "random_binary"),
("big", "zeros"),
("big", "ones"),
("big", "random_binary"),
("random", "zeros"),
("random", "ones"),
("random", "random_binary"),
]:
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))

return torch.cat(inputs), torch.cat(targets)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [0, 1])
def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# For testing the ratio with manual calculation, we require the reduction to be "none"
reduction = "none"
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)

assert torch.all(
focal_loss <= ce_loss
), "focal loss must be less or equal to cross entropy loss with same input"

loss_ratio = (focal_loss / ce_loss).squeeze()
prob = torch.sigmoid(inputs)
p_t = prob * targets + (1 - prob) * (1 - targets)
correct_ratio = (1.0 - p_t) ** gamma
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
correct_ratio = correct_ratio * alpha_t

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol)

@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [2, 3])
def test_equal_ce_loss(self, reduction, device, dtype, seed):
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# focal loss should be equal ce_loss if alpha=-1 and gamma=0
alpha = -1
gamma = 0
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
inputs_fl = inputs.clone().requires_grad_()
targets_fl = targets.clone()
inputs_ce = inputs.clone().requires_grad_()
targets_ce = targets.clone()
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol)

focal_loss.backward()
ce_loss.backward()
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [4, 5])
def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
script_fn = torch.jit.script(ops.sigmoid_focal_loss)
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
if device == "cpu":
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
else:
with torch.jit.fuser("fuser2"):
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
# We may remove this condition once the bug is resolved
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)


if __name__ == "__main__":
pytest.main([__file__])
Loading