@@ -48,7 +48,7 @@ def _prepare_decoder_ids_for_generation(
4848 return torch .ones ((batch_size , 1 ), dtype = torch .long , device = device ) * pad_idx
4949
5050 def greedy_search (
51- self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : Optional [ int ] = None , ** model_kwargs
51+ self , input_ids : torch .Tensor , max_length : int , eos_idx : int , pad_idx : int , ** model_kwargs
5252 ) -> torch .Tensor :
5353 """Greedy search decoding for text generation. Takes the most likely next token every time.
5454
@@ -62,10 +62,11 @@ def greedy_search(
6262 Returns:
6363 Batch of sequences decoded by greedy search.
6464 """
65- unfinished_sequences = torch .ones ((input_ids .shape [0 ], 1 ), device = input_ids .device , dtype = torch .long )
65+ unfinished_sequences = torch .ones ((input_ids .shape [0 ]), device = input_ids .device , dtype = torch .long )
6666
6767 while True :
6868 model_inputs = self .model .prepare_inputs_for_generation (input_ids , ** model_kwargs )
69+
6970 if self .is_huggingface_model :
7071 model_inputs ["return_dict" ] = True
7172 model_inputs ["output_hidden_states" ] = True
@@ -77,18 +78,16 @@ def greedy_search(
7778
7879 # Calculate probabilities and take the most likely next token
7980 probs = F .log_softmax (decoder_output [:, - 1 ], dim = - 1 )
80- _ , next_tokens = torch .topk (probs , 1 )
81+ next_tokens = torch .argmax (probs , dim = - 1 )
8182
8283 # For any finished sequences, padding idx should be the last token
83- if eos_idx is not None :
84- if pad_idx is not None :
85- next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
84+ next_tokens = next_tokens * unfinished_sequences + pad_idx * (1 - unfinished_sequences )
8685
8786 # Append the next tokens to the previous tokens
88- input_ids = torch .cat ([input_ids , next_tokens ], dim = - 1 )
87+ input_ids = torch .cat ([input_ids , next_tokens [:, None ] ], dim = - 1 )
8988
90- if eos_idx is not None :
91- unfinished_sequences = unfinished_sequences .mul ((next_tokens != eos_idx ).long () )
89+ # Update unfinished sequences count
90+ unfinished_sequences = unfinished_sequences .mul ((next_tokens != eos_idx )) .long ()
9291
9392 # Stop iterating once all sequences are finished or exceed the max_length
9493 if unfinished_sequences .max () == 0 or len (input_ids [0 ]) >= max_length :
@@ -128,8 +127,10 @@ def generate(
128127
129128 if self .is_encoder_decoder :
130129 encoder = self .model .get_encoder ()
131- model_kwargs ["encoder_outputs" ] = encoder (inputs )
130+ encoder_model_kwargs = {"src_key_padding_mask" : inputs .eq (pad_idx )}
131+ model_kwargs ["encoder_outputs" ] = encoder (inputs , ** encoder_model_kwargs )
132132 inputs = self ._prepare_decoder_ids_for_generation (len (inputs ), device = inputs .device , ** model_kwargs )
133+ model_kwargs ["encoder_padding_mask" ] = encoder_model_kwargs .pop ("src_key_padding_mask" )
133134
134135 if max_length is None :
135136 # Too hard to try to figure out the exact max_seq_length for each model
0 commit comments