@@ -404,14 +404,27 @@ def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
404404            """ 
405405            (next_outputs , decoder_state , next_inputs ,
406406             decoder_finished ) =  decoder .step (time , inputs , state , training )
407+             decoder_state_sequence_lengths  =  False 
407408            if  decoder .tracks_own_finished :
408409                next_finished  =  decoder_finished 
410+                 lengths  =  getattr (decoder_state , "lengths" , None )
411+                 if  lengths  is  not   None :
412+                     # sequence lengths are provided by decoder_state.lengths; 
413+                     # overwrite our sequence lengths. 
414+                     decoder_state_sequence_lengths  =  True 
415+                     sequence_lengths  =  tf .cast (lengths , tf .int32 )
409416            else :
410417                next_finished  =  tf .logical_or (decoder_finished , finished )
411-             next_sequence_lengths  =  tf .where (
412-                 tf .logical_not (finished ),
413-                 tf .fill (tf .shape (sequence_lengths ), time  +  1 ),
414-                 sequence_lengths )
418+ 
419+             if  decoder_state_sequence_lengths :
420+                 # Just pass something through the loop; at the next iteration 
421+                 # we'll pull the sequence lengths from the decoder_state again. 
422+                 next_sequence_lengths  =  sequence_lengths 
423+             else :
424+                 next_sequence_lengths  =  tf .where (
425+                     tf .logical_not (finished ),
426+                     tf .fill (tf .shape (sequence_lengths ), time  +  1 ),
427+                     sequence_lengths )
415428
416429            tf .nest .assert_same_structure (state , decoder_state )
417430            tf .nest .assert_same_structure (outputs_ta , next_outputs )
0 commit comments