Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions kerascv/layers/iou_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import tensorflow as tf


class IOUSimilarity(tf.keras.layers.Layer):
Copy link
Contributor

@fchollet fchollet Jun 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Big question here is why a layer -- what workflows will it be used in? We do already have a IoU metric -- would it make sense for IoU to be a metric and a loss instead of a layer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The more I think about it, this is actually pretty different from what 1) giou loss in addons, 2) iou in keras/metrics is doing.
Both 1) and 2) require y_true and y_pred, which needs to be the same shape, i.e.,:
y_pred: [n_boxes, 4]
y_true: [n_boxes, 4]
And the metrics/loss from it is [n_boxes] before reduction

However what we are having here is:
box1: [n_boxes_1, 4]
box2: [n_boxes_2, 4]
and we're computing the bipartite similarities, so the result would be [n_boxes_1, n_boxes_2]

More specifically, we will need to call this layer with call(gt_boxes, anchors), where n_boxes_1 is usually 3-10 (which is also why I added ragged support because we don't know how many), but n_boxes_2 is always fixed (e.g., 8732 in SSD300).
This is really used for different things, so I believe it should be a layer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect the output of this layer to be fed to other layers? Or is it a final output of the model? If the latter, then maybe it should be a metric instead (to be used instead another layer, for instance).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, this should be a metric

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that IOU and derivates could be still interesting also as loss:
https://arxiv.org/abs/1908.03851
https://arxiv.org/abs/1911.08287

"""Defines a IOUSimilarity that calculates the IOU between ground truth boxes and anchors.

Calling the layer with `ground_truth_boxes` and `anchors`, `ground_truth_boxes` can be a batched
`tf.Tensor` or `tf.RaggedTensor`, while `anchors` can be a batched or un-batched `tf.Tensor`.
"""

def __init__(self, name=None, **kwargs):
super(IOUSimilarity, self).__init__(name=name, **kwargs)

def call(self, ground_truth_boxes, anchors):
# ground_truth_box [n_gt_boxes, box_dim] or [batch_size, n_gt_boxes, box_dim]
# anchor [n_anchors, box_dim]
def iou(ground_truth_box, anchor):
# [n_anchors, 1]
y_min_anchors, x_min_anchors, y_max_anchors, x_max_anchors = tf.split(
anchor, num_or_size_splits=4, axis=-1
)
# [n_gt_boxes, 1] or [batch_size, n_gt_boxes, 1]
y_min_gt, x_min_gt, y_max_gt, x_max_gt = tf.split(
ground_truth_box, num_or_size_splits=4, axis=-1
)
# [n_anchors]
anchor_areas = tf.squeeze(
(y_max_anchors - y_min_anchors) * (x_max_anchors - x_min_anchors), [1]
)
# [n_gt_boxes, 1] or [batch_size, n_gt_boxes, 1]
gt_areas = (y_max_gt - y_min_gt) * (x_max_gt - x_min_gt)

# [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
max_y_min = tf.maximum(y_min_gt, tf.transpose(y_min_anchors))
min_y_max = tf.minimum(y_max_gt, tf.transpose(y_max_anchors))
intersect_heights = tf.maximum(
tf.constant(0, dtype=ground_truth_box.dtype), (min_y_max - max_y_min)
)

# [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
max_x_min = tf.maximum(x_min_gt, tf.transpose(x_min_anchors))
min_x_max = tf.minimum(x_max_gt, tf.transpose(x_max_anchors))
intersect_widths = tf.maximum(
tf.constant(0, dtype=ground_truth_box.dtype), (min_x_max - max_x_min)
)

# [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
intersections = intersect_heights * intersect_widths

# [n_gt_boxes, n_anchors] or [batch_size, n_gt_boxes, n_anchors]
unions = gt_areas + anchor_areas - intersections

return tf.cast(tf.truediv(intersections, unions), tf.float32)

if isinstance(ground_truth_boxes, tf.RaggedTensor):
if anchors.shape.ndims == 2:
return tf.map_fn(
lambda x: iou(x, anchors),
elems=ground_truth_boxes,
parallel_iterations=32,
back_prop=False,
fn_output_signature=tf.RaggedTensorSpec(
dtype=tf.float32, ragged_rank=0
),
)
else:
return tf.map_fn(
lambda x: iou(x[0], x[1]),
elems=[ground_truth_boxes, anchors],
parallel_iterations=32,
back_prop=False,
fn_output_signature=tf.RaggedTensorSpec(
dtype=tf.float32, ragged_rank=0
),
)
if anchors.shape.ndims == 2:
return iou(ground_truth_boxes, anchors)
elif anchors.shape.ndims == 3:
return tf.map_fn(
lambda x: iou(x[0], x[1]),
elems=[ground_truth_boxes, anchors],
dtype=tf.float32,
parallel_iterations=32,
back_prop=False,
)
178 changes: 178 additions & 0 deletions tests/kerascv/layers/iou_similarity_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
import numpy as np
import tensorflow as tf
from kerascv.layers.iou_similarity import IOUSimilarity


def test_iou_basic_absolute_coordinate():
# both gt box and two anchors are size 4
# the intersection between gt box and first anchor is 1 and union is 7
# the intersection between gt box and second anchor is 0 and union is 8
gt_boxes = tf.constant([[[0, 2, 2, 4]]])
anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
# batch_size = 1, n_gt_boxes = 1, n_anchors = 2
expected_out = np.asarray([1 / 7, 0]).astype(np.float32).reshape((1, 1, 2))
np.testing.assert_allclose(expected_out, similarity)


def test_iou_basic_normalized_coordinate():
# both gt box and two anchors are size 1
# the intersection between gt box and first anchor is 1 and union is 7
# the intersection between gt box and second anchor is 0 and union is 8
gt_boxes = tf.constant([[[0, 0.5, 0.5, 1.0]]])
anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0, 0.75]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = np.asarray([1 / 7, 0]).astype(np.float32).reshape((1, 1, 2))
np.testing.assert_allclose(expected_out, similarity)


