Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7da0a41
Refactor tests
oke-aditya May 16, 2022
ae5c346
Merge branch 'main' of https://github.com/pytorch/vision into refacto…
oke-aditya May 29, 2022
fa27931
Merge branch 'main' of https://github.com/pytorch/vision into refacto…
oke-aditya Jun 2, 2022
9b8df92
Merge branch 'main' of https://github.com/pytorch/vision into refacto…
oke-aditya Jun 6, 2022
7f788f1
Remove tol, fix comments
oke-aditya Jun 6, 2022
6ab501a
Add tolerance only where necessary
oke-aditya Jun 6, 2022
b83d745
Add tolerance only where necessary
oke-aditya Jun 6, 2022
7e49682
Add tolerance only where necessary
oke-aditya Jun 6, 2022
485d1fc
Refactor to adapt suggestions
oke-aditya Jun 6, 2022
108b247
Merge branch 'main' of https://github.com/pytorch/vision into refacto…
oke-aditya Jun 6, 2022
5c8f4fb
Refactor and add nits
oke-aditya Jun 7, 2022
1ed639f
Refactor box area
oke-aditya Jun 7, 2022
fd96c07
Refactor to one file
oke-aditya Jun 7, 2022
aa854ca
Merge branch 'main' of https://github.com/pytorch/vision into refacto…
oke-aditya Jun 8, 2022
5c00ebc
Adapt almost all except area
oke-aditya Jun 8, 2022
141bb68
final update
oke-aditya Jun 10, 2022
395a024
Merge branch 'main' into refactor_ops_tests
oke-aditya Jun 10, 2022
1f183c5
Tighten for jit
oke-aditya Jun 10, 2022
481ba20
Merge branch 'refactor_ops_tests' of github.com:oke-aditya/vision int…
oke-aditya Jun 10, 2022
5171187
Refactor slightly
oke-aditya Jun 21, 2022
bb93929
Merge branch 'main' into refactor_ops_tests
oke-aditya Jun 21, 2022
d7092a2
Merge branch 'main' into refactor_ops_tests
datumbox Jul 25, 2022
c0b4704
Merge branch 'main' into refactor_ops_tests
oke-aditya Jul 26, 2022
8f73645
Fix tests
oke-aditya Jul 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions test/test_ious.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import List, Callable

import pytest
import torch
import torch.fx
from torch import Tensor
from torchvision import ops


class IouTestBase:
@staticmethod
def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List):
def assert_close(box: Tensor, expected: Tensor, tolerance):
out = target_fn(box, box)
torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance)

for dtype in dtypes:
actual_box = torch.tensor(test_input, dtype=dtype)
expected_box = torch.tensor(expected)
assert_close(actual_box, expected_box, tolerance)

@staticmethod
def _run_jit_test(target_fn: Callable, test_input: List):
box_tensor = torch.tensor(test_input, dtype=torch.float)
expected = target_fn(box_tensor, box_tensor)
scripted_fn = torch.jit.script(target_fn)
scripted_out = scripted_fn(box_tensor, box_tensor)
torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3)


def _generate_int_input():
return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]


def _generate_float_input():
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]


class TestBoxIou(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_input():
return [
[285.3538, 185.5758, 1193.5110, 851.4551],
[285.1472, 188.7374, 1192.4984, 851.0669],
[279.2440, 197.9812, 1189.4746, 849.2019],
]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestGenBoxIou(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestDistanceBoxIoU(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


class TestCompleteBoxIou(IouTestBase):
def _generate_int_expected():
return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]]

def _generate_float_expected():
return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]]

@pytest.mark.parametrize(
"test_input, dtypes, tolerance, expected",
[
pytest.param(
_generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected()
),
pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()),
pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()),
],
)
def test_iou(self, test_input, dtypes, tolerance, expected):
self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected)

def test_iou_jit(self):
self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]])


if __name__ == "__main__":
pytest.main([__file__])
226 changes: 226 additions & 0 deletions test/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
import pytest
import torch
import torch.nn.functional as F
from common_utils import cpu_and_gpu
from torchvision import ops


def get_boxes(dtype, device):
box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device)
box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device)
box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device)
box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device)

box1s = torch.stack([box2, box2], dim=0)
box2s = torch.stack([box3, box4], dim=0)

return box1, box2, box3, box4, box1s, box2s


def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"):
computed_loss = iou_fn(box1, box2, reduction=reduction)
expected_loss = torch.tensor(expected_loss, device=device)
torch.testing.assert_close(computed_loss, expected_loss)


