Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions tensorflow_addons/seq2seq/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,14 +404,27 @@ def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
"""
(next_outputs, decoder_state, next_inputs,
decoder_finished) = decoder.step(time, inputs, state, training)
decoder_state_sequence_lengths = False
if decoder.tracks_own_finished:
next_finished = decoder_finished
lengths = getattr(decoder_state, "lengths", None)
if lengths is not None:
# sequence lengths are provided by decoder_state.lengths;
# overwrite our sequence lengths.
decoder_state_sequence_lengths = True
sequence_lengths = tf.cast(lengths, tf.int32)
else:
next_finished = tf.logical_or(decoder_finished, finished)
next_sequence_lengths = tf.where(
tf.logical_not(finished),
tf.fill(tf.shape(sequence_lengths), time + 1),
sequence_lengths)

if decoder_state_sequence_lengths:
# Just pass something through the loop; at the next iteration
# we'll pull the sequence lengths from the decoder_state again.
next_sequence_lengths = sequence_lengths
else:
next_sequence_lengths = tf.where(
tf.logical_not(finished),
tf.fill(tf.shape(sequence_lengths), time + 1),
sequence_lengths)

tf.nest.assert_same_structure(state, decoder_state)
tf.nest.assert_same_structure(outputs_ta, next_outputs)
Expand Down