Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorflow_addons/losses/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
| focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 |
| lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 |
| npairs | NpairsLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf |
| npairs | NpairsMultilabelLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf |
| sparsemax_loss | SparsemaxLoss | https://arxiv.org/abs/1602.02068 |
| triplet | TripletSemiHardLoss | https://arxiv.org/abs/1503.03832 |

Expand Down
2 changes: 1 addition & 1 deletion tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@
from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss
from tensorflow_addons.losses.focal_loss import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy
from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss
from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss
from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss, npairs_multilabel_loss, NpairsMultilabelLoss
from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss
from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss
110 changes: 110 additions & 0 deletions tensorflow_addons/losses/npairs.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,71 @@ def npairs_loss(y_true, y_pred):
return tf.math.reduce_mean(loss)


@keras_utils.register_keras_custom_object
@tf.function
def npairs_multilabel_loss(y_true, y_pred):
"""Computes the npairs loss between multilabel data `y_true` and `y_pred`.

Npairs loss expects paired data where a pair is composed of samples from
the same labels and each pairs in the minibatch have different labels.
The loss takes each row of the pair-wise similarity matrix, `y_pred`,
as logits and the remapped multi-class labels, `y_true`, as labels.

To deal with multilabel inputs, the count of label intersection
is computed as follows:

```
L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) |
```

Each row of the count based label matrix is further normalized so that
each row sums to one.

`y_true` should be a binary indicator for classes.
That is, if `y_true[i, j] = 1`, then `i`th sample is in `j`th class;
if `y_true[i, j] = 0`, then `i`th sample is not in `j`th class.

The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Args:
y_true: Either 2-D integer `Tensor` with shape
`[batch_size, num_classes]`, or `SparseTensor` with dense shape
`[batch_size, num_classes]`. If `y_true` is a `SparseTensor`, then
it will be converted to `Tensor` via `tf.sparse.to_dense` first.

y_pred: 2-D float `Tensor` with shape `[batch_size, batch_size]` of
similarity matrix between embedding matrices.

Returns:
npairs_multilabel_loss: float scalar.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)

# Convert to dense tensor if `y_true` is a `SparseTensor`
if isinstance(y_true, tf.SparseTensor):
y_true = tf.sparse.to_dense(y_true)

# Enable efficient multiplication because y_true contains lots of zeros
# https://www.tensorflow.org/api_docs/python/tf/linalg/matmul
y_true = tf.linalg.matmul(
y_true, y_true, transpose_b=True, a_is_sparse=True, b_is_sparse=True)
y_true /= tf.math.reduce_sum(y_true, 1, keepdims=True)

loss = tf.nn.softmax_cross_entropy_with_logits(
logits=y_pred, labels=y_true)

return tf.math.reduce_mean(loss)


@keras_utils.register_keras_custom_object
class NpairsLoss(tf.keras.losses.Loss):
"""Computes the npairs loss between `y_true` and `y_pred`.
Expand Down Expand Up @@ -93,3 +158,48 @@ def __init__(self, name="npairs_loss"):

def call(self, y_true, y_pred):
return npairs_loss(y_true, y_pred)


@keras_utils.register_keras_custom_object
class NpairsMultilabelLoss(tf.keras.losses.Loss):
"""Computes the npairs loss between multilabel data `y_true` and `y_pred`.

Npairs loss expects paired data where a pair is composed of samples from
the same labels and each pairs in the minibatch have different labels.
The loss takes each row of the pair-wise similarity matrix, `y_pred`,
as logits and the remapped multi-class labels, `y_true`, as labels.

To deal with multilabel inputs, the count of label intersection
is computed as follows:

```
L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) |
```

Each row of the count based label matrix is further normalized so that
each row sums to one.

`y_true` should be a binary indicator for classes.
That is, if `y_true[i, j] = 1`, then `i`th sample is in `j`th class;
if `y_true[i, j] = 0`, then `i`th sample is not in `j`th class.

The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Args:
name: (Optional) name for the loss.
"""

def __init__(self, name="npairs_multilabel_loss"):
super(NpairsMultilabelLoss, self).__init__(
reduction=tf.keras.losses.Reduction.NONE, name=name)