def assert_empty_loss(iou_fn, dtype, device):
box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_()
loss = iou_fn(box1, box2, reduction="mean")
loss.backward()
torch.testing.assert_close(loss, torch.tensor(0.0, device=device))
assert box1.grad is not None, "box1.grad should not be None after backward is called"
assert box2.grad is not None, "box2.grad should not be None after backward is called"
loss = iou_fn(box1, box2, reduction="none")
assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty"


class TestGeneralizedBoxIouLoss:
# We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_giou_loss(self, dtype, device):

box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

# Identical boxes should have loss of 0
assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device)

# quarter size box inside other box = IoU of 0.25
assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device)

# Two side by side boxes, area=union
# IoU=0 and GIoU=0 (loss 1.0)
assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device)

# Two diagonally adjacent boxes, area=2*union
# IoU=0 and GIoU=-0.5 (loss 1.5)
assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device)

# Test batched loss and reductions
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum")
assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device):
assert_empty_loss(ops.generalized_box_iou_loss, dtype, device)


class TestCIOULoss:
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_ciou_loss(self, dtype, device):
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device)
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean")
assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_inputs(self, dtype, device):
assert_empty_loss(ops.complete_box_iou_loss, dtype, device)


class TestDIouLoss:
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_distance_iou_loss(self, dtype, device):
box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device)

assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device)
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean")
assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum")

@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
def test_empty_distance_iou_inputs(self, dtype, device):
assert_empty_loss(ops.distance_box_iou_loss, dtype, device)


class TestFocalLoss:
def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs):
def logit(p):
return torch.log(p / (1 - p))

def generate_tensor_with_range_type(shape, range_type, **kwargs):
if range_type != "random_binary":
low, high = {
"small": (0.0, 0.2),
"big": (0.8, 1.0),
"zeros": (0.0, 0.0),
"ones": (1.0, 1.0),
"random": (0.0, 1.0),
}[range_type]
return torch.testing.make_tensor(shape, low=low, high=high, **kwargs)
else:
return torch.randint(0, 2, shape, **kwargs)

# This function will return inputs and targets with shape: (shape[0]*9, shape[1])
inputs = []
targets = []
for input_range_type, target_range_type in [
("small", "zeros"),
("small", "ones"),
("small", "random_binary"),
("big", "zeros"),
("big", "ones"),
("big", "random_binary"),
("random", "zeros"),
("random", "ones"),
("random", "random_binary"),
]:
inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs)))
targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs))

return torch.cat(inputs), torch.cat(targets)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [0, 1])
def test_correct_ratio(self, alpha, gamma, device, dtype, seed):
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# For testing the ratio with manual calculation, we require the reduction to be "none"
reduction = "none"
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction)

assert torch.all(
focal_loss <= ce_loss
), "focal loss must be less or equal to cross entropy loss with same input"

loss_ratio = (focal_loss / ce_loss).squeeze()
prob = torch.sigmoid(inputs)
p_t = prob * targets + (1 - prob) * (1 - targets)
correct_ratio = (1.0 - p_t) ** gamma
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
correct_ratio = correct_ratio * alpha_t

torch.testing.assert_close(correct_ratio, loss_ratio)

@pytest.mark.parametrize("reduction", ["mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [2, 3])
def test_equal_ce_loss(self, reduction, device, dtype, seed):
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
# focal loss should be equal ce_loss if alpha=-1 and gamma=0
alpha = -1
gamma = 0
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
inputs_fl = inputs.clone().requires_grad_()
targets_fl = targets.clone()
inputs_ce = inputs.clone().requires_grad_()
targets_ce = targets.clone()
focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction)
ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, ce_loss, atol=tol, rtol=tol)

focal_loss.backward()
ce_loss.backward()
torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad)

@pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0])
@pytest.mark.parametrize("gamma", [0, 2])
@pytest.mark.parametrize("reduction", ["none", "mean", "sum"])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dtype", [torch.float32, torch.half])
@pytest.mark.parametrize("seed", [4, 5])
def test_jit(self, alpha, gamma, reduction, device, dtype, seed):
if device == "cpu" and dtype is torch.half:
pytest.skip("Currently torch.half is not fully supported on cpu")
script_fn = torch.jit.script(ops.sigmoid_focal_loss)
torch.random.manual_seed(seed)
inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device)
focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
if device == "cpu":
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)
else:
with torch.jit.fuser("fuser2"):
# Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476
# We may remove this condition once the bug is resolved
scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction)

tol = 1e-3 if dtype is torch.half else 1e-5
torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol)


if __name__ == "__main__":
pytest.main([__file__])
Loading