From e3c2849fadf0cb482e6d4dd9f12ed0b1699f661c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 23 Sep 2024 11:55:14 -0600 Subject: [PATCH 01/16] feat(models): Add models.json blocks for Granite Code 3b and 8b Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/model_config/models.json | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json index 2d3dfcbeb..dfa29b6bd 100644 --- a/torchchat/model_config/models.json +++ b/torchchat/model_config/models.json @@ -164,5 +164,17 @@ "https://github.com/karpathy/llama2.c/raw/master/tokenizer.model" ], "checkpoint_file": "stories110M.pt" + }, + "ibm-granite/granite-3b-code-instruct-128k": { + "aliases": ["granite-code", "granite-code-3b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-3b-code-instruct-128k", + "transformer_params_key": "3B" + }, + "ibm-granite/granite-8b-code-instruct-128k": { + "aliases": ["granite-code-8b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-8b-code-instruct-128k", + "transformer_params_key": "8B" } } From 9d80e523868a83b4cce909efb3b7ed0592d6ec65 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 25 Sep 2024 14:02:06 -0600 Subject: [PATCH 02/16] feat: Initial model params for granite code 3b Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/model_params/Granite-3B-Code.json | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 torchchat/model_params/Granite-3B-Code.json diff --git a/torchchat/model_params/Granite-3B-Code.json b/torchchat/model_params/Granite-3B-Code.json new file mode 100644 index 000000000..4af6ef42c --- /dev/null +++ b/torchchat/model_params/Granite-3B-Code.json @@ -0,0 +1,16 @@ +{ + "block_size": 128000, + "dim": 2560, + "hidden_dim": 10240, + "n_heads": 32, + "n_local_heads": 32, + "n_layers": 32, + "rope_base": 10000000, + "vocab_size": 49152, + "use_tokenizers": true, + "norm_eps": 0.00001, + "rope_scaling": null, + "attention_bias": true, + "feed_forward_bias": true, + "tie_word_embeddings": true +} \ No newline at end of file From 85057bc05d08e8bb101a567d55323588b63fc8eb Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 31 Oct 2024 11:05:55 -0600 Subject: [PATCH 03/16] fix(model config): Fix model configs for Granite Code * Use the right tokenizer_file name * Use the right transformer_params_key based on the file name in model_params * Use the updated name to indicate HF tokenizers Signed-off-by: Gabe Goodhart --- torchchat/model_config/models.json | 6 ++++-- torchchat/model_params/Granite-3B-Code.json | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json index dfa29b6bd..8791601fb 100644 --- a/torchchat/model_config/models.json +++ b/torchchat/model_config/models.json @@ -169,12 +169,14 @@ "aliases": ["granite-code", "granite-code-3b"], "distribution_channel": "HuggingFaceSnapshot", "distribution_path": "ibm-granite/granite-3b-code-instruct-128k", - "transformer_params_key": "3B" + "transformer_params_key": "Granite-3B-Code", + "tokenizer_file": "tokenizer.json" }, "ibm-granite/granite-8b-code-instruct-128k": { "aliases": ["granite-code-8b"], "distribution_channel": "HuggingFaceSnapshot", "distribution_path": "ibm-granite/granite-8b-code-instruct-128k", - "transformer_params_key": "8B" + "transformer_params_key": "Granite-8B-Code", + "tokenizer_file": "tokenizer.json" } } diff --git a/torchchat/model_params/Granite-3B-Code.json b/torchchat/model_params/Granite-3B-Code.json index 4af6ef42c..3ec1c615a 100644 --- a/torchchat/model_params/Granite-3B-Code.json +++ b/torchchat/model_params/Granite-3B-Code.json @@ -7,7 +7,7 @@ "n_layers": 32, "rope_base": 10000000, "vocab_size": 49152, - "use_tokenizers": true, + "use_hf_tokenizer": true, "norm_eps": 0.00001, "rope_scaling": null, "attention_bias": true, From 1e3addc73c437d8e7ee2634b1f90df2281164cc8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 31 Oct 2024 12:41:28 -0600 Subject: [PATCH 04/16] feat(granite): Add model params for granite-code-8b Something isn't quite working with this model yet, but the config should be accurate at this point. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/model_params/Granite-8B-Code.json | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) create mode 100644 torchchat/model_params/Granite-8B-Code.json diff --git a/torchchat/model_params/Granite-8B-Code.json b/torchchat/model_params/Granite-8B-Code.json new file mode 100644 index 000000000..71a7d9201 --- /dev/null +++ b/torchchat/model_params/Granite-8B-Code.json @@ -0,0 +1,16 @@ +{ + "block_size": 128000, + "dim": 4096, + "hidden_dim": 14336, + "n_heads": 32, + "n_local_heads": 8, + "n_layers": 36, + "rope_base": 10000000, + "vocab_size": 49152, + "use_hf_tokenizer": true, + "norm_eps": 0.00001, + "rope_scaling": null, + "attention_bias": true, + "feed_forward_bias": true, + "tie_word_embeddings": true +} \ No newline at end of file From 5d342fa2ae45d875c7dbf9c646f0ec70680c66c8 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 6 Nov 2024 17:19:39 -0700 Subject: [PATCH 05/16] fix(deps): Add tokenizers to the deps explicitly It was implicitly being pulled in via lm_eval -> transformers, but it's better to have it explicit since we use it directly Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- install/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/install/requirements.txt b/install/requirements.txt index 8fb1832ba..e1bb02f8f 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -9,6 +9,9 @@ gguf # Tiktoken tokenizer for Llama 3 and other advanced models tiktoken +# Tokenizers for other non-llama models that use HF tokenizers +tokenizers + # Miscellaneous snakeviz sentencepiece From dadff616d993230266a460ae40b288e025e65dfe Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 6 Nov 2024 17:21:01 -0700 Subject: [PATCH 06/16] feat(tokenizer): Add basic support for jinja2 template rendering for HF tokenizers This is a much simplified version of the corresponding logic in transformers. I opted for this so that the full transformers dependency is not added here. CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1522 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- tokenizer/hf_tokenizer.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/tokenizer/hf_tokenizer.py b/tokenizer/hf_tokenizer.py index 7ad5807d1..d10ecb076 100644 --- a/tokenizer/hf_tokenizer.py +++ b/tokenizer/hf_tokenizer.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. # Standard -from typing import List, Optional +from typing import Dict, List, Optional import json import os # Third Party +import jinja2 from tokenizers import Tokenizer # Local @@ -37,6 +38,9 @@ def __init__(self, file_path: str): # Load the tokenizer itself self._tokenizer = Tokenizer.from_file(tokenizer_path) + # Load the chat template if we have a config path + self._chat_template: Optional[jinja2.Template] = None + # If available, parse bos/eos tokens from the tokenizer config self._bos_id, self._eos_id = None, None if tokenizer_config_path is not None: @@ -48,6 +52,8 @@ def __init__(self, file_path: str): self._bos_id = self._tokenizer.token_to_id(bos_token) if eos_token is not None: self._eos_id = self._tokenizer.token_to_id(eos_token) + if chat_template_str := tok_config.get("chat_template"): + self._chat_template = jinja2.Template(chat_template_str) # If no eos/bos tokens found, go looking for them! if None in [self._bos_id, self._eos_id]: @@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio if len(candidate_toks) == 1: return candidate_toks[0]["id"] + ## Interface ## + def encode( self, s: str, @@ -90,3 +98,21 @@ def bos_id(self) -> int: def eos_id(self) -> int: return self._eos_id + + ## Additional Public Methods ## + + def has_chat_template(self) -> bool: + return bool(self._chat_template) + + def apply_chat_template( + self, + dialog: List[Dict[str, str]], + add_generation_prompt: bool = False, + ) -> str: + """If configured with a chat template, apply it to the list of messages + """ + if not self._chat_template: + raise ValueError("No chat template configured!") + return self._chat_template.render( + messages=dialog, add_generation_prompt=add_generation_prompt + ) From 2017e5d3a64d79da87e35a22bfde9b250bfe2a0e Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 6 Nov 2024 17:22:30 -0700 Subject: [PATCH 07/16] fix(chat): Add HFTokenizerChatFormatter and use it for HF tokenizers This will allow the jinja2 templates for HF tokenizers to be applied without needing to hard-code the formatter logic. This will likely need to be duplicated in the embedded code version of chat. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/generate.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/torchchat/generate.py b/torchchat/generate.py index 9b4c6430a..807633b36 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -125,6 +125,15 @@ def encode_dialog_prompt(self, dialog) -> List[int]: return tokens +class HFTokenizerChatFormatter(_ChatFormatter): + """Chat formatter that uses the built-in formatting capabilities of an HF + tokenizer instance + """ + def encode_dialog_prompt(self, dialog) -> List[int]: + rendered = self.tokenizer.apply_chat_template(dialog, add_generation_prompt=True) + return self.tokenizer.encode(rendered) + + @dataclass class GeneratorArgs: prompt: Optional[str] = ( @@ -286,6 +295,10 @@ def __init__( logging.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) + elif self.tokenizer_args.is_hf_tokenizer: + if not self.tokenizer.has_chat_template(): + raise ValueError("Tokenizer must have a chat template") + self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer) else: self.chat_formatter = Llama2ChatFormatter(self.tokenizer) From 38a649af3d361028df554cc612149a30284d8e00 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 6 Nov 2024 17:31:15 -0700 Subject: [PATCH 08/16] fix(deps): Add jinja2 as an explicit dep It was getting pulled in implicitly via flask and lm_eval -> transformers, but better to have it explicit. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- install/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/install/requirements.txt b/install/requirements.txt index e1bb02f8f..457131275 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -9,8 +9,9 @@ gguf # Tiktoken tokenizer for Llama 3 and other advanced models tiktoken -# Tokenizers for other non-llama models that use HF tokenizers +# Tokenizers and jinja2 for other non-llama models that use HF tokenizers tokenizers +jinja2 # Miscellaneous snakeviz From 0b4f159e78e187d751026d5b15319a90b6be7631 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 20 Nov 2024 13:19:26 -0700 Subject: [PATCH 09/16] feat(log): Add env-based LOG_LEVEL config to CLI Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/cli/cli.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index a7f7bbba2..91bdcaf26 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -17,7 +17,15 @@ allowable_params_table, ) -logging.basicConfig(level=logging.INFO, format="%(message)s") +_log_level_env = os.getenv("LOG_LEVEL", "INFO") +try: + _log_level = getattr(logging, _log_level_env.upper()) +except AttributeError: + print(f"Invalid log level: {_log_level_env}", file=sys.stderr) + _log_level = logging.INFO + + +logging.basicConfig(level=_log_level, format="%(message)s") logger = logging.getLogger(__name__) default_device = os.getenv("TORCHCHAT_DEVICE", "fast") From 526ce15416c16033827198fbf24f477bb4d75452 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 20 Nov 2024 13:21:01 -0700 Subject: [PATCH 10/16] feat(log): Add better logging in model and generate In generate, there were a number of commented-out log lines. These are safe to leave in as long as lazy string interpolation is used. Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/generate.py | 22 ++++++++++++++-------- torchchat/model.py | 7 ++++++- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 807633b36..6f2e7b062 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -45,6 +45,9 @@ from torchchat.utils.device_info import get_device_info +logger = logging.getLogger(__name__) + + class _ChatFormatter(ABC): def __init__(self, tokenizer): self.tokenizer = tokenizer @@ -292,7 +295,7 @@ def __init__( if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: - logging.debug( + logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) elif self.tokenizer_args.is_hf_tokenizer: @@ -354,10 +357,12 @@ def sample( temperature: float = 0, top_k: Optional[int] = None, ): + logits = logits[0, -1] + logger.debug("Logits: %s", logits) if temperature == 0 and not need_probs: - _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1) + _, idx_next = torch.topk(logits, k=1, dim=-1) return (idx_next, None) - probs = self.logits_to_probs(logits[0, -1], temperature, top_k) + probs = self.logits_to_probs(logits, temperature, top_k) idx_next = self.multinomial_sample_one_no_sync(probs) return idx_next, probs @@ -371,7 +376,7 @@ def prefill( sequential_prefill=True, **sampling_kwargs, ) -> torch.Tensor: - # logging.debug(f"x: {x}, input_pos: {input_pos}") + logger.debug("x: %s, input_pos: %s", x, input_pos) width = x.size(1) assert input_pos.size(0) == width @@ -407,7 +412,7 @@ def prefill( elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) - # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") + logger.debug(" x: %s, input_pos: %s", x_sliced, ip_sliced) logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])da else: # input_pos: [B, S] @@ -740,7 +745,7 @@ def encode_tokens(self, string, bos=True, device="cpu"): tokens = self.tokenizer.encode(string) if bos: tokens = [self.tokenizer.bos_id()] + tokens - logging.debug(f"Size after encode_tokens: {len(tokens)}") + logger.debug("Size after encode_tokens: %d", len(tokens)) return torch.tensor(tokens, dtype=torch.int, device=device) def _callback(self, x, *, buffer, done_generating): @@ -798,7 +803,7 @@ def _gen_model_input( tokens, dtype=torch.int, device=self.builder_args.device ) - logging.debug(encoded) + logger.debug(encoded) return encoded, None # Llama 3.2 11B @@ -913,7 +918,7 @@ def _gen_model_input( value=0, ) - logging.debug(encoded) + logger.debug(encoded) return encoded, batch def chat( @@ -1244,6 +1249,7 @@ def main(args): speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) generator_args = GeneratorArgs.from_args(args) + logger.debug("GeneratorArgs: %s", generator_args) if not builder_args.distributed: gen = Generator( builder_args, diff --git a/torchchat/model.py b/torchchat/model.py index 2a3b9f12f..479128c61 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json +import logging import os import warnings from abc import ABC, abstractmethod @@ -48,6 +49,8 @@ config_path = Path(f"{str(Path(__file__).parent)}/model_params") +logger = logging.getLogger(__name__) + class QuickGELUActivation(nn.Module): """ @@ -477,7 +480,9 @@ def build_model(self) -> nn.Module: for name, module_class in recipe.modules.items(): config_args = self.config.transformer_args[name] if module_class == Transformer: - modules[name] = module_class(TransformerArgs.from_params(config_args)) + transformer_args = TransformerArgs.from_params(config_args) + logger.debug("Transformer Args: %s", transformer_args) + modules[name] = module_class(transformer_args) else: modules[name] = module_class(**config_args) From c9f8a7143a6353186296110372b9a36ed54ac629 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 21 Nov 2024 10:49:20 -0700 Subject: [PATCH 11/16] feat(generate): Make prepending BOS model-conigurable And disable it for Granite Code models Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/generate.py | 5 +++-- torchchat/model.py | 13 ++++++++++++- torchchat/model_params/Granite-3B-Code.json | 1 + torchchat/model_params/Granite-8B-Code.json | 1 + 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 6f2e7b062..a00391ec0 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -746,6 +746,7 @@ def encode_tokens(self, string, bos=True, device="cpu"): if bos: tokens = [self.tokenizer.bos_id()] + tokens logger.debug("Size after encode_tokens: %d", len(tokens)) + logger.debug("Token IDs: %s", tokens) return torch.tensor(tokens, dtype=torch.int, device=device) def _callback(self, x, *, buffer, done_generating): @@ -794,7 +795,7 @@ def _gen_model_input( # Single String prompt if isinstance(prompt, str): encoded = self.encode_tokens( - prompt, bos=True, device=self.builder_args.device + prompt, bos=self.model.config.tokenizer_prepend_bos, device=self.builder_args.device ) # List of dialog else: @@ -1048,7 +1049,7 @@ def chat( else: prompt = f"{B_INST} {prompt.strip()} {E_INST}" encoded = self.encode_tokens( - prompt, bos=True, device=self.builder_args.device + prompt, bos=self.model.config.tokenizer_prepend_bos, device=self.builder_args.device ) else: if self.system_prompt: diff --git a/torchchat/model.py b/torchchat/model.py index 479128c61..1c78d4c63 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -276,6 +276,7 @@ class TransformerArgs: # Select the desired tokenizer. Defaults to sentencepiece use_tiktoken: bool = False use_hf_tokenizer: bool = False + tokenizer_prepend_bos: bool = True max_seq_length: int = 8192 rope_scaling: Optional[Dict[str, Any]] = None # For pipeline parallel @@ -333,6 +334,7 @@ class ModelArgs: transformer_args: Dict[str, Dict[str, Any]] use_tiktoken: bool use_hf_tokenizer: bool + tokenizer_prepend_bos: bool def __init__( self, @@ -340,6 +342,7 @@ def __init__( model_type: ModelType = ModelType.TextOnly, use_tiktoken: bool = False, use_hf_tokenizer: bool = False, + tokenizer_prepend_bos: bool = True, ) -> None: self._sanity_check(transformer_args, model_type) @@ -349,6 +352,7 @@ def __init__( # Model-level attributes self.use_tiktoken = use_tiktoken self.use_hf_tokenizer = use_hf_tokenizer + self.tokenizer_prepend_bos = tokenizer_prepend_bos def _sanity_check( self, @@ -376,7 +380,14 @@ def from_params(cls, params_path): use_tiktoken = loaded_params.get("use_tiktoken", False) use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False) - return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer) + tokenizer_prepend_bos = loaded_params.get("tokenizer_prepend_bos", True) + return cls( + transformer_args=transformer_args, + model_type=model_type, + use_tiktoken=use_tiktoken, + use_hf_tokenizer=use_hf_tokenizer, + tokenizer_prepend_bos=tokenizer_prepend_bos, + ) @classmethod def from_table(cls, name: str): diff --git a/torchchat/model_params/Granite-3B-Code.json b/torchchat/model_params/Granite-3B-Code.json index 3ec1c615a..0654a8f2c 100644 --- a/torchchat/model_params/Granite-3B-Code.json +++ b/torchchat/model_params/Granite-3B-Code.json @@ -8,6 +8,7 @@ "rope_base": 10000000, "vocab_size": 49152, "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, "norm_eps": 0.00001, "rope_scaling": null, "attention_bias": true, diff --git a/torchchat/model_params/Granite-8B-Code.json b/torchchat/model_params/Granite-8B-Code.json index 71a7d9201..079a32070 100644 --- a/torchchat/model_params/Granite-8B-Code.json +++ b/torchchat/model_params/Granite-8B-Code.json @@ -8,6 +8,7 @@ "rope_base": 10000000, "vocab_size": 49152, "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, "norm_eps": 0.00001, "rope_scaling": null, "attention_bias": true, From ef132cb55edd5e560542852954bf9364013fab20 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 21 Nov 2024 13:10:34 -0700 Subject: [PATCH 12/16] fix(chat): Refactor chat template logic to encapsulate all formatting in classes The formatted strings may not be perfectly 1:1 with the previous impl, but they should be in line with the official model guidelines: * https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3 * https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-2 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/generate.py | 161 +++++++++++++++++++++++++----------------- 1 file changed, 97 insertions(+), 64 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index a00391ec0..8a7e3725c 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -47,14 +47,39 @@ logger = logging.getLogger(__name__) +## Chat Formatters ############################################################# class _ChatFormatter(ABC): + + # Messages can arrive as a standard dict with "role" and "content" as + # strings, or where "content" is a list of objects with "text" fields. + MESSAGE_TYPE = Dict[str, Union[str, List[Dict[str, str]]]] + + # A dialog is a sequence of messages + DIALOG_TYPE = List[MESSAGE_TYPE] + def __init__(self, tokenizer): self.tokenizer = tokenizer @abstractmethod - def encode_dialog_prompt(self, dialog) -> List[int]: - raise NotImplementedError() + def encode_dialog_prompt( + self, + dialog: DIALOG_TYPE, + add_generation_prompt: bool, + ) -> List[int]: + """Encode a sequence of messages into a sequence of token IDs, including + the chat template + + Args: + dialog (DIALOG_TYPE): The sequence of dialog messages to encode. + This will be the additional messages on top of those that have + already been processed. + add_generation_prompt (bool): Whether to include a generation prompt + at the end of the encoded sequence. + + Returns: + List[int]: A list of token IDs representing the encoded prompt. + """ class Llama3ChatFormatter(_ChatFormatter): @@ -64,7 +89,7 @@ class Llama3ChatFormatter(_ChatFormatter): """ - def encode_header(self, role) -> List[int]: + def _encode_header(self, role) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) tokens.extend(self.tokenizer.encode(role, bos=False, eos=False)) @@ -72,8 +97,8 @@ def encode_header(self, role) -> List[int]: tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) return tokens - def encode_message(self, message) -> List[int]: - tokens = self.encode_header(message["role"]) + def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]: + tokens = self._encode_header(message["role"]) if isinstance(message["content"], str): tokens.extend( self.tokenizer.encode(message["content"], bos=False, eos=False) @@ -88,54 +113,79 @@ def encode_message(self, message) -> List[int]: tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) return tokens - def encode_dialog_prompt(self, dialog) -> List[int]: + def encode_dialog_prompt( + self, + dialog: _ChatFormatter.DIALOG_TYPE, + add_generation_prompt: bool, + ) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) for message in dialog: - tokens.extend(self.encode_message(message)) + tokens.extend(self._encode_message(message)) # Add the start of an assistant message for the model to complete. - tokens.extend(self.encode_header("assistant")) # Pass role directly as a string + if add_generation_prompt: + tokens.extend(self._encode_header("assistant")) # Pass role directly as a string return tokens -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>", "<>" +class Llama2ChatFormatter(_ChatFormatter): + """ + Chat formatting for Llama2 + CITE: https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/ + """ + + B_INST, E_INST = "[INST] ", " [/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + @staticmethod + def _get_content_str(message: _ChatFormatter.MESSAGE_TYPE) -> str: + if isinstance(message["content"], list): + return message["content"][0]["text"] + return message["content"] -class Llama2ChatFormatter(_ChatFormatter): - def encode_dialog_prompt(self, dialog) -> List[int]: - tokens = self.tokenizer.encode(f"{B_INST} ") - first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it. + def encode_dialog_prompt( + self, + dialog: _ChatFormatter.DIALOG_TYPE, + add_generation_prompt: bool, # UNUSED + ) -> List[int]: + new_turn = True + tokens = [] for message in dialog: - if isinstance(message["content"], list): - content = message["content"][0]["text"] + if new_turn: + tokens += self.tokenizer.encode(f"{self.tokenizer.bos}{self.B_INST}") + content = self._get_content_str(message).strip() + role = message["role"] + if role == "system": + tokens += self.tokenizer.encode(f"{self.B_SYS}{content}{self.E_SYS}") + new_turn = False + elif role == "user": + tokens += self.tokenizer.encode(f"{content}{self.E_INST}") + new_turn = False + elif role == "assistant": + tokens += self.tokenizer.encode(f" {content} {self.tokenizer.eos}\n") + new_turn = True else: - content = message["content"] - content = content.strip() - if message["role"] == "system": - encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}") - first_message = False - elif message["role"] == "user": - encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode( - f"{B_INST if first_message else ''} {content} {E_INST} " - ) - first_message = True - elif message["role"] == "assistant": - encoded = self.tokenizer.encode(f"{content}\n\n") + [ - self.tokenizer.eos_id() - ] - tokens += encoded + raise ValueError("Invalid role in dialog.") return tokens + class HFTokenizerChatFormatter(_ChatFormatter): """Chat formatter that uses the built-in formatting capabilities of an HF tokenizer instance """ - def encode_dialog_prompt(self, dialog) -> List[int]: - rendered = self.tokenizer.apply_chat_template(dialog, add_generation_prompt=True) + def encode_dialog_prompt( + self, + dialog: _ChatFormatter.DIALOG_TYPE, + add_generation_prompt: bool, + ) -> List[int]: + rendered = self.tokenizer.apply_chat_template( + dialog, add_generation_prompt=add_generation_prompt + ) + logger.debug("Formatted chat prompt:\n%s", rendered) return self.tokenizer.encode(rendered) +## Generation ################################################################## @dataclass class GeneratorArgs: @@ -1040,38 +1090,21 @@ def chat( if prompt == "/bye": print("Exiting Chat.\n") break - if not self.is_llama3_model: - if self.system_prompt: - prompt = f"{B_INST} {B_SYS}\n{self.system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip()} {E_INST}" - self.system_prompt = ( - None # can only provide system prompt on first interaction - ) - else: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = self.encode_tokens( - prompt, bos=self.model.config.tokenizer_prepend_bos, device=self.builder_args.device - ) - else: - if self.system_prompt: - encoded = self.chat_formatter.encode_dialog_prompt( - [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": prompt}, - ] - ) - self.system_prompt = None - elif is_first_sample: - encoded = self.chat_formatter.encode_dialog_prompt( - [{"role": "user", "content": prompt}] - ) - else: - encoded = self.chat_formatter.encode_message( - {"role": "user", "content": prompt} - ) - encoded.extend(self.chat_formatter.encode_header("assistant")) - encoded = torch.tensor( - encoded, dtype=torch.int, device=self.builder_args.device + + # Encode the additional messages added in this dialog turn. If + # this is the first turn, that includes any system prompt. + messages_to_encode = [] + if is_first_sample and self.system_prompt: + messages_to_encode.append( + {"role": "system", "content": self.system_prompt} ) + messages_to_encode.append({"role": "system", "content": prompt}) + encoded = self.chat_formatter.encode_dialog_prompt( + messages_to_encode, add_generation_prompt=True, + ) + encoded = torch.tensor( + encoded, dtype=torch.int, device=self.builder_args.device + ) if encoded.size(0) + start_pos > max_seq_length: print( "This prompt would take us past the max_seq_length. Ending Conversation." From 8d26923665afc09a3d37a2cad81f6ced3a1d8d3c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 21 Nov 2024 15:53:57 -0700 Subject: [PATCH 13/16] fix(chat): Fix small formatting bugs in llama3 chat formatter Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/generate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 8a7e3725c..57938e4fd 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -110,7 +110,7 @@ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]: self.tokenizer.encode(content["text"], bos=False, eos=False) ) - tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) + tokens.append(self.tokenizer.special_tokens["<|eot_id|>\n"]) return tokens def encode_dialog_prompt( @@ -123,8 +123,8 @@ def encode_dialog_prompt( for message in dialog: tokens.extend(self._encode_message(message)) # Add the start of an assistant message for the model to complete. - if add_generation_prompt: - tokens.extend(self._encode_header("assistant")) # Pass role directly as a string + if add_generation_prompt and dialog and dialog[-1]["role"] != "assistant": + tokens.extend(self._encode_header("assistant")) # Pass role directly as a string return tokens From 0390655f46bb6f6d11d1dff20811e5c06568e646 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 21 Nov 2024 15:54:56 -0700 Subject: [PATCH 14/16] test: Add initial unit tests for chat formatters There's no formal execution framework for pytest yet, but these were helpful in ensuring that the formatting was working correctly! To run them, install pytest and run `pytest tests/` Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- tests/conftest.py | 12 ++ tests/test_chat_formatters.py | 227 ++++++++++++++++++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 tests/conftest.py create mode 100644 tests/test_chat_formatters.py diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..c1580e27b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +""" +Global pytest config, fixtures, and helpers go here! +""" + +# Standard +import os +import sys + +# Make sure tests can import torchchat +sys.path.append( + os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) +) diff --git a/tests/test_chat_formatters.py b/tests/test_chat_formatters.py new file mode 100644 index 000000000..feae1c138 --- /dev/null +++ b/tests/test_chat_formatters.py @@ -0,0 +1,227 @@ +""" +Unit tests for chat formatters +""" + +# Third Party +import pytest + +# Local +from torchchat.generate import ( + HFTokenizerChatFormatter, + Llama2ChatFormatter, + Llama3ChatFormatter, +) + +## Helpers ##################################################################### + +class DummyTokenizer: + """Dummy tokenizer that encodes as strings so it's easy to check formatting""" + def encode(self, text, *_, **__): + return text + + +class DummySPTokenizer(DummyTokenizer): + """Emulated Sentencepiece tokenizer with bos/eos""" + bos = "" + eos = "" + + +class DummyLlama3Tokenizer(DummyTokenizer): + class _IdentityDict: + def __getitem__(self, key): + return key + special_tokens = _IdentityDict() + + +class DummyHFTokenizer(DummyTokenizer): + """Dummy made up chat template scheme""" + # Sequence + bos = "" + # Turn + bot = "" + eot = "" + # Role + bor = "" + eor = "" + def apply_chat_template(self, messages, add_generation_prompt): + out = [self.bos] + role = None + for msg in messages: + role = msg["role"] + content = msg["content"] + out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}") + if add_generation_prompt and role != "assistant": + out.append(f"{self.bot}{self.bor}assistant{self.eor}") + return "\n".join(out) + + +def check_rendering(fmt, messages, expected, add_generation_prompt): + """Render messages and compare to expected output""" + assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected + + +def make_message(role, text): + return {"role": role, "content": text} + + +SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything." +USER1 = "Hello world!" +ASSISTANT1 = "Greetings! How can I help you?" +USER2 = "Why is the sky blue?" +ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering." + + +# Stock sets of messages to test +MSGS_NO_SYS= [ + make_message("user", USER1), +] +MSGS_SYS_USR = [ + make_message("system", SYSTEM_PROMPT), + make_message("user", USER1), +] +MSGS_SYS_USR_ASST = [ + make_message("system", SYSTEM_PROMPT), + make_message("user", USER1), + make_message("assistant", ASSISTANT1), +] +MSGS_MULTI_TURN = [ + make_message("system", SYSTEM_PROMPT), + make_message("user", USER1), + make_message("assistant", ASSISTANT1), + make_message("user", USER2), + make_message("assistant", ASSISTANT2), +] + +## Llama2ChatFormatter ######################################################### + +@pytest.mark.parametrize( + ["messages", "expected"], + [ + # single user message (no system prompt) + (MSGS_NO_SYS, f"[INST] {USER1} [/INST]"), + # sys, usr + (MSGS_SYS_USR, f"""[INST] <> +{SYSTEM_PROMPT} +<> + +{USER1} [/INST]"""), + # sys, usr, asst + (MSGS_SYS_USR_ASST, f"""[INST] <> +{SYSTEM_PROMPT} +<> + +{USER1} [/INST] {ASSISTANT1} +"""), + # sys, usr, asst, usr, asst + (MSGS_MULTI_TURN, f"""[INST] <> +{SYSTEM_PROMPT} +<> + +{USER1} [/INST] {ASSISTANT1} +[INST] {USER2} [/INST] {ASSISTANT2} +"""), + ] +) +def test_llama2_chat_formatter(messages, expected): + """Tests for Llama2 following the official guide + https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/ + """ + tok = DummySPTokenizer() + fmt = Llama2ChatFormatter(tok) + # NOTE: add_generation_prompt not used by Llama2 + check_rendering(fmt, messages, expected, True) + +## Llama3ChatFormatter ######################################################### + +@pytest.mark.parametrize( + ["messages", "expected"], + [ + # single user message (no system prompt) + (MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|> +"""), + # sys, usr + (MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{SYSTEM_PROMPT}<|eot_id|> +<|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|> +"""), + # sys, usr, asst + (MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{SYSTEM_PROMPT}<|eot_id|> +<|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|> +<|start_header_id|>assistant<|end_header_id|> + +{ASSISTANT1}<|eot_id|> +"""), + # sys, usr, asst, usr, asst + (MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{SYSTEM_PROMPT}<|eot_id|> +<|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|> +<|start_header_id|>assistant<|end_header_id|> + +{ASSISTANT1}<|eot_id|> +<|start_header_id|>user<|end_header_id|> + +{USER2}<|eot_id|> +<|start_header_id|>assistant<|end_header_id|> + +{ASSISTANT2}<|eot_id|> +"""), + ] +) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_llama3_chat_formatter(messages, expected, add_generation_prompt): + """Tests for Llama3 following the official guide + https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/ + """ + tok = DummyLlama3Tokenizer() + fmt = Llama3ChatFormatter(tok) + # No assistant prompt added if the last message is from the assistant + if add_generation_prompt and messages[-1]["role"] != "assistant": + expected += "<|start_header_id|>assistant<|end_header_id|>\n\n" + check_rendering(fmt, messages, expected, add_generation_prompt) + +## HFTokenizerChatFormatter #################################################### + +@pytest.mark.parametrize( + ["messages", "expected"], + [ + # single user message (no system prompt) + (MSGS_NO_SYS, f""" +user{USER1}"""), + # sys, usr + (MSGS_SYS_USR, f""" +system{SYSTEM_PROMPT} +user{USER1}"""), + # sys, usr, asst + (MSGS_SYS_USR_ASST, f""" +system{SYSTEM_PROMPT} +user{USER1} +assistant{ASSISTANT1}"""), + # sys, usr, asst, usr, asst + (MSGS_MULTI_TURN, f""" +system{SYSTEM_PROMPT} +user{USER1} +assistant{ASSISTANT1} +user{USER2} +assistant{ASSISTANT2}"""), + ] +) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_hf_chat_formatter(messages, expected, add_generation_prompt): + tok = DummyHFTokenizer() + fmt = HFTokenizerChatFormatter(tok) + # No assistant prompt added if the last message is from the assistant + if add_generation_prompt and messages[-1]["role"] != "assistant": + expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}" + check_rendering(fmt, messages, expected, add_generation_prompt) From 2d7e546d0c8b2c2c78abc7ddfa6fa62f45a683e3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 25 Nov 2024 11:41:22 -0700 Subject: [PATCH 15/16] fix(logging): Disable logging in generate unless set in the env There is an incompatibility with logging and torch._dynamo, so this disables it unless the developer asks for it explicitly. NOTE: The TC team has stated that they have holistic logging on the roadmap so this is a short-term solution pending a more robust approach. REF: https://github.com/pytorch/torchchat/actions/runs/11963066986/job/33493237302#step:14:3599 Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- torchchat/generate.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/torchchat/generate.py b/torchchat/generate.py index 57938e4fd..274a0cec8 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -45,7 +45,18 @@ from torchchat.utils.device_info import get_device_info -logger = logging.getLogger(__name__) +# NOTE: Logging disabled by default here due to conflicts with torch._dynamo +class NoOpLogger: + def __no_op(self, *_, **__): + pass + def __getattr__(self, name): + return self.__no_op + + +logger = ( + NoOpLogger() if os.getenv("LOG_LEVEL") is None + else logging.getLogger(__name__) +) ## Chat Formatters ############################################################# From 78a3637ec857ff2369aa1d020f089c1e6979a98d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Dec 2024 09:05:30 -0700 Subject: [PATCH 16/16] fix: Remove trailing n from llama3 <|eot_id|> There's inconsistency in the documentation on whether or not there should be a n after <|eot_id|>, but this maintains consistency with previous formatting Branch: GraniteCodeSupport Signed-off-by: Gabe Goodhart --- tests/test_chat_formatters.py | 33 +++++++++++---------------------- torchchat/generate.py | 2 +- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/tests/test_chat_formatters.py b/tests/test_chat_formatters.py index feae1c138..2f7f7a955 100644 --- a/tests/test_chat_formatters.py +++ b/tests/test_chat_formatters.py @@ -139,44 +139,33 @@ def test_llama2_chat_formatter(messages, expected): # single user message (no system prompt) (MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -"""), +{USER1}<|eot_id|>"""), # sys, usr (MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> -{SYSTEM_PROMPT}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -"""), +{USER1}<|eot_id|>"""), # sys, usr, asst (MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> -{SYSTEM_PROMPT}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -<|start_header_id|>assistant<|end_header_id|> +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> -{ASSISTANT1}<|eot_id|> -"""), +{ASSISTANT1}<|eot_id|>"""), # sys, usr, asst, usr, asst (MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> -{SYSTEM_PROMPT}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER1}<|eot_id|> -<|start_header_id|>assistant<|end_header_id|> +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> -{ASSISTANT1}<|eot_id|> -<|start_header_id|>user<|end_header_id|> +{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|> -{USER2}<|eot_id|> -<|start_header_id|>assistant<|end_header_id|> +{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|> -{ASSISTANT2}<|eot_id|> -"""), +{ASSISTANT2}<|eot_id|>"""), ] ) @pytest.mark.parametrize("add_generation_prompt", [True, False]) diff --git a/torchchat/generate.py b/torchchat/generate.py index 274a0cec8..4d2439d2f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -121,7 +121,7 @@ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]: self.tokenizer.encode(content["text"], bos=False, eos=False) ) - tokens.append(self.tokenizer.special_tokens["<|eot_id|>\n"]) + tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) return tokens def encode_dialog_prompt(