diff --git a/tensorflow_addons/losses/__init__.py b/tensorflow_addons/losses/__init__.py index fba1f3138a..c238285d00 100644 --- a/tensorflow_addons/losses/__init__.py +++ b/tensorflow_addons/losses/__init__.py @@ -19,7 +19,12 @@ from tensorflow_addons.losses.giou_loss import giou_loss, GIoULoss from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss -from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss +from tensorflow_addons.losses.triplet import ( + triplet_semihard_loss, + triplet_hard_loss, + TripletSemiHardLoss, + TripletHardLoss, +) from tensorflow_addons.losses.quantiles import pinball_loss, PinballLoss # Temporarily disable for windows diff --git a/tensorflow_addons/losses/triplet.py b/tensorflow_addons/losses/triplet.py index 9ef4c704b2..0e375b08c0 100644 --- a/tensorflow_addons/losses/triplet.py +++ b/tensorflow_addons/losses/triplet.py @@ -16,6 +16,7 @@ import tensorflow as tf from tensorflow_addons.losses import metric_learning +from typeguard import typechecked def _masked_maximum(data, mask, dim=1): @@ -142,6 +143,57 @@ def triplet_semihard_loss(y_true, y_pred, margin=1.0): return triplet_loss +@tf.keras.utils.register_keras_serializable(package="Addons") +@tf.function +def triplet_hard_loss(y_true, y_pred, margin=1.0, soft=False): + """Computes the triplet loss with hard negative and hard positive mining. + + Args: + y_true: 1-D integer `Tensor` with shape [batch_size] of + multiclass integer labels. + y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should + be l2 normalized. + margin: Float, margin term in the loss definition. + soft: Boolean, if set, use the soft margin version. + """ + labels, embeddings = y_true, y_pred + # Reshape label tensor to [batch_size, 1]. + lshape = tf.shape(labels) + labels = tf.reshape(labels, [lshape[0], 1]) + + # Build pairwise squared distance matrix. + pdist_matrix = metric_learning.pairwise_distance(embeddings, squared=True) + # Build pairwise binary adjacency matrix. + adjacency = tf.math.equal(labels, tf.transpose(labels)) + # Invert so we can select negatives only. + adjacency_not = tf.math.logical_not(adjacency) + + adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32) + # hard negatives: smallest D_an. + hard_negatives = _masked_minimum(pdist_matrix, adjacency_not) + + batch_size = tf.size(labels) + + adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32) + + mask_positives = tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag( + tf.ones([batch_size]) + ) + + # hard positives: largest D_ap. + hard_positives = _masked_maximum(pdist_matrix, mask_positives) + + if soft: + triplet_loss = tf.math.log1p(tf.math.exp(hard_positives - hard_negatives)) + else: + triplet_loss = tf.maximum(hard_positives - hard_negatives + margin, 0.0) + + # Get final mean triplet loss + triplet_loss = tf.reduce_mean(triplet_loss) + + return triplet_loss + + @tf.keras.utils.register_keras_serializable(package="Addons") class TripletSemiHardLoss(tf.keras.losses.Loss): """Computes the triplet loss with semi-hard negative mining. @@ -175,3 +227,44 @@ def get_config(self): } base_config = super().get_config() return {**base_config, **config} + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class TripletHardLoss(tf.keras.losses.Loss): + """Computes the triplet loss with hard negative and hard positive mining. + + The loss encourages the maximum positive distance (between a pair of embeddings + with the same labels) to be smaller than the minimum negative distance plus the + margin constant in the mini-batch. + The loss selects the hardest positive and the hardest negative samples + within the batch when forming the triplets for computing the loss. + See: https://arxiv.org/pdf/1703.07737. + + We expect labels `y_true` to be provided as 1-D integer `Tensor` with shape + [batch_size] of multi-class integer labels. And embeddings `y_pred` must be + 2-D float `Tensor` of l2 normalized embedding vectors. + + Args: + margin: Float, margin term in the loss definition. Default value is 1.0. + soft: Boolean, if set, use the soft margin version. Default value is False. + name: Optional name for the op. + """ + + @typechecked + def __init__( + self, margin: float = 1.0, soft: bool = False, name: str = None, **kwargs + ): + super().__init__(name=name, reduction=tf.keras.losses.Reduction.NONE) + self.margin = margin + self.soft = soft + + def call(self, y_true, y_pred): + return triplet_hard_loss(y_true, y_pred, self.margin, self.soft) + + def get_config(self): + config = { + "margin": self.margin, + "soft": self.soft, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/tensorflow_addons/losses/triplet_test.py b/tensorflow_addons/losses/triplet_test.py index 6760397a68..d6cab02c03 100644 --- a/tensorflow_addons/losses/triplet_test.py +++ b/tensorflow_addons/losses/triplet_test.py @@ -50,6 +50,44 @@ def pairwise_distance_np(feature, squared=False): return pairwise_distances +def triplet_hard_loss_np(labels, embedding, margin, soft=False): + + num_data = embedding.shape[0] + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels.astype(np.float32), (labels.shape[0], 1)) + # Compute the loss in NP. + adjacency = np.equal(labels_reshaped, labels_reshaped.T) + + pdist_matrix = pairwise_distance_np(embedding, squared=True) + loss_np = 0.0 + for i in range(num_data): + pos_distances = [] + neg_distances = [] + for j in range(num_data): + if adjacency[i][j] == 0: + neg_distances.append(pdist_matrix[i][j]) + if adjacency[i][j] > 0.0 and i != j: + pos_distances.append(pdist_matrix[i][j]) + + # if there are no positive pairs, distance is 0 + if len(pos_distances) == 0: + pos_distances.append(0) + + # Sort by distance. + neg_distances.sort() + min_neg_distance = neg_distances[0] + pos_distances.sort(reverse=True) + max_pos_distance = pos_distances[0] + + if soft: + loss_np += np.log1p(np.exp(max_pos_distance - min_neg_distance)) + else: + loss_np += np.maximum(0.0, max_pos_distance - min_neg_distance + margin) + + loss_np /= num_data + return loss_np + + @test_utils.run_all_in_graph_and_eager_modes class TripletSemiHardLossTest(tf.test.TestCase): def test_unweighted(self): @@ -114,5 +152,54 @@ def test_serialization(self): new_loss = tf.keras.losses.deserialize(tf.keras.losses.serialize(loss)) +@test_utils.run_all_in_graph_and_eager_modes +class TripletHardLossTest(tf.test.TestCase): + def test_unweighted(self): + num_data = 20 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint(0, num_classes, size=(num_data)) + + loss_np = triplet_hard_loss_np(labels, embedding, margin) + + # Compute the loss in TF. + y_true = tf.constant(labels) + y_pred = tf.constant(embedding) + cce_obj = triplet.TripletHardLoss() + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(self.evaluate(loss), loss_np, 3) + + def test_unweighted_soft(self): + num_data = 20 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint(0, num_classes, size=(num_data)) + + loss_np = triplet_hard_loss_np(labels, embedding, margin, soft=True) + + # Compute the loss in TF. + y_true = tf.constant(labels) + y_pred = tf.constant(embedding) + cce_obj = triplet.TripletHardLoss(soft=True) + loss = cce_obj(y_true, y_pred) + self.assertAlmostEqual(self.evaluate(loss), loss_np, 3) + + def test_keras_model_compile(self): + model = tf.keras.models.Sequential( + [tf.keras.layers.Input(shape=(784,)), tf.keras.layers.Dense(10),] + ) + model.compile(loss="Addons>triplet_hard_loss", optimizer="adam") + + def test_serialization(self): + loss = triplet.TripletHardLoss() + tf.keras.losses.deserialize(tf.keras.losses.serialize(loss)) + + if __name__ == "__main__": tf.test.main()