def call(self, y_true, y_pred):
return npairs_multilabel_loss(y_true, y_pred)
79 changes: 79 additions & 0 deletions tensorflow_addons/losses/npairs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,84 @@ def test_unweighted(self):
self.assertAllClose(loss, 0.253856)


@test_utils.run_all_in_graph_and_eager_modes
class NpairsMultilabelLossTest(tf.test.TestCase):
def config(self):
nml_obj = npairs.NpairsMultilabelLoss(name="nml")
self.assertEqual(nml_obj.name, "nml")
self.assertEqual(nml_obj.reduction, tf.keras.losses.Reduction.NONE)

def test_single_label(self):
"""Test single label, which is the same with `NpairsLoss`."""
nml_obj = npairs.NpairsMultilabelLoss()
# batch size = 4, hidden size = 2
y_true = tf.constant(
[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
dtype=tf.int64)
# features of anchors
f = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# features of positive samples
fp = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# similarity matrix
y_pred = tf.matmul(f, fp, transpose_a=False, transpose_b=True)
loss = nml_obj(y_true, y_pred)

# Loss = 1/4 * \sum_i log(1 + \sum_{j != i} exp(f_i*fp_j^T-f_i*f_i^T))
# Compute loss for i = 0, 1, 2, 3 without multiplier 1/4
# i = 0 => log(1 + sum([exp(-2), exp(-2), exp(-4)])) = 0.253846
# i = 1 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846
# i = 2 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846
# i = 3 => log(1 + sum([exp(-4), exp(-2), exp(-2)])) = 0.253846
# Loss = (0.253856 + 0.253856 + 0.253856 + 0.253856) / 4 = 0.253856

self.assertAllClose(loss, 0.253856)

# Test sparse tensor
y_true = tf.sparse.from_dense(y_true)
loss = nml_obj(y_true, y_pred)
self.assertAllClose(loss, 0.253856)

def test_multilabel(self):
nml_obj = npairs.NpairsMultilabelLoss()
# batch size = 4, hidden size = 2
y_true = tf.constant(
[[1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]],
dtype=tf.int64)
# features of anchors
f = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# features of positive samples
fp = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# similarity matrix
y_pred = tf.matmul(f, fp, transpose_a=False, transpose_b=True)
loss = nml_obj(y_true, y_pred)

# Loss = \sum_i log(1 + \sum_{j != i} exp(f_i*fp_j^T-f_i*f_i^T))
# Because of multilabel, the label matrix is normalized so that each
# row sums to one. That's why the multiplier before log exists.
# Compute loss for i = 0, 1, 2, 3 without multiplier 1/4
# i = 0 => 2/3 * log(1 + sum([exp(-2), exp(-2), exp(-4)])) +
# 1/3 * log(1 + sum([exp(2) , exp(0) , exp(-2)])) = 0.920522
# i = 1 => 1/4 * log(1 + sum([exp(2) , exp(-2), exp(0) ])) +
# 1/2 * log(1 + sum([exp(-2), exp(-4), exp(-2)])) +
# 1/4 * log(1 + sum([exp(2) , exp(4) , exp(2) ])) = 1.753856
# i = 2 => 1/4 * log(1 + sum([exp(2) , exp(4) , exp(2) ])) +
# 1/2 * log(1 + sum([exp(-2), exp(-4), exp(-2)])) +
# 1/4 * log(1 + sum([exp(0) , exp(-2), exp(2) ])) = 1.753856
# i = 4 => 1/2 * log(1 + sum([exp(-2), exp(0) , exp(2) ])) +
# 1/2 * log(1 + sum([exp(-4), exp(-2), exp(-2)])) = 1.253856
# Loss = (0.920522 + 1.753856 + 1.753856 + 1.253856) / 4 = 1.420522

self.assertAllClose(loss, 1.420522)

# Test sparse tensor
y_true = tf.sparse.from_dense(y_true)
loss = nml_obj(y_true, y_pred)
self.assertAllClose(loss, 1.420522)


if __name__ == "__main__":
tf.test.main()