diff --git a/tensorflow_addons/losses/triplet.py b/tensorflow_addons/losses/triplet.py index e3b056f8a1..3b3d7b266d 100644 --- a/tensorflow_addons/losses/triplet.py +++ b/tensorflow_addons/losses/triplet.py @@ -105,8 +105,8 @@ def triplet_semihard_loss( Returns: triplet_loss: float scalar with dtype of `y_pred`. """ - - labels, embeddings = y_true, y_pred + labels = tf.convert_to_tensor(y_true, name="labels") + embeddings = tf.convert_to_tensor(y_pred, name="embeddings") convert_to_float32 = ( embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16 @@ -242,7 +242,8 @@ def triplet_hard_loss( Returns: triplet_loss: float scalar with dtype of `y_pred`. """ - labels, embeddings = y_true, y_pred + labels = tf.convert_to_tensor(y_true, name="labels") + embeddings = tf.convert_to_tensor(y_pred, name="embeddings") convert_to_float32 = ( embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16