From ef2b9ffd7fd9f63419b9b54cb0e277a0ded8a9a0 Mon Sep 17 00:00:00 2001 From: AmirHosein KazemNejad Date: Wed, 11 Sep 2019 21:36:15 +0430 Subject: [PATCH 1/3] Fix SequenceLoss incompatibility with Keras built-in loops --- tensorflow_addons/seq2seq/loss.py | 30 ++++++++++++--- tensorflow_addons/seq2seq/loss_test.py | 51 ++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/seq2seq/loss.py b/tensorflow_addons/seq2seq/loss.py index 9e9b8ef506..8e06da09eb 100644 --- a/tensorflow_addons/seq2seq/loss.py +++ b/tensorflow_addons/seq2seq/loss.py @@ -89,12 +89,21 @@ def sequence_loss(logits, if len(logits.get_shape()) != 3: raise ValueError("Logits must be a " "[batch_size x sequence_length x logits] tensor") - if len(targets.get_shape()) != 2: + + targets_rank = len(targets.get_shape()) + if targets_rank != 2 and targets_rank != 3: + print(targets_rank) raise ValueError( - "Targets must be a [batch_size x sequence_length] tensor") + "Targets must be either a [batch_size x sequence_length] tensor " \ + + "where each element contains the labels' index" \ + + "or a [batch_size x sequence_length x num_classes] tensor " \ + + "where the third axis is a one-hot representation of the labels" + ) + if len(weights.get_shape()) != 2: raise ValueError( "Weights must be a [batch_size x sequence_length] tensor") + if average_across_timesteps and sum_over_timesteps: raise ValueError( "average_across_timesteps and sum_over_timesteps cannot " @@ -114,11 +123,17 @@ def sequence_loss(logits, with tf.name_scope(name or "sequence_loss"): num_classes = tf.shape(input=logits)[2] logits_flat = tf.reshape(logits, [-1, num_classes]) - targets = tf.reshape(targets, [-1]) if softmax_loss_function is None: - crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=targets, logits=logits_flat) + if targets_rank == 2: + targets = tf.reshape(targets, [-1]) + crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=targets, logits=logits_flat) + else: + targets = tf.reshape(targets, [-1, num_classes]) + crossent = tf.nn.softmax_cross_entropy_with_logits( + labels=targets, logits=logits_flat) else: + targets = tf.reshape(targets, [-1]) crossent = softmax_loss_function( labels=targets, logits=logits_flat) crossent *= tf.reshape(weights, [-1]) @@ -168,6 +183,11 @@ def __init__(self, self.sum_over_batch = sum_over_batch self.softmax_loss_function = softmax_loss_function + # Delete the reduction attribute to inform Keras that it + # should call this class by the __call__(...) method. + if 'reduction' in dir(self): + delattr(self, 'reduction') + def __call__(self, y_true, y_pred, sample_weight=None): """Override the parent __call__ to have a customized reduce behavior.""" diff --git a/tensorflow_addons/seq2seq/loss_test.py b/tensorflow_addons/seq2seq/loss_test.py index 042bb3cac3..8ae600e25d 100644 --- a/tensorflow_addons/seq2seq/loss_test.py +++ b/tensorflow_addons/seq2seq/loss_test.py @@ -310,5 +310,56 @@ def testAmbiguousOrder(self): seq_loss(self.targets, self.logits, self.weights)) +@test_utils.run_all_in_graph_and_eager_modes +class DenseTargetLossTest(LossTest): + def setup(self): + super(DenseTargetLossTest, self).setup() + self.targets = tf.one_hot(self.targets, depth=self.number_of_classes) + + def testKerasCompatibility(self): + """To test the compatibility of SequenceLoss with Keras's built-in + training loops, we create a fake model which always outputs a pre- + defined set of logits. + + Then we check the calculated loss to be equal to the expected + loss. Note that since the fake model doesn't have any trainable + parameters, no matter how many steps we train it, it always + outputs the same loss value. + """ + with self.cached_session(use_gpu=True): + self.setup() + + def return_logits(x): + batch_size = tf.shape(x)[0] + logits_single_row = self.logits[0, :, :] + logits_batch = tf.tile( + tf.expand_dims(logits_single_row, 0), [batch_size, 1, 1]) + return logits_batch + + inp = tf.keras.layers.Input(shape=(self.sequence_length,)) + out = tf.keras.layers.Lambda( + return_logits, + output_shape=(self.sequence_length, + self.number_of_classes))(inp) + model = tf.keras.models.Model(inp, out) + + loss_obj = loss.SequenceLoss() + model.compile( + optimizer='adam', loss=loss_obj, sample_weight_mode="temporal") + + # This is a fake input. + x = tf.ones(shape=(self.batch_size, self.sequence_length)) + + h = model.fit( + x, + self.targets, + sample_weight=self.weights, + batch_size=self.batch_size, + steps_per_epoch=1) + + calculated_loss = h.history['loss'][0] + self.assertAllClose(calculated_loss, self.expected_loss) + + if __name__ == '__main__': tf.test.main() From bd32090838cb5d718439415a427558144f405470 Mon Sep 17 00:00:00 2001 From: AmirHosein KazemNejad Date: Wed, 11 Sep 2019 21:45:47 +0430 Subject: [PATCH 2/3] Remove debugging prints --- tensorflow_addons/seq2seq/loss.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow_addons/seq2seq/loss.py b/tensorflow_addons/seq2seq/loss.py index 8e06da09eb..aaa2aa1763 100644 --- a/tensorflow_addons/seq2seq/loss.py +++ b/tensorflow_addons/seq2seq/loss.py @@ -92,7 +92,6 @@ def sequence_loss(logits, targets_rank = len(targets.get_shape()) if targets_rank != 2 and targets_rank != 3: - print(targets_rank) raise ValueError( "Targets must be either a [batch_size x sequence_length] tensor " \ + "where each element contains the labels' index" \ From a4917d720a0bb3f1d79391207484bcedc164a927 Mon Sep 17 00:00:00 2001 From: AmirHosein KazemNejad Date: Wed, 11 Sep 2019 22:10:43 +0430 Subject: [PATCH 3/3] Change the attribute existence checking to use more pythonic way --- tensorflow_addons/seq2seq/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/seq2seq/loss.py b/tensorflow_addons/seq2seq/loss.py index aaa2aa1763..a3e5283941 100644 --- a/tensorflow_addons/seq2seq/loss.py +++ b/tensorflow_addons/seq2seq/loss.py @@ -184,7 +184,7 @@ def __init__(self, # Delete the reduction attribute to inform Keras that it # should call this class by the __call__(...) method. - if 'reduction' in dir(self): + if hasattr(self, 'reduction'): delattr(self, 'reduction') def __call__(self, y_true, y_pred, sample_weight=None):