@@ -26,8 +26,8 @@ def _masked_maximum(data, mask, dim=1):
2626 """Computes the axis wise maximum over chosen elements.
2727
2828 Args:
29- data: 2-D float `Tensor` of size [n, m].
30- mask: 2-D Boolean `Tensor` of size [n, m].
29+ data: 2-D float `Tensor` of shape ` [n, m]` .
30+ mask: 2-D Boolean `Tensor` of shape ` [n, m]` .
3131 dim: The dimension over which to compute the maximum.
3232
3333 Returns:
@@ -48,8 +48,8 @@ def _masked_minimum(data, mask, dim=1):
4848 """Computes the axis wise minimum over chosen elements.
4949
5050 Args:
51- data: 2-D float `Tensor` of size [n, m].
52- mask: 2-D Boolean `Tensor` of size [n, m].
51+ data: 2-D float `Tensor` of shape ` [n, m]` .
52+ mask: 2-D Boolean `Tensor` of shape ` [n, m]` .
5353 dim: The dimension over which to compute the minimum.
5454
5555 Returns:
@@ -74,33 +74,36 @@ def triplet_semihard_loss(
7474 margin : FloatTensorLike = 1.0 ,
7575 distance_metric : Union [str , Callable ] = "L2" ,
7676) -> tf .Tensor :
77- """Computes the triplet loss with semi-hard negative mining.
77+ r"""Computes the triplet loss with semi-hard negative mining.
78+
79+ Usage:
80+
81+ >>> y_true = tf.convert_to_tensor([0, 0])
82+ >>> y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]])
83+ >>> tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric="L2")
84+ <tf.Tensor: shape=(), dtype=float32, numpy=2.4142137>
85+
86+ >>> # Calling with callable `distance_metric`
87+ >>> distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True)
88+ >>> tfa.losses.triplet_semihard_loss(y_true, y_pred, distance_metric=distance_metric)
89+ <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
7890
7991 Args:
80- y_true: 1-D integer `Tensor` with shape [batch_size] of
92+ y_true: 1-D integer `Tensor` with shape ` [batch_size]` of
8193 multiclass integer labels.
8294 y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should
8395 be l2 normalized.
8496 margin: Float, margin term in the loss definition.
85- distance_metric: str or function, determines distance metric:
86- "L2" for l2-norm distance
87- "squared-L2" for squared l2-norm distance
88- "angular" for cosine similarity
89- A custom function returning a 2d adjacency
90- matrix of a chosen distance metric can
91- also be passed here. e.g.
92-
93- def custom_distance(batch):
94- batch = 1 - batch @ batch.T
95- return batch
96-
97- triplet_semihard_loss(batch, labels,
98- distance_metric=custom_distance
99- )
97+ distance_metric: `str` or a `Callable` that determines distance metric.
98+ Valid strings are "L2" for l2-norm distance,
99+ "squared-L2" for squared l2-norm distance,
100+ and "angular" for cosine similarity.
100101
102+ A `Callable` should take a batch of embeddings as input and
103+ return the pairwise distance matrix.
101104
102105 Returns:
103- triplet_loss: float scalar with dtype of y_pred.
106+ triplet_loss: float scalar with dtype of ` y_pred` .
104107 """
105108
106109 labels , embeddings = y_true , y_pred
@@ -207,33 +210,37 @@ def triplet_hard_loss(
207210 soft : bool = False ,
208211 distance_metric : Union [str , Callable ] = "L2" ,
209212) -> tf .Tensor :
210- """Computes the triplet loss with hard negative and hard positive mining.
213+ r"""Computes the triplet loss with hard negative and hard positive mining.
214+
215+ Usage:
216+
217+ >>> y_true = tf.convert_to_tensor([0, 0])
218+ >>> y_pred = tf.convert_to_tensor([[0.0, 1.0], [1.0, 0.0]])
219+ >>> tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric="L2")
220+ <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
221+
222+ >>> # Calling with callable `distance_metric`
223+ >>> distance_metric = lambda x: tf.linalg.matmul(x, x, transpose_b=True)
224+ >>> tfa.losses.triplet_hard_loss(y_true, y_pred, distance_metric=distance_metric)
225+ <tf.Tensor: shape=(), dtype=float32, numpy=0.0>
211226
212227 Args:
213- y_true: 1-D integer `Tensor` with shape [batch_size] of
228+ y_true: 1-D integer `Tensor` with shape ` [batch_size]` of
214229 multiclass integer labels.
215230 y_pred: 2-D float `Tensor` of embedding vectors. Embeddings should
216231 be l2 normalized.
217232 margin: Float, margin term in the loss definition.
218233 soft: Boolean, if set, use the soft margin version.
219- distance_metric: str or function, determines distance metric:
220- "L2" for l2-norm distance
221- "squared-L2" for squared l2-norm distance
222- "angular" for cosine similarity
223- A custom function returning a 2d adjacency
224- matrix of a chosen distance metric can
225- also be passed here. e.g.
226-
227- def custom_distance(batch):
228- batch = 1 - batch @ batch.T
229- return batch
230-
231- triplet_semihard_loss(batch, labels,
232- distance_metric=custom_distance
233- )
234+ distance_metric: `str` or a `Callable` that determines distance metric.
235+ Valid strings are "L2" for l2-norm distance,
236+ "squared-L2" for squared l2-norm distance,
237+ and "angular" for cosine similarity.
238+
239+ A `Callable` should take a batch of embeddings as input and
240+ return the pairwise distance matrix.
234241
235242 Returns:
236- triplet_loss: float scalar with dtype of y_pred.
243+ triplet_loss: float scalar with dtype of ` y_pred` .
237244 """
238245 labels , embeddings = y_true , y_pred
239246
@@ -311,7 +318,7 @@ class TripletSemiHardLoss(LossFunctionWrapper):
311318 See: https://arxiv.org/abs/1503.03832.
312319
313320 We expect labels `y_true` to be provided as 1-D integer `Tensor` with shape
314- [batch_size] of multi-class integer labels. And embeddings `y_pred` must be
321+ ` [batch_size]` of multi-class integer labels. And embeddings `y_pred` must be
315322 2-D float `Tensor` of l2 normalized embedding vectors.
316323
317324 Args:
@@ -348,7 +355,7 @@ class TripletHardLoss(LossFunctionWrapper):
348355 See: https://arxiv.org/pdf/1703.07737.
349356
350357 We expect labels `y_true` to be provided as 1-D integer `Tensor` with shape
351- [batch_size] of multi-class integer labels. And embeddings `y_pred` must be
358+ ` [batch_size]` of multi-class integer labels. And embeddings `y_pred` must be
352359 2-D float `Tensor` of l2 normalized embedding vectors.
353360
354361 Args:
0 commit comments