diff --git a/tensorflow_addons/losses/BUILD b/tensorflow_addons/losses/BUILD index 0c58d96d59..ba62f8eada 100644 --- a/tensorflow_addons/losses/BUILD +++ b/tensorflow_addons/losses/BUILD @@ -10,6 +10,7 @@ py_library( "focal_loss.py", "lifted.py", "metric_learning.py", + "npairs.py", "sparsemax_loss.py", "triplet.py", ], @@ -46,6 +47,19 @@ py_test( ], ) +py_test( + name = "npairs_test", + size = "small", + srcs = [ + "npairs_test.py", + ], + main = "npairs_test.py", + srcs_version = "PY2AND3", + deps = [ + ":losses", + ], +) + py_test( name = "sparsemax_loss_test", size = "small", diff --git a/tensorflow_addons/losses/README.md b/tensorflow_addons/losses/README.md index 03037e2b4d..c6ba6d32e2 100644 --- a/tensorflow_addons/losses/README.md +++ b/tensorflow_addons/losses/README.md @@ -6,6 +6,7 @@ | contrastive | @WindQAQ | windqaq@gmail.com | | focal_loss | | | | lifted | | | +| npairs | @WindQAQ | windqaq@gmail.com | | sparsemax_loss | @AndreasMadsen | amwwebdk+github@gmail.com | | triplet | | | @@ -15,6 +16,7 @@ | contrastive | ContrastiveLoss | http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf | | focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 | | lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 | +| npairs | NpairsLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf | | sparsemax_loss | SparsemaxLoss | https://arxiv.org/abs/1602.02068 | | triplet | TripletSemiHardLoss | https://arxiv.org/abs/1503.03832 | diff --git a/tensorflow_addons/losses/__init__.py b/tensorflow_addons/losses/__init__.py index a552c69d67..ce94d7b91e 100644 --- a/tensorflow_addons/losses/__init__.py +++ b/tensorflow_addons/losses/__init__.py @@ -21,5 +21,6 @@ from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss from tensorflow_addons.losses.focal_loss import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss +from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss diff --git a/tensorflow_addons/losses/npairs.py b/tensorflow_addons/losses/npairs.py new file mode 100644 index 0000000000..adba81566e --- /dev/null +++ b/tensorflow_addons/losses/npairs.py @@ -0,0 +1,95 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements npairs loss.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +@tf.function +def npairs_loss(y_true, y_pred): + """Computes the npairs loss between `y_true` and `y_pred`. + + Npairs loss expects paired data where a pair is composed of samples from + the same labels and each pairs in the minibatch have different labels. + The loss takes each row of the pair-wise similarity matrix, `y_pred`, + as logits and the remapped multi-class labels, `y_true`, as labels. + + The similarity matrix `y_pred` between two embedding matrices `a` and `b` + with shape `[batch_size, hidden_size]` can be computed as follows: + + ```python + # y_pred = a * b^T + y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True) + ``` + + See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf + + Args: + y_true: 1-D integer `Tensor` with shape `[batch_size]` of + multi-class labels. + y_pred: 2-D float `Tensor` with shape `[batch_size, batch_size]` of + similarity matrix between embedding matrices. + + Returns: + npairs_loss: float scalar. + """ + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + + # Expand to [batch_size, 1] + y_true = tf.expand_dims(y_true, -1) + y_true = tf.cast(tf.equal(y_true, tf.transpose(y_true)), y_pred.dtype) + y_true /= tf.math.reduce_sum(y_true, 1, keepdims=True) + + loss = tf.nn.softmax_cross_entropy_with_logits( + logits=y_pred, labels=y_true) + + return tf.math.reduce_mean(loss) + + +@keras_utils.register_keras_custom_object +class NpairsLoss(tf.keras.losses.Loss): + """Computes the npairs loss between `y_true` and `y_pred`. + + Npairs loss expects paired data where a pair is composed of samples from + the same labels and each pairs in the minibatch have different labels. + The loss takes each row of the pair-wise similarity matrix, `y_pred`, + as logits and the remapped multi-class labels, `y_true`, as labels. + + The similarity matrix `y_pred` between two embedding matrices `a` and `b` + with shape `[batch_size, hidden_size]` can be computed as follows: + + ```python + # y_pred = a * b^T + y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True) + ``` + + See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf + + Args: + name: (Optional) name for the loss. + """ + + def __init__(self, name="npairs_loss"): + super(NpairsLoss, self).__init__( + reduction=tf.keras.losses.Reduction.NONE, name=name) + + def call(self, y_true, y_pred): + return npairs_loss(y_true, y_pred) diff --git a/tensorflow_addons/losses/npairs_test.py b/tensorflow_addons/losses/npairs_test.py new file mode 100644 index 0000000000..0f0ecc12b3 --- /dev/null +++ b/tensorflow_addons/losses/npairs_test.py @@ -0,0 +1,58 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for npairs loss.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow_addons.losses import npairs +from tensorflow_addons.utils import test_utils + + +@test_utils.run_all_in_graph_and_eager_modes +class NpairsLossTest(tf.test.TestCase): + def test_config(self): + nl_obj = npairs.NpairsLoss(name="nl") + self.assertEqual(nl_obj.name, "nl") + self.assertEqual(nl_obj.reduction, tf.keras.losses.Reduction.NONE) + + def test_unweighted(self): + nl_obj = npairs.NpairsLoss() + # batch size = 4, hidden size = 2 + y_true = tf.constant([0, 1, 2, 3], dtype=tf.int64) + # features of anchors + f = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]], + dtype=tf.float32) + # features of positive samples + fp = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]], + dtype=tf.float32) + # similarity matrix + y_pred = tf.matmul(f, fp, transpose_a=False, transpose_b=True) + loss = nl_obj(y_true, y_pred) + + # Loss = 1/4 * \sum_i log(1 + \sum_{j != i} exp(f_i*fp_j^T-f_i*f_i^T)) + # Compute loss for i = 0, 1, 2, 3 without multiplier 1/4 + # i = 0 => log(1 + sum([exp(-2), exp(-2), exp(-4)])) = 0.253846 + # i = 1 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846 + # i = 2 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846 + # i = 3 => log(1 + sum([exp(-4), exp(-2), exp(-2)])) = 0.253846 + # Loss = (0.253856 + 0.253856 + 0.253856 + 0.253856) / 4 = 0.253856 + + self.assertAllClose(loss, 0.253856) + + +if __name__ == "__main__": + tf.test.main()