From b5b0fb6ad54841a90f9ec3900146091441c0d784 Mon Sep 17 00:00:00 2001 From: Tzu-Wei Sung Date: Tue, 14 Jan 2020 22:16:15 -0800 Subject: [PATCH] FIX: sigmoid focal cross entropy model compile --- tensorflow_addons/losses/focal_loss.py | 4 ---- tensorflow_addons/losses/focal_loss_test.py | 6 ++++++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/tensorflow_addons/losses/focal_loss.py b/tensorflow_addons/losses/focal_loss.py index beaacd4c02..6708afd732 100644 --- a/tensorflow_addons/losses/focal_loss.py +++ b/tensorflow_addons/losses/focal_loss.py @@ -117,10 +117,6 @@ def sigmoid_focal_crossentropy(y_true, y_pred = tf.convert_to_tensor(y_pred) y_true = tf.convert_to_tensor(y_true, dtype=y_pred.dtype) - if y_true.shape != y_pred.shape: - raise ValueError("Shape mismatch for y_true: {} and y_pred: {}".format( - tf.shape(y_true), tf.shape(y_pred))) - # Get the cross_entropy for each entry ce = K.binary_crossentropy(y_true, y_pred, from_logits=from_logits) diff --git a/tensorflow_addons/losses/focal_loss_test.py b/tensorflow_addons/losses/focal_loss_test.py index 8641ca9e93..e98d825bc7 100644 --- a/tensorflow_addons/losses/focal_loss_test.py +++ b/tensorflow_addons/losses/focal_loss_test.py @@ -106,6 +106,12 @@ def test_without_logits(self): pow_values = tf.constant([1000, 100, 10, 10, 100, 1000]) self.assertAllClose(order_of_ratio, pow_values) + def test_keras_model_compile(self): + model = tf.keras.models.Sequential([ + tf.keras.layers.Input(shape=(100,)), + tf.keras.layers.Dense(5, activation="softmax") + ]) + model.compile(loss="Addons>sigmoid_focal_crossentropy") if __name__ == '__main__': tf.test.main()