Skip to content
7 changes: 6 additions & 1 deletion tensorflow_addons/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@
from tensorflow_addons.losses.giou_loss import giou_loss, GIoULoss
from tensorflow_addons.losses.lifted import lifted_struct_loss, LiftedStructLoss
from tensorflow_addons.losses.sparsemax_loss import sparsemax_loss, SparsemaxLoss
from tensorflow_addons.losses.triplet import triplet_semihard_loss, TripletSemiHardLoss
from tensorflow_addons.losses.triplet import (
triplet_semihard_loss,
triplet_hard_loss,
TripletSemiHardLoss,
TripletHardLoss,
)
from tensorflow_addons.losses.quantiles import pinball_loss, PinballLoss

# Temporarily disable for windows
Expand Down
93 changes: 93 additions & 0 deletions tensorflow_addons/losses/triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import tensorflow as tf
from tensorflow_addons.losses import metric_learning
from typeguard import typechecked


def _masked_maximum(data, mask, dim=1):
Expand Down Expand Up @@ -142,6 +143,57 @@ def triplet_semihard_loss(y_true, y_pred, margin=1.0):
return triplet_loss


@tf.keras.utils.register_keras_serializable(package="Addons")
@tf.function
def triplet_hard_loss(y_true, y_pred, margin=1.0, soft=False):
"""Computes the triplet loss with hard negative and hard positive mining.

Args:
y_true: 1-D integer `Tensor` with shape [batch_size] of
multiclass integer labels.
y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should
be l2 normalized.
margin: Float, margin term in the loss definition.
soft: Boolean, if set, use the soft margin version.
"""
labels, embeddings = y_true, y_pred
# Reshape label tensor to [batch_size, 1].
lshape = tf.shape(labels)
labels = tf.reshape(labels, [lshape[0], 1])

# Build pairwise squared distance matrix.
pdist_matrix = metric_learning.pairwise_distance(embeddings, squared=True)
# Build pairwise binary adjacency matrix.
adjacency = tf.math.equal(labels, tf.transpose(labels))
# Invert so we can select negatives only.
adjacency_not = tf.math.logical_not(adjacency)

adjacency_not = tf.cast(adjacency_not, dtype=tf.dtypes.float32)
# hard negatives: smallest D_an.
hard_negatives = _masked_minimum(pdist_matrix, adjacency_not)

batch_size = tf.size(labels)

adjacency = tf.cast(adjacency, dtype=tf.dtypes.float32)

mask_positives = tf.cast(adjacency, dtype=tf.dtypes.float32) - tf.linalg.diag(
tf.ones([batch_size])
)

# hard positives: largest D_ap.
hard_positives = _masked_maximum(pdist_matrix, mask_positives)

if soft:
triplet_loss = tf.math.log1p(tf.math.exp(hard_positives - hard_negatives))
else:
triplet_loss = tf.maximum(hard_positives - hard_negatives + margin, 0.0)

# Get final mean triplet loss
triplet_loss = tf.reduce_mean(triplet_loss)

return triplet_loss


@tf.keras.utils.register_keras_serializable(package="Addons")
class TripletSemiHardLoss(tf.keras.losses.Loss):
"""Computes the triplet loss with semi-hard negative mining.
Expand Down Expand Up @@ -175,3 +227,44 @@ def get_config(self):
}
base_config = super().get_config()
return {**base_config, **config}


@tf.keras.utils.register_keras_serializable(package="Addons")
class TripletHardLoss(tf.keras.losses.Loss):
"""Computes the triplet loss with hard negative and hard positive mining.

The loss encourages the maximum positive distance (between a pair of embeddings
with the same labels) to be smaller than the minimum negative distance plus the
margin constant in the mini-batch.
The loss selects the hardest positive and the hardest negative samples
within the batch when forming the triplets for computing the loss.
See: https://arxiv.org/pdf/1703.07737.

We expect labels `y_true` to be provided as 1-D integer `Tensor` with shape
[batch_size] of multi-class integer labels. And embeddings `y_pred` must be
2-D float `Tensor` of l2 normalized embedding vectors.

