From fbf18ae3a387617b47957ef1b3c390ec783abf74 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Tue, 28 Nov 2023 18:28:25 +0000 Subject: [PATCH 01/17] MVP --- src/transformers/generation/candidates.py | 319 ++++++++++++++++++++++ src/transformers/generation/utils.py | 15 +- 2 files changed, 326 insertions(+), 8 deletions(-) create mode 100644 src/transformers/generation/candidates.py diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py new file mode 100644 index 000000000000..9cccc81fed06 --- /dev/null +++ b/src/transformers/generation/candidates.py @@ -0,0 +1,319 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import inspect +import warnings +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, List + +import torch + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + from .logits_process import LogitsProcessorList + + +class CandidateGenerator: + """Abstract base class for all candidate generators that can be applied during assisted generation.""" + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by + the model. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." + ) + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can call " + "`update_candidate_strategy`." + ) + + +class AssistedCandidateGenerator(CandidateGenerator): + """ + `CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of + a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + assistant_model (`PreTrainedModel`): + The model to be used for generating candidates. This model should be smaller than the main model. + logits_processor (`LogitsProcessorList`): + An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] + used to modify the prediction scores of the language modeling head applied at each generation step. + model_kwargs (`Dict`): + The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant + model as well. + inputs_tensor (`torch.Tensor`, *optional*): + The model input tensor. In encoder-decoder models, this is the encoder input. + eos_token_id (`Union[int, List[int]]`, *optional*): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + """ + + def __init__( + self, + input_ids: torch.LongTensor, + assistant_model: "PreTrainedModel", + logits_processor: "LogitsProcessorList", + model_kwargs: Dict, + inputs_tensor: Optional[torch.Tensor] = None, + eos_token_id: Optional[Union[int, List[int]]] = None + ): + + self.assistant_model = assistant_model + + # Prepare the number of candidate tokens + if hasattr(assistant_model, "num_assistant_tokens"): + warnings.warn( + "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be " + "removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", + FutureWarning, + ) + self.num_assistant_tokens = assistant_model.num_assistant_tokens + else: + self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens + + # Prepare the kwargs for the assistant model + assistant_kwargs = copy.deepcopy(model_kwargs) + if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: + inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( + inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs + ) + assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( + inputs_tensor, assistant_kwargs, model_input_name + ) + self.assistant_kwargs = assistant_kwargs + + # Prepare assistant model's keys of inputs + if assistant_model.config.is_encoder_decoder: + # both are encoder-decoder + self.input_ids_key = "decoder_input_ids" + self.attention_key = "decoder_attention_mask" + elif "encoder_outputs" in assistant_kwargs: + # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) + self.input_ids_key = "input_ids" + self.attention_key = "attention_mask" + self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( + "decoder_attention_mask", + torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), + ) + else: + # both are decoder-only + self.input_ids_key = "input_ids" + self.attention_key = "attention_mask" + + # Prepare other attributes + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + self.logits_processor = logits_processor + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. + """ + # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length + # (which implicitly contains the number of accepted candidates from the previous round) + has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None + if has_past_key_values: + new_cur_len = input_ids.shape[-1] + + new_cache_size = new_cur_len - 1 + self.assistant_kwargs["past_key_values"] = _crop_past_key_values( + self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 + ) # the assistant does not have the token after the last match, hence the -1 + + self.assistant_kwargs = _prepare_attention_mask( + self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder + ) + self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) + + # 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()` + # call if we decide to add `past_key_values` as a possible output of generate, as we need access to the + # assistant cache to secure strong speedups. + candidate_input_ids = input_ids + for _ in range(int(self.num_assistant_tokens)): + # 2.1 prepare assistant model inputs + assistant_inputs = self.assistant_model.prepare_inputs_for_generation( + candidate_input_ids, + **self.assistant_kwargs, + ) + + # 2.2. check if the input ids length is correct + has_past_key_values = assistant_inputs.get("past_key_values", None) is not None + if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2): + raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") + + # 2.3. use the assistant model to obtain the next candidate logits + assistant_model_outputs = self.assistant_model(**assistant_inputs) + + # 2.4. greedily select the next candidate token + if len(self.logits_processor) > 0: + assistant_model_outputs.logits[:, -1, :] = self.logits_processor( + candidate_input_ids, assistant_model_outputs.logits[:, -1, :] + ) + new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) + candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) + + # 2.5. update assistant model inputs + if self.assistant_kwargs.get(self.attention_key, None) is not None: + mask = self.assistant_kwargs[self.attention_key] + self.assistant_kwargs[self.attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1) + self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values + + # 2.6. stop assistant generation on EOS + if self.eos_token_id_tensor is not None: + last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1) + last_assistant_token_is_eos = ( + ~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() + ) + if last_assistant_token_is_eos: + break + + return candidate_input_ids + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, + # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the + # cost of forecasting incorrect assistant tokens. + if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": + if num_matches == int(self.num_assistant_tokens): + self.num_assistant_tokens += 2.0 + else: + self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) + + +def _crop_past_key_values(model, past_key_values, maximum_length): + """Crops the past key values up to a certain maximum length.""" + new_past = [] + if model.config.is_encoder_decoder: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + past_key_values[idx][2], + past_key_values[idx][3], + ) + ) + past_key_values = tuple(new_past) + # bloom is special + elif "bloom" in model.__class__.__name__.lower() or ( + model.config.architectures is not None and "bloom" in model.config.architectures[0].lower() + ): + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length], + past_key_values[idx][1][:, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + # gptbigcode is too + elif "gptbigcode" in model.__class__.__name__.lower() or ( + model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() + ): + if model.config.multi_query: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] + else: + for idx in range(len(past_key_values)): + new_past.append( + ( + past_key_values[idx][0][:, :, :maximum_length, :], + past_key_values[idx][1][:, :, :maximum_length, :], + ) + ) + past_key_values = tuple(new_past) + return past_key_values + + +def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: + """Expands or crops the model's mask for decoding purposes, to the defined length""" + + mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" + if mask_key not in model_kwargs: + return model_kwargs + + mask = model_kwargs[mask_key] + mask_length_diff = new_length - mask.shape[1] + + if mask_length_diff < 0: + model_kwargs[mask_key] = mask[:, :mask_length_diff] + elif mask_length_diff > 0: + model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) + return model_kwargs + + +def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" + if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + return model_kwargs + + token_type_ids = model_kwargs["token_type_ids"] + final_token_type = token_type_ids[:, -1].unsqueeze(-1) + type_length_diff = new_length - token_type_ids.shape[1] + + if type_length_diff < 0: + token_type_ids = token_type_ids[:, :type_length_diff] + elif type_length_diff > 0: + token_type_copies = final_token_type.repeat(1, type_length_diff) + model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) + return model_kwargs diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c7ae4aee7f8d..451de1c33c49 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -37,7 +37,7 @@ from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer -from .candidate_generator import ( +from .candidates import ( AssistedCandidateGenerator, CandidateGenerator, _crop_past_key_values, @@ -904,6 +904,7 @@ def _get_candidate_generator( assistant_model: "PreTrainedModel", logits_processor: LogitsProcessorList, model_kwargs: Dict, + eos_token_id: Union[int, List[int]], ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` @@ -915,6 +916,7 @@ def _get_candidate_generator( logits_processor=logits_processor, model_kwargs=model_kwargs, inputs_tensor=inputs_tensor, + eos_token_id=eos_token_id, ) return candidate_generator @@ -1708,6 +1710,7 @@ def generate( assistant_model=assistant_model, logits_processor=logits_processor, model_kwargs=model_kwargs, + eos_token_id=generation_config.eos_token_id, ) # 12. run assisted generate @@ -4426,7 +4429,7 @@ def assisted_decoding( The sequence used as a prompt for the generation. candidate_generator (`CandidateGenerator`, *optional*): A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. + more information, the documentation of [`CandidateGenerator`] should be read. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model @@ -4515,11 +4518,7 @@ def assisted_decoding( if assistant_model is not None: candidate_generator = AssistedCandidateGenerator( - input_ids=input_ids, - assistant_model=assistant_model, - logits_processor=logits_processor, - model_kwargs=model_kwargs, - eos_token_id=eos_token_id, + input_ids, assistant_model, logits_processor, model_kwargs, eos_token_id ) warnings.warn( "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " @@ -4585,7 +4584,7 @@ def assisted_decoding( cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) + candidate_input_ids = candidate_generator.get_candidates(input_ids) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] last_assistant_token_is_eos = ( ~candidate_input_ids[:, -1] From 73e60c0942897073dcc378e5acec62e5de746470 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 29 Nov 2023 09:53:07 +0000 Subject: [PATCH 02/17] fix ci --- src/transformers/generation/candidates.py | 25 ++++++++++++++++------- src/transformers/generation/utils.py | 6 +++++- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py index 9cccc81fed06..573cac3b6a2e 100644 --- a/src/transformers/generation/candidates.py +++ b/src/transformers/generation/candidates.py @@ -14,12 +14,12 @@ # limitations under the License. import copy -import inspect import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Union, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch + if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel from .logits_process import LogitsProcessorList @@ -92,9 +92,8 @@ def __init__( logits_processor: "LogitsProcessorList", model_kwargs: Dict, inputs_tensor: Optional[torch.Tensor] = None, - eos_token_id: Optional[Union[int, List[int]]] = None + eos_token_id: Optional[Union[int, List[int]]] = None, ): - self.assistant_model = assistant_model # Prepare the number of candidate tokens @@ -109,7 +108,15 @@ def __init__( self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens # Prepare the kwargs for the assistant model - assistant_kwargs = copy.deepcopy(model_kwargs) + assistant_kwargs = {} + for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads + if key != "encoder_outputs": + assistant_kwargs[key] = ( + value.clone().detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) + ) + if "encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] + if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs @@ -140,7 +147,9 @@ def __init__( # Prepare other attributes if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - self.eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + self.eos_token_id_tensor = ( + torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None + ) self.logits_processor = logits_processor def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: @@ -200,7 +209,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: # 2.5. update assistant model inputs if self.assistant_kwargs.get(self.attention_key, None) is not None: mask = self.assistant_kwargs[self.attention_key] - self.assistant_kwargs[self.attention_key] = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1) + self.assistant_kwargs[self.attention_key] = torch.cat( + [mask, mask.new_ones((mask.shape[0], 1))], dim=-1 + ) self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values # 2.6. stop assistant generation on EOS diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 451de1c33c49..a8e9942f55c8 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4518,7 +4518,11 @@ def assisted_decoding( if assistant_model is not None: candidate_generator = AssistedCandidateGenerator( - input_ids, assistant_model, logits_processor, model_kwargs, eos_token_id + input_ids=input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + eos_token_id=eos_token_id, ) warnings.warn( "Passing `assistant_model` to `assisted_decoding` is deprecated and will be removed in v4.38. " From baed703da6a8fec983b799b224776b26de6a6b25 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 29 Nov 2023 10:16:49 +0000 Subject: [PATCH 03/17] more ci --- src/transformers/generation/candidates.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py index 573cac3b6a2e..7cceac3364af 100644 --- a/src/transformers/generation/candidates.py +++ b/src/transformers/generation/candidates.py @@ -110,20 +110,20 @@ def __init__( # Prepare the kwargs for the assistant model assistant_kwargs = {} for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads - if key != "encoder_outputs": - assistant_kwargs[key] = ( - value.clone().detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) - ) - if "encoder_outputs" in model_kwargs: - assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] + if key not in ("encoder_outputs", "assistant_encoder_outputs"): + assistant_kwargs[key] = value.detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) - if assistant_model.config.is_encoder_decoder and "assistant_encoder_outputs" not in model_kwargs: + if "assistant_encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] + elif assistant_model.config.is_encoder_decoder: inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs ) assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( inputs_tensor, assistant_kwargs, model_input_name ) + elif "encoder_outputs" in model_kwargs: + assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] self.assistant_kwargs = assistant_kwargs # Prepare assistant model's keys of inputs From 474d434589a0042494df2849d008bd60cb6c80c9 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 29 Nov 2023 10:35:13 +0000 Subject: [PATCH 04/17] remove redundant kwarg --- src/transformers/generation/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index a8e9942f55c8..1b1f5559a361 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -904,7 +904,6 @@ def _get_candidate_generator( assistant_model: "PreTrainedModel", logits_processor: LogitsProcessorList, model_kwargs: Dict, - eos_token_id: Union[int, List[int]], ) -> CandidateGenerator: """ Returns the candidate generator to be used in `assisted_generation` @@ -916,7 +915,7 @@ def _get_candidate_generator( logits_processor=logits_processor, model_kwargs=model_kwargs, inputs_tensor=inputs_tensor, - eos_token_id=eos_token_id, + eos_token_id=generation_config.eos_token_id, ) return candidate_generator @@ -1710,7 +1709,6 @@ def generate( assistant_model=assistant_model, logits_processor=logits_processor, model_kwargs=model_kwargs, - eos_token_id=generation_config.eos_token_id, ) # 12. run assisted generate From be93529447a5cc0a88b3d949c910205eab703e5e Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Thu, 30 Nov 2023 00:48:37 +0530 Subject: [PATCH 05/17] added and wired up PromptLookupCandidateGenerator --- src/transformers/generation/candidates.py | 91 +++++++++++++++++++ .../generation/configuration_utils.py | 4 + src/transformers/generation/utils.py | 30 ++++-- 3 files changed, 116 insertions(+), 9 deletions(-) diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py index 7cceac3364af..6257359bbaef 100644 --- a/src/transformers/generation/candidates.py +++ b/src/transformers/generation/candidates.py @@ -62,6 +62,97 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F "`update_candidate_strategy`." ) +class PromptLookupCandidateGenerator(CandidateGenerator): + """ + `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up + likely continuations in the provided prompt (input_ids) itself. + Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding + + Args: + max_matching_ngram_size (`int`): + The maximum ngram size to be considered for matching in the prompt + num_output_tokens (`int`): + The number of tokens to be output as candidate tokens. + """ + + def __init__( + self, + num_output_tokens: int = 10, + max_matching_ngram_size = 3, + ): + self.num_output_tokens = num_output_tokens + self.max_matching_ngram_size = max_matching_ngram_size + + + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. + """ + input_length = input_ids.size(1) + if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0 or self.max_matching_ngram_size > input_length: + raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") + + chosen_ids = None + match_found = False + for ngram_size in range(self.max_matching_ngram_size, 0, -1): + # Create sliding windows of size ngram_size + windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) + + # Convert ngram to a tensor for comparison + ngram_tensor = input_ids[0, -ngram_size:] + + # Find where the windows match the ngram + matches = (windows == ngram_tensor).all(dim=2) + + # Get the indices of matches + match_indices = matches.nonzero(as_tuple=True)[1] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + self.num_output_tokens + end_idx = min(end_idx, input_length) + + if start_idx < end_idx: + chosen_ids = input_ids[0, start_idx:end_idx] + match_found = True + break + if match_found: + break + + if chosen_ids == None or len(chosen_ids) == 0: + # Need to make a dummy tensor to avoid errors + chosen_ids = torch.tensor([0], dtype=torch.long, device=input_ids.device) + + # Now need extend input_ids with chosen_ids + chosen_ids = chosen_ids.unsqueeze(0) + candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) + return candidate_input_ids + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + # Currently does nothing + return + class AssistedCandidateGenerator(CandidateGenerator): """ diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 4818ca8d97b7..7cb9f2c150db 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -320,6 +320,10 @@ def __init__(self, **kwargs): self.num_assistant_tokens = kwargs.pop("num_assistant_tokens", 5) self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") + # Prompt lookup decoding + self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", 10) + self.prompt_lookup_max_matching_ngram = kwargs.pop("prompt_lookup_max_matching_ngram", 3) + # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1b1f5559a361..537bd86449cd 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -39,6 +39,7 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidates import ( AssistedCandidateGenerator, + PromptLookupCandidateGenerator, CandidateGenerator, _crop_past_key_values, _prepare_attention_mask, @@ -908,15 +909,26 @@ def _get_candidate_generator( """ Returns the candidate generator to be used in `assisted_generation` """ - candidate_generator = AssistedCandidateGenerator( - input_ids=input_ids, - assistant_model=assistant_model, - generation_config=generation_config, - logits_processor=logits_processor, - model_kwargs=model_kwargs, - inputs_tensor=inputs_tensor, - eos_token_id=generation_config.eos_token_id, - ) + # Check if assistant_model is a string + if isinstance(assistant_model, str): + if assistant_model == "prompt_lookup": + candidate_generator = PromptLookupCandidateGenerator( + num_output_tokens=generation_config.prompt_lookup_num_tokens, + max_matching_ngram_size=generation_config.prompt_lookup_max_matching_ngram, + ) + else: + raise NotImplementedError( + f"{assistant_model} is not implemented." + ) + else: + candidate_generator = AssistedCandidateGenerator( + input_ids=input_ids, + assistant_model=assistant_model, + logits_processor=logits_processor, + model_kwargs=model_kwargs, + inputs_tensor=inputs_tensor, + eos_token_id=generation_config.eos_token_id, + ) return candidate_generator def _get_logits_warper( From 2d5a67cae8e5a3ef30e0b8c1c1ba10e9650eb66f Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Mon, 18 Dec 2023 02:48:02 +0530 Subject: [PATCH 06/17] rebased with main, working --- .../generation/candidate_generator.py | 96 ++++ src/transformers/generation/candidates.py | 421 ------------------ .../generation/configuration_utils.py | 3 +- src/transformers/generation/utils.py | 26 +- 4 files changed, 107 insertions(+), 439 deletions(-) delete mode 100644 src/transformers/generation/candidates.py diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index bb82b852f001..ac9534726abb 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -219,6 +219,102 @@ def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.F self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) +class PromptLookupCandidateGenerator(CandidateGenerator): + """ + `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up + likely continuations in the provided prompt (input_ids) itself. + Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding + + Args: + max_matching_ngram_size (`int`): + The maximum ngram size to be considered for matching in the prompt + num_output_tokens (`int`): + The number of tokens to be output as candidate tokens. + """ + + def __init__( + self, + num_output_tokens: int = 10, + max_matching_ngram_size: int = 2, + ): + self.num_output_tokens = num_output_tokens + self.max_matching_ngram_size = max_matching_ngram_size + + if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: + raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") + + + + def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + """ + Fetches the candidates to be tried for the current input. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + + Return: + `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. + """ + input_length = input_ids.size(1) + if input_length < self.max_matching_ngram_size: + raise ValueError("Input length is smaller than max_matching_ngram_size for Prompt Lookup Decoding") + + chosen_ids = None + match_found = False + for ngram_size in range(self.max_matching_ngram_size, 0, -1): + # Create sliding windows of size ngram_size + windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) + + # Convert ngram to a tensor for comparison + ngram_tensor = input_ids[0, -ngram_size:] + + # Find where the windows match the ngram + matches = (windows == ngram_tensor).all(dim=2) + + # Get the indices of matches + match_indices = matches.nonzero(as_tuple=True)[1] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + self.num_output_tokens + end_idx = min(end_idx, input_length) + + if start_idx < end_idx: + chosen_ids = input_ids[0, start_idx:end_idx] + match_found = True + break + if match_found: + break + + if chosen_ids == None or len(chosen_ids) == 0: + # Need to make a dummy tensor to avoid errors + chosen_ids = torch.tensor([0], dtype=torch.long, device=input_ids.device) + + # Now need extend input_ids with chosen_ids + chosen_ids = chosen_ids.unsqueeze(0) + candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) + # assisted_generation expects logits as well, but we don't have those here, so returning empty list + return candidate_input_ids, [] + + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): + """ + Updates the candidate generation strategy based on the outcomes. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) + scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): + Prediction scores of a language modeling head. These can be logits for each vocabulary when not using + beam search or log softmax for each vocabulary token when using beam search + num_matches (`int`): + The number of matches between the candidate sequences and the model predictions. + """ + # Currently does nothing + return + + def _crop_past_key_values(model, past_key_values, maximum_length): """Crops the past key values up to a certain maximum length.""" new_past = [] diff --git a/src/transformers/generation/candidates.py b/src/transformers/generation/candidates.py deleted file mode 100644 index 6257359bbaef..000000000000 --- a/src/transformers/generation/candidates.py +++ /dev/null @@ -1,421 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import warnings -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union - -import torch - - -if TYPE_CHECKING: - from ..modeling_utils import PreTrainedModel - from .logits_process import LogitsProcessorList - - -class CandidateGenerator: - """Abstract base class for all candidate generators that can be applied during assisted generation.""" - - def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: - """ - Fetches the candidates to be tried for the current input. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - - Return: - `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be assessed by - the model. - """ - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`." - ) - - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): - """ - Updates the candidate generation strategy based on the outcomes. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): - Prediction scores of a language modeling head. These can be logits for each vocabulary when not using - beam search or log softmax for each vocabulary token when using beam search - num_matches (`int`): - The number of matches between the candidate sequences and the model predictions. - """ - raise NotImplementedError( - f"{self.__class__} is an abstract class. Only classes inheriting this class can call " - "`update_candidate_strategy`." - ) - -class PromptLookupCandidateGenerator(CandidateGenerator): - """ - `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up - likely continuations in the provided prompt (input_ids) itself. - Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding - - Args: - max_matching_ngram_size (`int`): - The maximum ngram size to be considered for matching in the prompt - num_output_tokens (`int`): - The number of tokens to be output as candidate tokens. - """ - - def __init__( - self, - num_output_tokens: int = 10, - max_matching_ngram_size = 3, - ): - self.num_output_tokens = num_output_tokens - self.max_matching_ngram_size = max_matching_ngram_size - - - - def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: - """ - Fetches the candidates to be tried for the current input. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - - Return: - `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. - """ - input_length = input_ids.size(1) - if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0 or self.max_matching_ngram_size > input_length: - raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") - - chosen_ids = None - match_found = False - for ngram_size in range(self.max_matching_ngram_size, 0, -1): - # Create sliding windows of size ngram_size - windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) - - # Convert ngram to a tensor for comparison - ngram_tensor = input_ids[0, -ngram_size:] - - # Find where the windows match the ngram - matches = (windows == ngram_tensor).all(dim=2) - - # Get the indices of matches - match_indices = matches.nonzero(as_tuple=True)[1] - - # Iterate through match indices to find a valid continuation - for idx in match_indices: - start_idx = idx + ngram_size - end_idx = start_idx + self.num_output_tokens - end_idx = min(end_idx, input_length) - - if start_idx < end_idx: - chosen_ids = input_ids[0, start_idx:end_idx] - match_found = True - break - if match_found: - break - - if chosen_ids == None or len(chosen_ids) == 0: - # Need to make a dummy tensor to avoid errors - chosen_ids = torch.tensor([0], dtype=torch.long, device=input_ids.device) - - # Now need extend input_ids with chosen_ids - chosen_ids = chosen_ids.unsqueeze(0) - candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) - return candidate_input_ids - - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): - """ - Updates the candidate generation strategy based on the outcomes. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): - Prediction scores of a language modeling head. These can be logits for each vocabulary when not using - beam search or log softmax for each vocabulary token when using beam search - num_matches (`int`): - The number of matches between the candidate sequences and the model predictions. - """ - # Currently does nothing - return - - -class AssistedCandidateGenerator(CandidateGenerator): - """ - `CandidateGenerator` class to be used for assisted generation. This class generates candidates through the use of - a smaller model. Read the following blog post for more information: https://huggingface.co/blog/assisted-generation - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - assistant_model (`PreTrainedModel`): - The model to be used for generating candidates. This model should be smaller than the main model. - logits_processor (`LogitsProcessorList`): - An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`] - used to modify the prediction scores of the language modeling head applied at each generation step. - model_kwargs (`Dict`): - The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant - model as well. - inputs_tensor (`torch.Tensor`, *optional*): - The model input tensor. In encoder-decoder models, this is the encoder input. - eos_token_id (`Union[int, List[int]]`, *optional*): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - """ - - def __init__( - self, - input_ids: torch.LongTensor, - assistant_model: "PreTrainedModel", - logits_processor: "LogitsProcessorList", - model_kwargs: Dict, - inputs_tensor: Optional[torch.Tensor] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - ): - self.assistant_model = assistant_model - - # Prepare the number of candidate tokens - if hasattr(assistant_model, "num_assistant_tokens"): - warnings.warn( - "Setting `num_assistant_tokens` via `assistant_model.num_assistant_tokens` is deprecated and will be " - "removed in v4.37. Make sure to set `num_assistant_tokens` via the generation_config instead.", - FutureWarning, - ) - self.num_assistant_tokens = assistant_model.num_assistant_tokens - else: - self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens - - # Prepare the kwargs for the assistant model - assistant_kwargs = {} - for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads - if key not in ("encoder_outputs", "assistant_encoder_outputs"): - assistant_kwargs[key] = value.detach() if isinstance(value, torch.Tensor) else copy.deepcopy(value) - - if "assistant_encoder_outputs" in model_kwargs: - assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] - elif assistant_model.config.is_encoder_decoder: - inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs( - inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs - ) - assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation( - inputs_tensor, assistant_kwargs, model_input_name - ) - elif "encoder_outputs" in model_kwargs: - assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"] - self.assistant_kwargs = assistant_kwargs - - # Prepare assistant model's keys of inputs - if assistant_model.config.is_encoder_decoder: - # both are encoder-decoder - self.input_ids_key = "decoder_input_ids" - self.attention_key = "decoder_attention_mask" - elif "encoder_outputs" in assistant_kwargs: - # special case for encoder-decoder with decoder-only assistant (like DistilWhisper) - self.input_ids_key = "input_ids" - self.attention_key = "attention_mask" - self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get( - "decoder_attention_mask", - torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long), - ) - else: - # both are decoder-only - self.input_ids_key = "input_ids" - self.attention_key = "attention_mask" - - # Prepare other attributes - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - self.eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None - ) - self.logits_processor = logits_processor - - def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: - """ - Fetches the candidates to be tried for the current input. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - - Return: - `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. - """ - # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length - # (which implicitly contains the number of accepted candidates from the previous round) - has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None - if has_past_key_values: - new_cur_len = input_ids.shape[-1] - - new_cache_size = new_cur_len - 1 - self.assistant_kwargs["past_key_values"] = _crop_past_key_values( - self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1 - ) # the assistant does not have the token after the last match, hence the -1 - - self.assistant_kwargs = _prepare_attention_mask( - self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder - ) - self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len) - - # 2. Forecast next N tokens using the assistant model. This `for` block can be replaced with a `.generate()` - # call if we decide to add `past_key_values` as a possible output of generate, as we need access to the - # assistant cache to secure strong speedups. - candidate_input_ids = input_ids - for _ in range(int(self.num_assistant_tokens)): - # 2.1 prepare assistant model inputs - assistant_inputs = self.assistant_model.prepare_inputs_for_generation( - candidate_input_ids, - **self.assistant_kwargs, - ) - - # 2.2. check if the input ids length is correct - has_past_key_values = assistant_inputs.get("past_key_values", None) is not None - if has_past_key_values and assistant_inputs[self.input_ids_key].shape[-1] not in (1, 2): - raise ValueError("The length of the input ids in assistant inputs should be 1 or 2") - - # 2.3. use the assistant model to obtain the next candidate logits - assistant_model_outputs = self.assistant_model(**assistant_inputs) - - # 2.4. greedily select the next candidate token - if len(self.logits_processor) > 0: - assistant_model_outputs.logits[:, -1, :] = self.logits_processor( - candidate_input_ids, assistant_model_outputs.logits[:, -1, :] - ) - new_token = assistant_model_outputs.logits[:, -1, :].argmax(dim=-1) - candidate_input_ids = torch.cat((candidate_input_ids, new_token[:, None]), dim=-1) - - # 2.5. update assistant model inputs - if self.assistant_kwargs.get(self.attention_key, None) is not None: - mask = self.assistant_kwargs[self.attention_key] - self.assistant_kwargs[self.attention_key] = torch.cat( - [mask, mask.new_ones((mask.shape[0], 1))], dim=-1 - ) - self.assistant_kwargs["past_key_values"] = assistant_model_outputs.past_key_values - - # 2.6. stop assistant generation on EOS - if self.eos_token_id_tensor is not None: - last_assistant_token_is_eos = new_token.tile(self.eos_token_id_tensor.shape[0], 1) - last_assistant_token_is_eos = ( - ~last_assistant_token_is_eos.ne(self.eos_token_id_tensor.unsqueeze(1)).prod(dim=0).bool() - ) - if last_assistant_token_is_eos: - break - - return candidate_input_ids - - def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): - """ - Updates the candidate generation strategy based on the outcomes. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): - Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`): - Prediction scores of a language modeling head. These can be logits for each vocabulary when not using - beam search or log softmax for each vocabulary token when using beam search - num_matches (`int`): - The number of matches between the candidate sequences and the model predictions. - """ - # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic, - # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the - # cost of forecasting incorrect assistant tokens. - if self.assistant_model.generation_config.num_assistant_tokens_schedule == "heuristic": - if num_matches == int(self.num_assistant_tokens): - self.num_assistant_tokens += 2.0 - else: - self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0) - - -def _crop_past_key_values(model, past_key_values, maximum_length): - """Crops the past key values up to a certain maximum length.""" - new_past = [] - if model.config.is_encoder_decoder: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - past_key_values[idx][2], - past_key_values[idx][3], - ) - ) - past_key_values = tuple(new_past) - # bloom is special - elif "bloom" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "bloom" in model.config.architectures[0].lower() - ): - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length], - past_key_values[idx][1][:, :maximum_length, :], - ) - ) - past_key_values = tuple(new_past) - # gptbigcode is too - elif "gptbigcode" in model.__class__.__name__.lower() or ( - model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() - ): - if model.config.multi_query: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :maximum_length, :] - else: - for idx in range(len(past_key_values)): - past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] - else: - for idx in range(len(past_key_values)): - new_past.append( - ( - past_key_values[idx][0][:, :, :maximum_length, :], - past_key_values[idx][1][:, :, :maximum_length, :], - ) - ) - past_key_values = tuple(new_past) - return past_key_values - - -def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]: - """Expands or crops the model's mask for decoding purposes, to the defined length""" - - mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask" - if mask_key not in model_kwargs: - return model_kwargs - - mask = model_kwargs[mask_key] - mask_length_diff = new_length - mask.shape[1] - - if mask_length_diff < 0: - model_kwargs[mask_key] = mask[:, :mask_length_diff] - elif mask_length_diff > 0: - model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1) - return model_kwargs - - -def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: - """Expands or crops the model's token_type_ids for decoding purposes, to the defined length""" - if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: - return model_kwargs - - token_type_ids = model_kwargs["token_type_ids"] - final_token_type = token_type_ids[:, -1].unsqueeze(-1) - type_length_diff = new_length - token_type_ids.shape[1] - - if type_length_diff < 0: - token_type_ids = token_type_ids[:, :type_length_diff] - elif type_length_diff > 0: - token_type_copies = final_token_type.repeat(1, type_length_diff) - model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1) - return model_kwargs diff --git a/src/transformers/generation/configuration_utils.py b/src/transformers/generation/configuration_utils.py index 7cb9f2c150db..71ee0e60fcd1 100644 --- a/src/transformers/generation/configuration_utils.py +++ b/src/transformers/generation/configuration_utils.py @@ -321,8 +321,7 @@ def __init__(self, **kwargs): self.num_assistant_tokens_schedule = kwargs.pop("num_assistant_tokens_schedule", "heuristic") # Prompt lookup decoding - self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", 10) - self.prompt_lookup_max_matching_ngram = kwargs.pop("prompt_lookup_max_matching_ngram", 3) + self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) # Wild card self.generation_kwargs = kwargs.pop("generation_kwargs", {}) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 537bd86449cd..c8fb7e0994ea 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -37,7 +37,7 @@ from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer -from .candidates import ( +from .candidate_generator import ( AssistedCandidateGenerator, PromptLookupCandidateGenerator, CandidateGenerator, @@ -909,17 +909,11 @@ def _get_candidate_generator( """ Returns the candidate generator to be used in `assisted_generation` """ - # Check if assistant_model is a string - if isinstance(assistant_model, str): - if assistant_model == "prompt_lookup": - candidate_generator = PromptLookupCandidateGenerator( - num_output_tokens=generation_config.prompt_lookup_num_tokens, - max_matching_ngram_size=generation_config.prompt_lookup_max_matching_ngram, - ) - else: - raise NotImplementedError( - f"{assistant_model} is not implemented." - ) + if generation_config.prompt_lookup_num_tokens is not None: + print("Using PromptLookupCandidateGenerator") + candidate_generator = PromptLookupCandidateGenerator( + num_output_tokens=generation_config.prompt_lookup_num_tokens, + ) else: candidate_generator = AssistedCandidateGenerator( input_ids=input_ids, @@ -1008,7 +1002,7 @@ def _get_generation_mode( generation_mode = GenerationMode.BEAM_SEARCH # Assisted generation may extend some generation modes - if assistant_model is not None: + if assistant_model is not None or generation_config.prompt_lookup_num_tokens is not None: if generation_mode in ("greedy_search", "sample"): generation_mode = GenerationMode.ASSISTED_GENERATION else: @@ -4439,7 +4433,7 @@ def assisted_decoding( The sequence used as a prompt for the generation. candidate_generator (`CandidateGenerator`, *optional*): A derived instance of [`CandidateGenerator`] that defines how candidate sequences are generated. For - more information, the documentation of [`CandidateGenerator`] should be read. + more information, the documentation of [`CandidateGenerator`] should be read. Only one of `assistant_model` or `candidate_generator` should be passed as input to this function. assistant_model (`PreTrainedModel`, *optional*): An assistant model that can be used to accelerate generation. The assistant model must have the exact same tokenizer. The acceleration is achieved when forecasting candidate tokens with the assistent model @@ -4598,7 +4592,7 @@ def assisted_decoding( cur_len = input_ids.shape[-1] # 1. Fetch candidate sequences from a `CandidateGenerator` - candidate_input_ids = candidate_generator.get_candidates(input_ids) + candidate_input_ids, candidate_logits = candidate_generator.get_candidates(input_ids) candidate_length = candidate_input_ids.shape[1] - input_ids.shape[1] last_assistant_token_is_eos = ( ~candidate_input_ids[:, -1] @@ -4917,4 +4911,4 @@ def _ranking_fast( contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] - return selected_idx + return selected_idx \ No newline at end of file From beb95ba256883db15704612e0622b8514042c1fb Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Mon, 18 Dec 2023 02:51:59 +0530 Subject: [PATCH 07/17] removed print --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c8fb7e0994ea..97f17f7356b7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -910,7 +910,6 @@ def _get_candidate_generator( Returns the candidate generator to be used in `assisted_generation` """ if generation_config.prompt_lookup_num_tokens is not None: - print("Using PromptLookupCandidateGenerator") candidate_generator = PromptLookupCandidateGenerator( num_output_tokens=generation_config.prompt_lookup_num_tokens, ) From 22ef5b25bd308ab0c97d2eddd3f7c1c629bcf363 Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Mon, 1 Jan 2024 20:44:16 +0530 Subject: [PATCH 08/17] style fixes --- .../generation/candidate_generator.py | 16 +++++++--------- src/transformers/generation/utils.py | 4 ++-- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index ac9534726abb..41f55bf871aa 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -242,24 +242,22 @@ def __init__( if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") - - def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: """ Fetches the candidates to be tried for the current input. - + Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids) - + Return: `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ input_length = input_ids.size(1) if input_length < self.max_matching_ngram_size: raise ValueError("Input length is smaller than max_matching_ngram_size for Prompt Lookup Decoding") - + chosen_ids = None match_found = False for ngram_size in range(self.max_matching_ngram_size, 0, -1): @@ -288,16 +286,16 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: if match_found: break - if chosen_ids == None or len(chosen_ids) == 0: + if chosen_ids is None or len(chosen_ids) == 0: # Need to make a dummy tensor to avoid errors chosen_ids = torch.tensor([0], dtype=torch.long, device=input_ids.device) - + # Now need extend input_ids with chosen_ids chosen_ids = chosen_ids.unsqueeze(0) candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) # assisted_generation expects logits as well, but we don't have those here, so returning empty list - return candidate_input_ids, [] - + return candidate_input_ids, [] + def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): """ Updates the candidate generation strategy based on the outcomes. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 97f17f7356b7..5babc5aec0d4 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -39,8 +39,8 @@ from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer from .candidate_generator import ( AssistedCandidateGenerator, - PromptLookupCandidateGenerator, CandidateGenerator, + PromptLookupCandidateGenerator, _crop_past_key_values, _prepare_attention_mask, _prepare_token_type_ids, @@ -4910,4 +4910,4 @@ def _ranking_fast( contrastive_score = (1.0 - alpha) * next_top_k_probs - alpha * degeneration_penalty contrastive_score = torch.stack(torch.split(contrastive_score, beam_width)) # [B, K] _, selected_idx = contrastive_score.max(dim=-1) # [B] - return selected_idx \ No newline at end of file + return selected_idx From 8147955278e5f3e380acc7edb426db515722afbb Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Mon, 1 Jan 2024 20:58:06 +0530 Subject: [PATCH 09/17] fix test --- src/transformers/generation/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5babc5aec0d4..c58ff9533baa 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -920,7 +920,6 @@ def _get_candidate_generator( logits_processor=logits_processor, model_kwargs=model_kwargs, inputs_tensor=inputs_tensor, - eos_token_id=generation_config.eos_token_id, ) return candidate_generator From 8b16de071d01a03682e0e3ccedf3011967cedc37 Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Mon, 1 Jan 2024 21:23:11 +0530 Subject: [PATCH 10/17] fixed tests --- src/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c58ff9533baa..facd30f1c469 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -917,6 +917,7 @@ def _get_candidate_generator( candidate_generator = AssistedCandidateGenerator( input_ids=input_ids, assistant_model=assistant_model, + generation_config=generation_config, logits_processor=logits_processor, model_kwargs=model_kwargs, inputs_tensor=inputs_tensor, From d0ab6d08d20531f79a8dea1d1807f4bfd1df77ad Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Tue, 2 Jan 2024 00:04:15 +0530 Subject: [PATCH 11/17] added test for prompt lookup decoding --- tests/generation/test_utils.py | 60 ++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 973f54f00397..d6ef9ddca85d 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1569,6 +1569,66 @@ def test_assisted_decoding_matches_greedy_search(self): for output in (output_greedy, output_assisted): self._check_outputs(output, input_ids, model.config, use_cache=True) + @is_flaky() + def test_prompt_lookup_decoding_matches_greedy_search(self): + # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. + # This test is mostly a copy of test_assisted_decoding_matches_greedy_search + + for model_class in self.all_generative_model_classes: + if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]): + self.skipTest("Won't fix: old model with different cache format") + if any( + model_name in model_class.__name__.lower() + for model_name in [ + "bigbirdpegasus", + "led", + "mega", + "speech2text", + "git", + "prophetnet", + "seamlessm4t", + "clvp", + ] + ): + self.skipTest("May fix in the future: need model-specific fixes") + + # enable cache + config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1) + + # NOTE: assisted generation only works with cache on at the moment. + if not hasattr(config, "use_cache"): + self.skipTest("This model doesn't support caching") + + config.use_cache = True + config.is_decoder = True + model = model_class(config).to(torch_device).eval() + # Sets assisted generation arguments such that: + # a) no EOS is generated, to ensure generation doesn't break early + # b) the prompt lookup tries to give the model 2 tokens, to ensure the input preparation of + # prompt lookup is correct + # c) there are at least two forward passes in the main model, to ensure the input preparation of + # the main model is correct + generation_kwargs = { + "eos_token_id": -1, # see a) + "max_new_tokens": 4, # see c) + "num_beams": 1, + "do_sample": False, + "output_scores": True, + "output_hidden_states": True, + "output_attentions": True, + "return_dict_in_generate": True, + } + + output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) + output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) + + # The two outputs must match and their shape must be as expected + self.assertListEqual(output_greedy.sequences.tolist(), output_prompt_lookup.sequences.tolist()) + for output in (output_greedy, output_prompt_lookup): + self._check_outputs(output, input_ids, model.config, use_cache=True) + def test_assisted_decoding_sample(self): # In this test we don't check assisted vs non-assisted output -- seeded assisted decoding with sample will not # match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with From d9264866c1dc89d378432ba9ec53f1f982394f67 Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Tue, 2 Jan 2024 00:09:41 +0530 Subject: [PATCH 12/17] fixed circleci --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d6ef9ddca85d..bfae5e882778 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1569,7 +1569,7 @@ def test_assisted_decoding_matches_greedy_search(self): for output in (output_greedy, output_assisted): self._check_outputs(output, input_ids, model.config, use_cache=True) - @is_flaky() + @is_flaky() def test_prompt_lookup_decoding_matches_greedy_search(self): # This test ensures that the prompt lookup generation does not introduce output changes over greedy search. # This test is mostly a copy of test_assisted_decoding_matches_greedy_search @@ -1621,7 +1621,7 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) - generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) + generation_kwargs.update({"prompt_lookup_num_tokens": 2}) # see b) output_prompt_lookup = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs) # The two outputs must match and their shape must be as expected From 371fce1fe04f52db84c4ca3f6bc054b1ae1092f7 Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Tue, 2 Jan 2024 00:37:51 +0530 Subject: [PATCH 13/17] fixed test issue --- src/transformers/generation/candidate_generator.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 41f55bf871aa..11059167f716 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -255,12 +255,10 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried. """ input_length = input_ids.size(1) - if input_length < self.max_matching_ngram_size: - raise ValueError("Input length is smaller than max_matching_ngram_size for Prompt Lookup Decoding") chosen_ids = None match_found = False - for ngram_size in range(self.max_matching_ngram_size, 0, -1): + for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1): # Create sliding windows of size ngram_size windows = input_ids.unfold(dimension=1, size=ngram_size, step=1) From 942141237fc9d58134b1ddd14ca490ab2e1feacb Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Wed, 10 Jan 2024 11:34:35 +0530 Subject: [PATCH 14/17] Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante --- src/transformers/generation/candidate_generator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 11059167f716..1e300071a4d6 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -291,8 +291,8 @@ def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: # Now need extend input_ids with chosen_ids chosen_ids = chosen_ids.unsqueeze(0) candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1) - # assisted_generation expects logits as well, but we don't have those here, so returning empty list - return candidate_input_ids, [] + # assisted_generation expects logits as well, but we don't have those here, so returning None + return candidate_input_ids, None def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int): """ From deaf783b86bccb5b9b7b47ede7f18c3ef2348952 Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Wed, 10 Jan 2024 11:34:44 +0530 Subject: [PATCH 15/17] Update src/transformers/generation/candidate_generator.py Co-authored-by: Joao Gante --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 1e300071a4d6..93456df5006a 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -243,7 +243,7 @@ def __init__( if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") - def get_candidates(self, input_ids: torch.LongTensor) -> torch.LongTensor: + def get_candidates(self, input_ids: torch.LongTensor) -> -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. From 3980b3a5b3077090ac08109e58392fb0712270a1 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Wed, 10 Jan 2024 09:17:31 +0000 Subject: [PATCH 16/17] Update src/transformers/generation/candidate_generator.py --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 93456df5006a..10de749d82c9 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -243,7 +243,7 @@ def __init__( if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0: raise ValueError("Invalid max_matching_ngram_size or num_output_tokens") - def get_candidates(self, input_ids: torch.LongTensor) -> -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: + def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]: """ Fetches the candidates to be tried for the current input. From fd3d6c41f9d68afb801a8fad6a0b22d10106357b Mon Sep 17 00:00:00 2001 From: Apoorv Saxena Date: Sat, 13 Jan 2024 00:01:11 +0530 Subject: [PATCH 17/17] Update src/transformers/generation/candidate_generator.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- src/transformers/generation/candidate_generator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 10de749d82c9..ca83f460a5f8 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -286,7 +286,7 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, if chosen_ids is None or len(chosen_ids) == 0: # Need to make a dummy tensor to avoid errors - chosen_ids = torch.tensor([0], dtype=torch.long, device=input_ids.device) + chosen_ids = torch.zeros((1), dtype=torch.long, device=input_ids.device) # Now need extend input_ids with chosen_ids chosen_ids = chosen_ids.unsqueeze(0)