Skip to content

Commit 5e3b399

Browse files
qlzh727guillaumekln
authored andcommitted
Port bug fix in TF contrib to addons. (#497)
* Port bug fix in TF contrib to addons. Original change at tensorflow/tensorflow@a913689. * Fix lint warning.
1 parent 6633c43 commit 5e3b399

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

tensorflow_addons/seq2seq/decoder.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,27 @@ def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
404404
"""
405405
(next_outputs, decoder_state, next_inputs,
406406
decoder_finished) = decoder.step(time, inputs, state, training)
407+
decoder_state_sequence_lengths = False
407408
if decoder.tracks_own_finished:
408409
next_finished = decoder_finished
410+
lengths = getattr(decoder_state, "lengths", None)
411+
if lengths is not None:
412+
# sequence lengths are provided by decoder_state.lengths;
413+
# overwrite our sequence lengths.
414+
decoder_state_sequence_lengths = True
415+
sequence_lengths = tf.cast(lengths, tf.int32)
409416
else:
410417
next_finished = tf.logical_or(decoder_finished, finished)
411-
next_sequence_lengths = tf.where(
412-
tf.logical_not(finished),
413-
tf.fill(tf.shape(sequence_lengths), time + 1),
414-
sequence_lengths)
418+
419+
if decoder_state_sequence_lengths:
420+
# Just pass something through the loop; at the next iteration
421+
# we'll pull the sequence lengths from the decoder_state again.
422+
next_sequence_lengths = sequence_lengths
423+
else:
424+
next_sequence_lengths = tf.where(
425+
tf.logical_not(finished),
426+
tf.fill(tf.shape(sequence_lengths), time + 1),
427+
sequence_lengths)
415428

416429
tf.nest.assert_same_structure(state, decoder_state)
417430
tf.nest.assert_same_structure(outputs_ta, next_outputs)

0 commit comments

Comments
 (0)