def test_iou_multi_gt_multi_anchor_absolute_coordinate():
# batch_size = 1, n_gt_boxes = 2
# [1, 2, 4]
gt_boxes = tf.constant([[[0, 2, 2, 4], [-1, 1, 1, 3]]])
# [2, 4]
anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = (
np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((1, 2, 2))
)
np.testing.assert_allclose(expected_out, similarity)


def test_iou_batched_gt_multi_anchor_absolute_coordinate():
# batch_size = 2, n_gt_boxes = 1
# [2, 1, 4]
gt_boxes = tf.constant([[[0, 2, 2, 4]], [[-1, 1, 1, 3]]])
# [2, 4]
anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = (
np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((2, 1, 2))
)
np.testing.assert_allclose(expected_out, similarity)


def test_iou_batched_gt_batched_anchor_absolute_coordinate():
# batch_size = 2, n_gt_boxes = 1
# [2, 1, 4]
gt_boxes = tf.constant([[[0, 2, 2, 4]], [[-1, 1, 1, 3]]])
# [2, 1, 4]
anchors = tf.constant([[[1, 1, 3, 3]], [[-2, 1, 0, 3]]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = np.asarray([[1 / 7], [1 / 3]]).astype(np.float32).reshape((2, 1, 1))
np.testing.assert_allclose(expected_out, similarity)


def test_iou_multi_gt_multi_anchor_normalized_coordinate():
# batch_size = 1, n_gt_boxes = 2
# [1, 2, 4]
gt_boxes = tf.constant([[[0.0, 0.5, 0.5, 1.0], [-0.25, 0.25, 0.25, 0.75]]])
# [2, 4]
anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0.0, 0.75]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = (
np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((1, 2, 2))
)
np.testing.assert_allclose(expected_out, similarity)


def test_iou_batched_gt_multi_anchor_normalized_coordinate():
# batch_size = 2, n_gt_boxes = 1
# [2, 1, 4]
gt_boxes = tf.constant([[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75]]])
# [2, 4]
anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0.0, 0.75]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = (
np.asarray([[1 / 7, 0], [0, 1 / 3]]).astype(np.float32).reshape((2, 1, 2))
)
np.testing.assert_allclose(expected_out, similarity)