Args:
margin: Float, margin term in the loss definition. Default value is 1.0.
soft: Boolean, if set, use the soft margin version. Default value is False.
name: Optional name for the op.
"""

@typechecked
def __init__(
self, margin: float = 1.0, soft: bool = False, name: str = None, **kwargs
):
super().__init__(name=name, reduction=tf.keras.losses.Reduction.NONE)
self.margin = margin
self.soft = soft

def call(self, y_true, y_pred):
return triplet_hard_loss(y_true, y_pred, self.margin, self.soft)

def get_config(self):
config = {
"margin": self.margin,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also serialize self.soft here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

"soft": self.soft,
}
base_config = super().get_config()
return {**base_config, **config}
87 changes: 87 additions & 0 deletions tensorflow_addons/losses/triplet_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,44 @@ def pairwise_distance_np(feature, squared=False):
return pairwise_distances


def triplet_hard_loss_np(labels, embedding, margin, soft=False):

num_data = embedding.shape[0]
# Reshape labels to compute adjacency matrix.
labels_reshaped = np.reshape(labels.astype(np.float32), (labels.shape[0], 1))
# Compute the loss in NP.
adjacency = np.equal(labels_reshaped, labels_reshaped.T)

pdist_matrix = pairwise_distance_np(embedding, squared=True)
loss_np = 0.0
for i in range(num_data):
pos_distances = []
neg_distances = []
for j in range(num_data):
if adjacency[i][j] == 0:
neg_distances.append(pdist_matrix[i][j])
if adjacency[i][j] > 0.0 and i != j:
pos_distances.append(pdist_matrix[i][j])

# if there are no positive pairs, distance is 0
if len(pos_distances) == 0:
pos_distances.append(0)

# Sort by distance.
neg_distances.sort()
min_neg_distance = neg_distances[0]
pos_distances.sort(reverse=True)
max_pos_distance = pos_distances[0]

if soft:
loss_np += np.log1p(np.exp(max_pos_distance - min_neg_distance))
else:
loss_np += np.maximum(0.0, max_pos_distance - min_neg_distance + margin)

loss_np /= num_data
return loss_np


@test_utils.run_all_in_graph_and_eager_modes
class TripletSemiHardLossTest(tf.test.TestCase):
def test_unweighted(self):
Expand Down Expand Up @@ -114,5 +152,54 @@ def test_serialization(self):
new_loss = tf.keras.losses.deserialize(tf.keras.losses.serialize(loss))


@test_utils.run_all_in_graph_and_eager_modes
class TripletHardLossTest(tf.test.TestCase):
def test_unweighted(self):
num_data = 20
feat_dim = 6
margin = 1.0
num_classes = 4

embedding = np.random.rand(num_data, feat_dim).astype(np.float32)
labels = np.random.randint(0, num_classes, size=(num_data))

loss_np = triplet_hard_loss_np(labels, embedding, margin)

# Compute the loss in TF.
y_true = tf.constant(labels)
y_pred = tf.constant(embedding)
cce_obj = triplet.TripletHardLoss()
loss = cce_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), loss_np, 3)

def test_unweighted_soft(self):
num_data = 20
feat_dim = 6
margin = 1.0
num_classes = 4

embedding = np.random.rand(num_data, feat_dim).astype(np.float32)
labels = np.random.randint(0, num_classes, size=(num_data))

loss_np = triplet_hard_loss_np(labels, embedding, margin, soft=True)

# Compute the loss in TF.
y_true = tf.constant(labels)
y_pred = tf.constant(embedding)
cce_obj = triplet.TripletHardLoss(soft=True)
loss = cce_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), loss_np, 3)

def test_keras_model_compile(self):
model = tf.keras.models.Sequential(
[tf.keras.layers.Input(shape=(784,)), tf.keras.layers.Dense(10),]
)
model.compile(loss="Addons>triplet_hard_loss", optimizer="adam")

def test_serialization(self):
loss = triplet.TripletHardLoss()
tf.keras.losses.deserialize(tf.keras.losses.serialize(loss))


if __name__ == "__main__":
tf.test.main()