From 41e5269f720c5056dd6a5ebec9c22bb45a8c6b19 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Thu, 20 Aug 2020 20:16:55 -0700 Subject: [PATCH] Convert inputs to tensor --- tensorflow_addons/losses/triplet.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/losses/triplet.py b/tensorflow_addons/losses/triplet.py index e3b056f8a1..3b3d7b266d 100644 --- a/tensorflow_addons/losses/triplet.py +++ b/tensorflow_addons/losses/triplet.py @@ -105,8 +105,8 @@ def triplet_semihard_loss( Returns: triplet_loss: float scalar with dtype of `y_pred`. """ - - labels, embeddings = y_true, y_pred + labels = tf.convert_to_tensor(y_true, name="labels") + embeddings = tf.convert_to_tensor(y_pred, name="embeddings") convert_to_float32 = ( embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16 @@ -242,7 +242,8 @@ def triplet_hard_loss( Returns: triplet_loss: float scalar with dtype of `y_pred`. """ - labels, embeddings = y_true, y_pred + labels = tf.convert_to_tensor(y_true, name="labels") + embeddings = tf.convert_to_tensor(y_pred, name="embeddings") convert_to_float32 = ( embeddings.dtype == tf.dtypes.float16 or embeddings.dtype == tf.dtypes.bfloat16