def test_iou_batched_gt_batched_anchor_normalized_coordinate():
# batch_size = 2, n_gt_boxes = 1
# [2, 1, 4]
gt_boxes = tf.constant([[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75]]])
# [2, 1, 4]
anchors = tf.constant([[[0.25, 0.25, 0.75, 0.75]], [[-0.5, 0.25, 0.0, 0.75]]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = np.asarray([[1 / 7], [1 / 3]]).astype(np.float32).reshape((2, 1, 1))
np.testing.assert_allclose(expected_out, similarity)


def test_iou_large():
# [2, 4]
gt_boxes = tf.constant([[4.0, 3.0, 7.0, 5.0], [5.0, 6.0, 10.0, 7.0]])
# [3, 4]
anchors = tf.constant(
[[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0], [0.0, 0.0, 20.0, 20.0]]
)
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = np.asarray([[2 / 16, 0, 6 / 400], [1 / 16, 0.0, 5 / 400]]).astype(
np.float32
)
np.testing.assert_allclose(expected_out, similarity)


def test_ragged_gt_boxes_multi_anchor_absolute_coordinate():
# [2, ragged, 4]
gt_boxes = tf.ragged.constant(
[[[0, 2, 2, 4]], [[-1, 1, 1, 3], [-1, 1, 2, 3]]], ragged_rank=1
)
# [2, 4]
anchors = tf.constant([[1, 1, 3, 3], [-2, 1, 0, 3]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = tf.ragged.constant([[[1 / 7, 0.0]], [[0.0, 1 / 3], [1 / 4, 1 / 4]]])
np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy())


def test_ragged_gt_boxes_multi_anchor_normalized_coordinate():
# [2, ragged, 4]
gt_boxes = tf.ragged.constant(
[[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75], [-0.25, 0.25, 0.5, 0.75]]],
ragged_rank=1,
)
# [2, 4]
anchors = tf.constant([[0.25, 0.25, 0.75, 0.75], [-0.5, 0.25, 0.0, 0.75]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = tf.ragged.constant([[[1 / 7, 0.0]], [[0.0, 1 / 3], [1 / 4, 1 / 4]]])
np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy())


def test_ragged_gt_boxes_batched_anchor_normalized_coordinate():
# [2, ragged, 4]
gt_boxes = tf.ragged.constant(
[[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75], [-0.25, 0.25, 0.5, 0.75]]],
ragged_rank=1,
)
# [2, 1, 4]
anchors = tf.constant([[[0.25, 0.25, 0.75, 0.75]], [[-0.5, 0.25, 0.0, 0.75]]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = tf.ragged.constant([[[1 / 7]], [[1 / 3], [1 / 4]]])
np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy())


def test_ragged_gt_boxes_empty_anchor():
# [2, ragged, 4]
gt_boxes = tf.ragged.constant(
[[[0.0, 0.5, 0.5, 1.0]], [[-0.25, 0.25, 0.25, 0.75], [-0.25, 0.25, 0.5, 0.75]]],
ragged_rank=1,
)
# [2, 4]
anchors = tf.constant([[0.25, 0.25, 0.25, 0.25], [-0.5, 0.25, 0.0, 0.75]])
iou_layer = IOUSimilarity()
similarity = iou_layer(gt_boxes, anchors)
expected_out = tf.ragged.constant([[[0.0, 0.0]], [[0.0, 1 / 3], [0.0, 1 / 4]]])
np.testing.assert_allclose(expected_out.values.numpy(), similarity.values.numpy())