From 3bacab35f499e37c6b6eaa3379eebb3577a86736 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 30 Aug 2019 20:59:28 +0800 Subject: [PATCH 1/2] migrate npairs multilabel loss --- tensorflow_addons/losses/README.md | 1 + tensorflow_addons/losses/__init__.py | 2 +- tensorflow_addons/losses/npairs.py | 110 ++++++++++++++++++++++++ tensorflow_addons/losses/npairs_test.py | 79 +++++++++++++++++ 4 files changed, 191 insertions(+), 1 deletion(-) diff --git a/tensorflow_addons/losses/README.md b/tensorflow_addons/losses/README.md index e0951d41c1..c1ed11c35d 100644 --- a/tensorflow_addons/losses/README.md +++ b/tensorflow_addons/losses/README.md @@ -17,6 +17,7 @@ | focal_loss | SigmoidFocalCrossEntropy | https://arxiv.org/abs/1708.02002 | | lifted | LiftedStructLoss | https://arxiv.org/abs/1511.06452 | | npairs | NpairsLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf | +| npairs | NpairsMultilabelLoss | http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf | | sparsemax_loss | SparsemaxLoss | https://arxiv.org/abs/1602.02068 | | triplet | TripletSemiHardLoss | https://arxiv.org/abs/1503.03832 | diff --git a/tensorflow_addons/losses/__init__.py b/tensorflow_addons/losses/__init__.py index ce94d7b91e..ff8e5094fa 100644 --- a/tensorflow_addons/losses/__init__.py +++ b/tensorflow_addons/losses/__init__.py @@ -21,6 +21,6 @@ from tensorflow_addons.losses.contrastive import contrastive_loss, ContrastiveLoss from tensorflow_addons.losses.focal_loss import sigmoid_focal_crossentropy, SigmoidFocalCrossEntropy from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss -from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss +from tensorflow_addons.losses.npairs import npairs_loss, NpairsLoss, npairs_multilabel_loss, NpairsMultilabelLoss from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss diff --git a/tensorflow_addons/losses/npairs.py b/tensorflow_addons/losses/npairs.py index adba81566e..35c534eb55 100644 --- a/tensorflow_addons/losses/npairs.py +++ b/tensorflow_addons/losses/npairs.py @@ -64,6 +64,71 @@ def npairs_loss(y_true, y_pred): return tf.math.reduce_mean(loss) +@keras_utils.register_keras_custom_object +@tf.function +def npairs_multilabel_loss(y_true, y_pred): + """Computes the npairs loss between multilabel data `y_true` and `y_pred`. + + Npairs loss expects paired data where a pair is composed of samples from + the same labels and each pairs in the minibatch have different labels. + The loss takes each row of the pair-wise similarity matrix, `y_pred`, + as logits and the remapped multi-class labels, `y_true`, as labels. + + To deal with multilabel inputs, the count of label intersection + is computed as follows: + + ``` + L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) | + ``` + + Each row of the count based label matrix is further normalized so that + each row sums to one. + + `y_true` should be a binary indicator for classes. + That is, if `y_true[i, j] = 1`, then `i`th sample is in `j`th class; + if `y_true[i, j] = 0`, then `i`th sample is not in `j`th class. + + The similarity matrix `y_pred` between two embedding matrices `a` and `b` + with shape `[batch_size, hidden_size]` can be computed as follows: + + ```python + # y_pred = a * b^T + y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True) + ``` + + See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf + + Args: + y_true: Either 2-D integer `Tensor` with shape + `[batch_size, num_classes]`, or `SparseTensor` with dense shape + `[batch_size, num_classes]`. If `y_true` is a `SparseTensor`, then + it will be converted to `Tensor` via `tf.sparse.to_dense` first. + + y_pred: 2-D float `Tensor` with shape `[batch_size, batch_size]` of + similarity matrix between embedding matrices. + + Returns: + npairs_mutlilabel_loss: float scalar. + """ + y_pred = tf.convert_to_tensor(y_pred) + y_true = tf.cast(y_true, y_pred.dtype) + + # Convert to dense tensor if `y_true` is a `SparseTensor` + if isinstance(y_true, tf.SparseTensor): + y_true = tf.sparse.to_dense(y_true) + + # Enable efficient multiplication because y_true contains lots of zeros + # https://www.tensorflow.org/api_docs/python/tf/linalg/matmul + y_true = tf.linalg.matmul( + y_true, y_true, transpose_b=True, a_is_sparse=True, b_is_sparse=True) + y_true /= tf.math.reduce_sum(y_true, 1, keepdims=True) + + loss = tf.nn.softmax_cross_entropy_with_logits( + logits=y_pred, labels=y_true) + + return tf.math.reduce_mean(loss) + + @keras_utils.register_keras_custom_object class NpairsLoss(tf.keras.losses.Loss): """Computes the npairs loss between `y_true` and `y_pred`. @@ -93,3 +158,48 @@ def __init__(self, name="npairs_loss"): def call(self, y_true, y_pred): return npairs_loss(y_true, y_pred) + + +@keras_utils.register_keras_custom_object +class NpairsMultilabelLoss(tf.keras.losses.Loss): + """Computes the npairs loss between multilabel data `y_true` and `y_pred`. + + Npairs loss expects paired data where a pair is composed of samples from + the same labels and each pairs in the minibatch have different labels. + The loss takes each row of the pair-wise similarity matrix, `y_pred`, + as logits and the remapped multi-class labels, `y_true`, as labels. + + To deal with multilabel inputs, the count of label intersection + is computed as follows: + + ``` + L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) | + ``` + + Each row of the count based label matrix is further normalized so that + each row sums to one. + + `y_true` should be a binary indicator for classes. + That is, if `y_true[i, j] = 1`, then `i`th sample is in `j`th class; + if `y_true[i, j] = 0`, then `i`th sample is not in `j`th class. + + The similarity matrix `y_pred` between two embedding matrices `a` and `b` + with shape `[batch_size, hidden_size]` can be computed as follows: + + ```python + # y_pred = a * b^T + y_pred = tf.matmul(a, b, transpose_a=False, transpose_b=True) + ``` + + See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf + + Args: + name: (Optional) name for the loss. + """ + + def __init__(self, name="npairs_multilabel_loss"): + super(NpairsMultilabelLoss, self).__init__( + reduction=tf.keras.losses.Reduction.NONE, name=name) + + def call(self, y_true, y_pred): + return npairs_multilabel_loss(y_true, y_pred) diff --git a/tensorflow_addons/losses/npairs_test.py b/tensorflow_addons/losses/npairs_test.py index 0f0ecc12b3..043c7d983d 100644 --- a/tensorflow_addons/losses/npairs_test.py +++ b/tensorflow_addons/losses/npairs_test.py @@ -54,5 +54,84 @@ def test_unweighted(self): self.assertAllClose(loss, 0.253856) +@test_utils.run_all_in_graph_and_eager_modes +class NpairsMultilabelLossTest(tf.test.TestCase): + def config(self): + nml_obj = npairs.NpairsMultilabelLoss(name="nml") + self.assertEqual(nml_obj.name, "nml") + self.assertEqual(nml_obj.reduction, tf.keras.losses.Reduction.NONE) + + def test_single_label(self): + """Test single label, which is the same with `NpairsLoss`.""" + nml_obj = npairs.NpairsMultilabelLoss() + # batch size = 4, hidden size = 2 + y_true = tf.constant( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype=tf.int64) + # features of anchors + f = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]], + dtype=tf.float32) + # features of positive samples + fp = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]], + dtype=tf.float32) + # similarity matrix + y_pred = tf.matmul(f, fp, transpose_a=False, transpose_b=True) + loss = nml_obj(y_true, y_pred) + + # Loss = 1/4 * \sum_i log(1 + \sum_{j != i} exp(f_i*fp_j^T-f_i*f_i^T)) + # Compute loss for i = 0, 1, 2, 3 without multiplier 1/4 + # i = 0 => log(1 + sum([exp(-2), exp(-2), exp(-4)])) = 0.253846 + # i = 1 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846 + # i = 2 => log(1 + sum([exp(-2), exp(-4), exp(-2)])) = 0.253846 + # i = 3 => log(1 + sum([exp(-4), exp(-2), exp(-2)])) = 0.253846 + # Loss = (0.253856 + 0.253856 + 0.253856 + 0.253856) / 4 = 0.253856 + + self.assertAllClose(loss, 0.253856) + + # Test sparse tensor + y_true = tf.sparse.from_dense(y_true) + loss = nml_obj(y_true, y_pred) + self.assertAllClose(loss, 0.253856) + + def test_multilabel(self): + nml_obj = npairs.NpairsMultilabelLoss() + # batch size = 4, hidden size = 2 + y_true = tf.constant( + [[1, 1, 0, 0], [0, 1, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]], + dtype=tf.int64) + # features of anchors + f = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]], + dtype=tf.float32) + # features of positive samples + fp = tf.constant([[1., 1.], [1., -1.], [-1., 1.], [-1., -1.]], + dtype=tf.float32) + # similarity matrix + y_pred = tf.matmul(f, fp, transpose_a=False, transpose_b=True) + loss = nml_obj(y_true, y_pred) + + # Loss = \sum_i log(1 + \sum_{j != i} exp(f_i*fp_j^T-f_i*f_i^T)) + # Because of multilabel, the label matrix is normalized so that each + # row sums to one. That's why the multiplier before log exists. + # Compute loss for i = 0, 1, 2, 3 without multiplier 1/4 + # i = 0 => 2/3 * log(1 + sum([exp(-2), exp(-2), exp(-4)])) + + # 1/3 * log(1 + sum([exp(2) , exp(0) , exp(-2)])) = 0.920522 + # i = 1 => 1/4 * log(1 + sum([exp(2) , exp(-2), exp(0) ])) + + # 1/2 * log(1 + sum([exp(-2), exp(-4), exp(-2)])) + + # 1/4 * log(1 + sum([exp(2) , exp(4) , exp(2) ])) = 1.753856 + # i = 2 => 1/4 * log(1 + sum([exp(2) , exp(4) , exp(2) ])) + + # 1/2 * log(1 + sum([exp(-2), exp(-4), exp(-2)])) + + # 1/4 * log(1 + sum([exp(0) , exp(-2), exp(2) ])) = 1.753856 + # i = 4 => 1/2 * log(1 + sum([exp(-2), exp(0) , exp(2) ])) + + # 1/2 * log(1 + sum([exp(-4), exp(-2), exp(-2)])) = 1.253856 + # Loss = (0.920522 + 1.753856 + 1.753856 + 1.253856) / 4 = 1.420522 + + self.assertAllClose(loss, 1.420522) + + # Test sparse tensor + y_true = tf.sparse.from_dense(y_true) + loss = nml_obj(y_true, y_pred) + self.assertAllClose(loss, 1.420522) + + if __name__ == "__main__": tf.test.main() From 31af3c5011466439f32f0be8d7cfa636e85ff7f9 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Fri, 30 Aug 2019 21:22:11 +0800 Subject: [PATCH 2/2] fix typo --- tensorflow_addons/losses/npairs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/losses/npairs.py b/tensorflow_addons/losses/npairs.py index 35c534eb55..319fba0fc9 100644 --- a/tensorflow_addons/losses/npairs.py +++ b/tensorflow_addons/losses/npairs.py @@ -108,7 +108,7 @@ def npairs_multilabel_loss(y_true, y_pred): similarity matrix between embedding matrices. Returns: - npairs_mutlilabel_loss: float scalar. + npairs_multilabel_loss: float scalar. """ y_pred = tf.convert_to_tensor(y_pred) y_true = tf.cast(y_true, y_pred.dtype)