diff --git a/tensorflow_addons/seq2seq/loss.py b/tensorflow_addons/seq2seq/loss.py index 9e9b8ef506..a3e5283941 100644 --- a/tensorflow_addons/seq2seq/loss.py +++ b/tensorflow_addons/seq2seq/loss.py @@ -89,12 +89,20 @@ def sequence_loss(logits, if len(logits.get_shape()) != 3: raise ValueError("Logits must be a " "[batch_size x sequence_length x logits] tensor") - if len(targets.get_shape()) != 2: + + targets_rank = len(targets.get_shape()) + if targets_rank != 2 and targets_rank != 3: raise ValueError( - "Targets must be a [batch_size x sequence_length] tensor") + "Targets must be either a [batch_size x sequence_length] tensor " \ + + "where each element contains the labels' index" \ + + "or a [batch_size x sequence_length x num_classes] tensor " \ + + "where the third axis is a one-hot representation of the labels" + ) + if len(weights.get_shape()) != 2: raise ValueError( "Weights must be a [batch_size x sequence_length] tensor") + if average_across_timesteps and sum_over_timesteps: raise ValueError( "average_across_timesteps and sum_over_timesteps cannot " @@ -114,11 +122,17 @@ def sequence_loss(logits, with tf.name_scope(name or "sequence_loss"): num_classes = tf.shape(input=logits)[2] logits_flat = tf.reshape(logits, [-1, num_classes]) - targets = tf.reshape(targets, [-1]) if softmax_loss_function is None: - crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( - labels=targets, logits=logits_flat) + if targets_rank == 2: + targets = tf.reshape(targets, [-1]) + crossent = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=targets, logits=logits_flat) + else: + targets = tf.reshape(targets, [-1, num_classes]) + crossent = tf.nn.softmax_cross_entropy_with_logits( + labels=targets, logits=logits_flat) else: + targets = tf.reshape(targets, [-1]) crossent = softmax_loss_function( labels=targets, logits=logits_flat) crossent *= tf.reshape(weights, [-1]) @@ -168,6 +182,11 @@ def __init__(self, self.sum_over_batch = sum_over_batch self.softmax_loss_function = softmax_loss_function + # Delete the reduction attribute to inform Keras that it + # should call this class by the __call__(...) method. + if hasattr(self, 'reduction'): + delattr(self, 'reduction') + def __call__(self, y_true, y_pred, sample_weight=None): """Override the parent __call__ to have a customized reduce behavior.""" diff --git a/tensorflow_addons/seq2seq/loss_test.py b/tensorflow_addons/seq2seq/loss_test.py index 042bb3cac3..8ae600e25d 100644 --- a/tensorflow_addons/seq2seq/loss_test.py +++ b/tensorflow_addons/seq2seq/loss_test.py @@ -310,5 +310,56 @@ def testAmbiguousOrder(self): seq_loss(self.targets, self.logits, self.weights)) +@test_utils.run_all_in_graph_and_eager_modes +class DenseTargetLossTest(LossTest): + def setup(self): + super(DenseTargetLossTest, self).setup() + self.targets = tf.one_hot(self.targets, depth=self.number_of_classes) + + def testKerasCompatibility(self): + """To test the compatibility of SequenceLoss with Keras's built-in + training loops, we create a fake model which always outputs a pre- + defined set of logits. + + Then we check the calculated loss to be equal to the expected + loss. Note that since the fake model doesn't have any trainable + parameters, no matter how many steps we train it, it always + outputs the same loss value. + """ + with self.cached_session(use_gpu=True): + self.setup() + + def return_logits(x): + batch_size = tf.shape(x)[0] + logits_single_row = self.logits[0, :, :] + logits_batch = tf.tile( + tf.expand_dims(logits_single_row, 0), [batch_size, 1, 1]) + return logits_batch + + inp = tf.keras.layers.Input(shape=(self.sequence_length,)) + out = tf.keras.layers.Lambda( + return_logits, + output_shape=(self.sequence_length, + self.number_of_classes))(inp) + model = tf.keras.models.Model(inp, out) + + loss_obj = loss.SequenceLoss() + model.compile( + optimizer='adam', loss=loss_obj, sample_weight_mode="temporal") + + # This is a fake input. + x = tf.ones(shape=(self.batch_size, self.sequence_length)) + + h = model.fit( + x, + self.targets, + sample_weight=self.weights, + batch_size=self.batch_size, + steps_per_epoch=1) + + calculated_loss = h.history['loss'][0] + self.assertAllClose(calculated_loss, self.expected_loss) + + if __name__ == '__main__': tf.test.main()