Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions tensorflow_addons/losses/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ py_library(
"focal_loss.py",
"lifted.py",
"metric_learning.py",
"npairs.py",
"sparsemax_loss.py",
"triplet.py",
],
Expand Down Expand Up @@ -46,6 +47,19 @@ py_test(
],
)

py_test(
name = "npairs_test",
size = "small",
srcs = [
"npairs_test.py",
],
main = "npairs_test.py",
srcs_version = "PY2AND3",
deps = [
":losses",
],
)

py_test(
name = "sparsemax_loss_test",
size = "small",
Expand Down
2 changes: 2 additions & 0 deletions tensorflow_addons/losses/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
| contrastive | @WindQAQ | [email protected] |
| focal_loss | | |
| lifted | | |
| npairs | @WindQAQ | [email protected] |
| sparsemax_loss | @AndreasMadsen | [email protected] |
| triplet | | |

Expand All @@ -15,6 +16,7 @@
| contrastive | ContrastiveLoss | http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf |
| focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 |
| lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 |
| npairs | NpairsLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf |
| sparsemax_loss | SparsemaxLoss | https://arxiv.org/abs/1602.02068 |
| triplet | TripletSemiHardLoss | https://arxiv.org/abs/1503.03832 |

Expand Down
1 change: 1 addition & 0 deletions tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@
from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss
from tensorflow_addons.losses.focal_loss import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy
from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss
from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss
from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss
from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss
95 changes: 95 additions & 0 deletions tensorflow_addons/losses/npairs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implements npairs loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.utils import keras_utils


@keras_utils.register_keras_custom_object
@tf.function
def npairs_loss(y_true, y_pred):
"""Computes the npairs loss between `y_true` and `y_pred`.

Npairs loss expects paired data where a pair is composed of samples from
the same labels and each pairs in the minibatch have different labels.
The loss takes each row of the pair-wise similarity matrix, `y_pred`,
as logits and the remapped multi-class labels, `y_true`, as labels.

The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Args:
y_true: 1-D integer `Tensor` with shape `[batch_size]` of
multi-class labels.
y_pred: 2-D float `Tensor` with shape `[batch_size, batch_size]` of
similarity matrix between embedding matrices.

Returns:
npairs_loss: float scalar.
"""
y_pred = tf.convert_to_tensor(y_pred)
y_true = tf.cast(y_true, y_pred.dtype)

# Expand to [batch_size, 1]
y_true = tf.expand_dims(y_true, -1)
y_true = tf.cast(tf.equal(y_true, tf.transpose(y_true)), y_pred.dtype)
y_true /= tf.math.reduce_sum(y_true, 1, keepdims=True)

loss = tf.nn.softmax_cross_entropy_with_logits(
logits=y_pred, labels=y_true)

return tf.math.reduce_mean(loss)


@keras_utils.register_keras_custom_object
class NpairsLoss(tf.keras.losses.Loss):
"""Computes the npairs loss between `y_true` and `y_pred`.

Npairs loss expects paired data where a pair is composed of samples from
the same labels and each pairs in the minibatch have different labels.
The loss takes each row of the pair-wise similarity matrix, `y_pred`,
as logits and the remapped multi-class labels, `y_true`, as labels.

The similarity matrix `y_pred` between two embedding matrices `a` and `b`
with shape `[batch_size, hidden_size]` can be computed as follows:

```python
# y_pred = a * b^T
y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True)
```

See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf

Args:
name: (Optional) name for the loss.
"""

def __init__(self, name="npairs_loss"):
super(NpairsLoss, self).__init__(
reduction=tf.keras.losses.Reduction.NONE, name=name)

def call(self, y_true, y_pred):
return npairs_loss(y_true, y_pred)
58 changes: 58 additions & 0 deletions tensorflow_addons/losses/npairs_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for npairs loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow_addons.losses import npairs
from tensorflow_addons.utils import test_utils


@test_utils.run_all_in_graph_and_eager_modes
class NpairsLossTest(tf.test.TestCase):
def test_config(self):
nl_obj = npairs.NpairsLoss(name="nl")
self.assertEqual(nl_obj.name, "nl")
self.assertEqual(nl_obj.reduction, tf.keras.losses.Reduction.NONE)

def test_unweighted(self):
nl_obj = npairs.NpairsLoss()
# batch size = 4, hidden size = 2
y_true = tf.constant([0, 1, 2, 3], dtype=tf.int64)
# features of anchors
f = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# features of positive samples
fp = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]],
dtype=tf.float32)
# similarity matrix
y_pred = tf.matmul(f, fp, transpose_a=False, transpose_b=True)
loss = nl_obj(y_true, y_pred)

# Loss = 1/4 * \sum_i log(1 + \sum_{j != i} exp(f_i*fp_j^T-f_i*f_i^T))
# Compute loss for i = 0, 1, 2, 3 without multiplier 1/4
# i = 0 => log(1 + sum([exp(-2), exp(-2), exp(-4)])) = 0.253846
# i = 1 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846
# i = 2 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846
# i = 3 => log(1 + sum([exp(-4), exp(-2), exp(-2)])) = 0.253846
# Loss = (0.253856 + 0.253856 + 0.253856 + 0.253856) / 4 = 0.253856

self.assertAllClose(loss, 0.253856)


if __name__ == "__main__":
tf.test.main()