diff --git a/install/requirements.txt b/install/requirements.txt index 8fb1832ba..457131275 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -9,6 +9,10 @@ gguf # Tiktoken tokenizer for Llama 3 and other advanced models tiktoken +# Tokenizers and jinja2 for other non-llama models that use HF tokenizers +tokenizers +jinja2 + # Miscellaneous snakeviz sentencepiece diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..c1580e27b --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +""" +Global pytest config, fixtures, and helpers go here! +""" + +# Standard +import os +import sys + +# Make sure tests can import torchchat +sys.path.append( + os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) +) diff --git a/tests/test_chat_formatters.py b/tests/test_chat_formatters.py new file mode 100644 index 000000000..2f7f7a955 --- /dev/null +++ b/tests/test_chat_formatters.py @@ -0,0 +1,216 @@ +""" +Unit tests for chat formatters +""" + +# Third Party +import pytest + +# Local +from torchchat.generate import ( + HFTokenizerChatFormatter, + Llama2ChatFormatter, + Llama3ChatFormatter, +) + +## Helpers ##################################################################### + +class DummyTokenizer: + """Dummy tokenizer that encodes as strings so it's easy to check formatting""" + def encode(self, text, *_, **__): + return text + + +class DummySPTokenizer(DummyTokenizer): + """Emulated Sentencepiece tokenizer with bos/eos""" + bos = "" + eos = "" + + +class DummyLlama3Tokenizer(DummyTokenizer): + class _IdentityDict: + def __getitem__(self, key): + return key + special_tokens = _IdentityDict() + + +class DummyHFTokenizer(DummyTokenizer): + """Dummy made up chat template scheme""" + # Sequence + bos = "" + # Turn + bot = "" + eot = "" + # Role + bor = "" + eor = "" + def apply_chat_template(self, messages, add_generation_prompt): + out = [self.bos] + role = None + for msg in messages: + role = msg["role"] + content = msg["content"] + out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}") + if add_generation_prompt and role != "assistant": + out.append(f"{self.bot}{self.bor}assistant{self.eor}") + return "\n".join(out) + + +def check_rendering(fmt, messages, expected, add_generation_prompt): + """Render messages and compare to expected output""" + assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected + + +def make_message(role, text): + return {"role": role, "content": text} + + +SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything." +USER1 = "Hello world!" +ASSISTANT1 = "Greetings! How can I help you?" +USER2 = "Why is the sky blue?" +ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering." + + +# Stock sets of messages to test +MSGS_NO_SYS= [ + make_message("user", USER1), +] +MSGS_SYS_USR = [ + make_message("system", SYSTEM_PROMPT), + make_message("user", USER1), +] +MSGS_SYS_USR_ASST = [ + make_message("system", SYSTEM_PROMPT), + make_message("user", USER1), + make_message("assistant", ASSISTANT1), +] +MSGS_MULTI_TURN = [ + make_message("system", SYSTEM_PROMPT), + make_message("user", USER1), + make_message("assistant", ASSISTANT1), + make_message("user", USER2), + make_message("assistant", ASSISTANT2), +] + +## Llama2ChatFormatter ######################################################### + +@pytest.mark.parametrize( + ["messages", "expected"], + [ + # single user message (no system prompt) + (MSGS_NO_SYS, f"[INST] {USER1} [/INST]"), + # sys, usr + (MSGS_SYS_USR, f"""[INST] <> +{SYSTEM_PROMPT} +<> + +{USER1} [/INST]"""), + # sys, usr, asst + (MSGS_SYS_USR_ASST, f"""[INST] <> +{SYSTEM_PROMPT} +<> + +{USER1} [/INST] {ASSISTANT1} +"""), + # sys, usr, asst, usr, asst + (MSGS_MULTI_TURN, f"""[INST] <> +{SYSTEM_PROMPT} +<> + +{USER1} [/INST] {ASSISTANT1} +[INST] {USER2} [/INST] {ASSISTANT2} +"""), + ] +) +def test_llama2_chat_formatter(messages, expected): + """Tests for Llama2 following the official guide + https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/ + """ + tok = DummySPTokenizer() + fmt = Llama2ChatFormatter(tok) + # NOTE: add_generation_prompt not used by Llama2 + check_rendering(fmt, messages, expected, True) + +## Llama3ChatFormatter ######################################################### + +@pytest.mark.parametrize( + ["messages", "expected"], + [ + # single user message (no system prompt) + (MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|>"""), + # sys, usr + (MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|>"""), + # sys, usr, asst + (MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{ASSISTANT1}<|eot_id|>"""), + # sys, usr, asst, usr, asst + (MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> + +{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|> + +{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + +{ASSISTANT2}<|eot_id|>"""), + ] +) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_llama3_chat_formatter(messages, expected, add_generation_prompt): + """Tests for Llama3 following the official guide + https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/ + """ + tok = DummyLlama3Tokenizer() + fmt = Llama3ChatFormatter(tok) + # No assistant prompt added if the last message is from the assistant + if add_generation_prompt and messages[-1]["role"] != "assistant": + expected += "<|start_header_id|>assistant<|end_header_id|>\n\n" + check_rendering(fmt, messages, expected, add_generation_prompt) + +## HFTokenizerChatFormatter #################################################### + +@pytest.mark.parametrize( + ["messages", "expected"], + [ + # single user message (no system prompt) + (MSGS_NO_SYS, f""" +user{USER1}"""), + # sys, usr + (MSGS_SYS_USR, f""" +system{SYSTEM_PROMPT} +user{USER1}"""), + # sys, usr, asst + (MSGS_SYS_USR_ASST, f""" +system{SYSTEM_PROMPT} +user{USER1} +assistant{ASSISTANT1}"""), + # sys, usr, asst, usr, asst + (MSGS_MULTI_TURN, f""" +system{SYSTEM_PROMPT} +user{USER1} +assistant{ASSISTANT1} +user{USER2} +assistant{ASSISTANT2}"""), + ] +) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +def test_hf_chat_formatter(messages, expected, add_generation_prompt): + tok = DummyHFTokenizer() + fmt = HFTokenizerChatFormatter(tok) + # No assistant prompt added if the last message is from the assistant + if add_generation_prompt and messages[-1]["role"] != "assistant": + expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}" + check_rendering(fmt, messages, expected, add_generation_prompt) diff --git a/tokenizer/hf_tokenizer.py b/tokenizer/hf_tokenizer.py index 7ad5807d1..d10ecb076 100644 --- a/tokenizer/hf_tokenizer.py +++ b/tokenizer/hf_tokenizer.py @@ -5,11 +5,12 @@ # LICENSE file in the root directory of this source tree. # Standard -from typing import List, Optional +from typing import Dict, List, Optional import json import os # Third Party +import jinja2 from tokenizers import Tokenizer # Local @@ -37,6 +38,9 @@ def __init__(self, file_path: str): # Load the tokenizer itself self._tokenizer = Tokenizer.from_file(tokenizer_path) + # Load the chat template if we have a config path + self._chat_template: Optional[jinja2.Template] = None + # If available, parse bos/eos tokens from the tokenizer config self._bos_id, self._eos_id = None, None if tokenizer_config_path is not None: @@ -48,6 +52,8 @@ def __init__(self, file_path: str): self._bos_id = self._tokenizer.token_to_id(bos_token) if eos_token is not None: self._eos_id = self._tokenizer.token_to_id(eos_token) + if chat_template_str := tok_config.get("chat_template"): + self._chat_template = jinja2.Template(chat_template_str) # If no eos/bos tokens found, go looking for them! if None in [self._bos_id, self._eos_id]: @@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio if len(candidate_toks) == 1: return candidate_toks[0]["id"] + ## Interface ## + def encode( self, s: str, @@ -90,3 +98,21 @@ def bos_id(self) -> int: def eos_id(self) -> int: return self._eos_id + + ## Additional Public Methods ## + + def has_chat_template(self) -> bool: + return bool(self._chat_template) + + def apply_chat_template( + self, + dialog: List[Dict[str, str]], + add_generation_prompt: bool = False, + ) -> str: + """If configured with a chat template, apply it to the list of messages + """ + if not self._chat_template: + raise ValueError("No chat template configured!") + return self._chat_template.render( + messages=dialog, add_generation_prompt=add_generation_prompt + ) diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py index a7f7bbba2..91bdcaf26 100644 --- a/torchchat/cli/cli.py +++ b/torchchat/cli/cli.py @@ -17,7 +17,15 @@ allowable_params_table, ) -logging.basicConfig(level=logging.INFO, format="%(message)s") +_log_level_env = os.getenv("LOG_LEVEL", "INFO") +try: + _log_level = getattr(logging, _log_level_env.upper()) +except AttributeError: + print(f"Invalid log level: {_log_level_env}", file=sys.stderr) + _log_level = logging.INFO + + +logging.basicConfig(level=_log_level, format="%(message)s") logger = logging.getLogger(__name__) default_device = os.getenv("TORCHCHAT_DEVICE", "fast") diff --git a/torchchat/generate.py b/torchchat/generate.py index 9b4c6430a..4d2439d2f 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -45,13 +45,52 @@ from torchchat.utils.device_info import get_device_info +# NOTE: Logging disabled by default here due to conflicts with torch._dynamo +class NoOpLogger: + def __no_op(self, *_, **__): + pass + def __getattr__(self, name): + return self.__no_op + + +logger = ( + NoOpLogger() if os.getenv("LOG_LEVEL") is None + else logging.getLogger(__name__) +) + +## Chat Formatters ############################################################# + class _ChatFormatter(ABC): + + # Messages can arrive as a standard dict with "role" and "content" as + # strings, or where "content" is a list of objects with "text" fields. + MESSAGE_TYPE = Dict[str, Union[str, List[Dict[str, str]]]] + + # A dialog is a sequence of messages + DIALOG_TYPE = List[MESSAGE_TYPE] + def __init__(self, tokenizer): self.tokenizer = tokenizer @abstractmethod - def encode_dialog_prompt(self, dialog) -> List[int]: - raise NotImplementedError() + def encode_dialog_prompt( + self, + dialog: DIALOG_TYPE, + add_generation_prompt: bool, + ) -> List[int]: + """Encode a sequence of messages into a sequence of token IDs, including + the chat template + + Args: + dialog (DIALOG_TYPE): The sequence of dialog messages to encode. + This will be the additional messages on top of those that have + already been processed. + add_generation_prompt (bool): Whether to include a generation prompt + at the end of the encoded sequence. + + Returns: + List[int]: A list of token IDs representing the encoded prompt. + """ class Llama3ChatFormatter(_ChatFormatter): @@ -61,7 +100,7 @@ class Llama3ChatFormatter(_ChatFormatter): """ - def encode_header(self, role) -> List[int]: + def _encode_header(self, role) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) tokens.extend(self.tokenizer.encode(role, bos=False, eos=False)) @@ -69,8 +108,8 @@ def encode_header(self, role) -> List[int]: tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) return tokens - def encode_message(self, message) -> List[int]: - tokens = self.encode_header(message["role"]) + def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]: + tokens = self._encode_header(message["role"]) if isinstance(message["content"], str): tokens.extend( self.tokenizer.encode(message["content"], bos=False, eos=False) @@ -85,46 +124,80 @@ def encode_message(self, message) -> List[int]: tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) return tokens - def encode_dialog_prompt(self, dialog) -> List[int]: + def encode_dialog_prompt( + self, + dialog: _ChatFormatter.DIALOG_TYPE, + add_generation_prompt: bool, + ) -> List[int]: tokens = [] tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) for message in dialog: - tokens.extend(self.encode_message(message)) + tokens.extend(self._encode_message(message)) # Add the start of an assistant message for the model to complete. - tokens.extend(self.encode_header("assistant")) # Pass role directly as a string + if add_generation_prompt and dialog and dialog[-1]["role"] != "assistant": + tokens.extend(self._encode_header("assistant")) # Pass role directly as a string return tokens -B_INST, E_INST = "[INST]", "[/INST]" -B_SYS, E_SYS = "<>", "<>" +class Llama2ChatFormatter(_ChatFormatter): + """ + Chat formatting for Llama2 + CITE: https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/ + """ + + B_INST, E_INST = "[INST] ", " [/INST]" + B_SYS, E_SYS = "<>\n", "\n<>\n\n" + @staticmethod + def _get_content_str(message: _ChatFormatter.MESSAGE_TYPE) -> str: + if isinstance(message["content"], list): + return message["content"][0]["text"] + return message["content"] -class Llama2ChatFormatter(_ChatFormatter): - def encode_dialog_prompt(self, dialog) -> List[int]: - tokens = self.tokenizer.encode(f"{B_INST} ") - first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it. + def encode_dialog_prompt( + self, + dialog: _ChatFormatter.DIALOG_TYPE, + add_generation_prompt: bool, # UNUSED + ) -> List[int]: + new_turn = True + tokens = [] for message in dialog: - if isinstance(message["content"], list): - content = message["content"][0]["text"] + if new_turn: + tokens += self.tokenizer.encode(f"{self.tokenizer.bos}{self.B_INST}") + content = self._get_content_str(message).strip() + role = message["role"] + if role == "system": + tokens += self.tokenizer.encode(f"{self.B_SYS}{content}{self.E_SYS}") + new_turn = False + elif role == "user": + tokens += self.tokenizer.encode(f"{content}{self.E_INST}") + new_turn = False + elif role == "assistant": + tokens += self.tokenizer.encode(f" {content} {self.tokenizer.eos}\n") + new_turn = True else: - content = message["content"] - content = content.strip() - if message["role"] == "system": - encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}") - first_message = False - elif message["role"] == "user": - encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode( - f"{B_INST if first_message else ''} {content} {E_INST} " - ) - first_message = True - elif message["role"] == "assistant": - encoded = self.tokenizer.encode(f"{content}\n\n") + [ - self.tokenizer.eos_id() - ] - tokens += encoded + raise ValueError("Invalid role in dialog.") return tokens + +class HFTokenizerChatFormatter(_ChatFormatter): + """Chat formatter that uses the built-in formatting capabilities of an HF + tokenizer instance + """ + def encode_dialog_prompt( + self, + dialog: _ChatFormatter.DIALOG_TYPE, + add_generation_prompt: bool, + ) -> List[int]: + rendered = self.tokenizer.apply_chat_template( + dialog, add_generation_prompt=add_generation_prompt + ) + logger.debug("Formatted chat prompt:\n%s", rendered) + return self.tokenizer.encode(rendered) + +## Generation ################################################################## + @dataclass class GeneratorArgs: prompt: Optional[str] = ( @@ -283,9 +356,13 @@ def __init__( if self.is_llama3_model: self.chat_formatter = Llama3ChatFormatter(self.tokenizer) if generator_args.chat_mode: - logging.debug( + logger.debug( "Llama3 model detected in chat mode. Using updated sentence schemas" ) + elif self.tokenizer_args.is_hf_tokenizer: + if not self.tokenizer.has_chat_template(): + raise ValueError("Tokenizer must have a chat template") + self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer) else: self.chat_formatter = Llama2ChatFormatter(self.tokenizer) @@ -341,10 +418,12 @@ def sample( temperature: float = 0, top_k: Optional[int] = None, ): + logits = logits[0, -1] + logger.debug("Logits: %s", logits) if temperature == 0 and not need_probs: - _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1) + _, idx_next = torch.topk(logits, k=1, dim=-1) return (idx_next, None) - probs = self.logits_to_probs(logits[0, -1], temperature, top_k) + probs = self.logits_to_probs(logits, temperature, top_k) idx_next = self.multinomial_sample_one_no_sync(probs) return idx_next, probs @@ -358,7 +437,7 @@ def prefill( sequential_prefill=True, **sampling_kwargs, ) -> torch.Tensor: - # logging.debug(f"x: {x}, input_pos: {input_pos}") + logger.debug("x: %s, input_pos: %s", x, input_pos) width = x.size(1) assert input_pos.size(0) == width @@ -394,7 +473,7 @@ def prefill( elif sequential_prefill: for i in range(width): x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1) - # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}") + logger.debug(" x: %s, input_pos: %s", x_sliced, ip_sliced) logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])da else: # input_pos: [B, S] @@ -727,7 +806,8 @@ def encode_tokens(self, string, bos=True, device="cpu"): tokens = self.tokenizer.encode(string) if bos: tokens = [self.tokenizer.bos_id()] + tokens - logging.debug(f"Size after encode_tokens: {len(tokens)}") + logger.debug("Size after encode_tokens: %d", len(tokens)) + logger.debug("Token IDs: %s", tokens) return torch.tensor(tokens, dtype=torch.int, device=device) def _callback(self, x, *, buffer, done_generating): @@ -776,7 +856,7 @@ def _gen_model_input( # Single String prompt if isinstance(prompt, str): encoded = self.encode_tokens( - prompt, bos=True, device=self.builder_args.device + prompt, bos=self.model.config.tokenizer_prepend_bos, device=self.builder_args.device ) # List of dialog else: @@ -785,7 +865,7 @@ def _gen_model_input( tokens, dtype=torch.int, device=self.builder_args.device ) - logging.debug(encoded) + logger.debug(encoded) return encoded, None # Llama 3.2 11B @@ -900,7 +980,7 @@ def _gen_model_input( value=0, ) - logging.debug(encoded) + logger.debug(encoded) return encoded, batch def chat( @@ -1021,38 +1101,21 @@ def chat( if prompt == "/bye": print("Exiting Chat.\n") break - if not self.is_llama3_model: - if self.system_prompt: - prompt = f"{B_INST} {B_SYS}\n{self.system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip()} {E_INST}" - self.system_prompt = ( - None # can only provide system prompt on first interaction - ) - else: - prompt = f"{B_INST} {prompt.strip()} {E_INST}" - encoded = self.encode_tokens( - prompt, bos=True, device=self.builder_args.device - ) - else: - if self.system_prompt: - encoded = self.chat_formatter.encode_dialog_prompt( - [ - {"role": "system", "content": self.system_prompt}, - {"role": "user", "content": prompt}, - ] - ) - self.system_prompt = None - elif is_first_sample: - encoded = self.chat_formatter.encode_dialog_prompt( - [{"role": "user", "content": prompt}] - ) - else: - encoded = self.chat_formatter.encode_message( - {"role": "user", "content": prompt} - ) - encoded.extend(self.chat_formatter.encode_header("assistant")) - encoded = torch.tensor( - encoded, dtype=torch.int, device=self.builder_args.device + + # Encode the additional messages added in this dialog turn. If + # this is the first turn, that includes any system prompt. + messages_to_encode = [] + if is_first_sample and self.system_prompt: + messages_to_encode.append( + {"role": "system", "content": self.system_prompt} ) + messages_to_encode.append({"role": "system", "content": prompt}) + encoded = self.chat_formatter.encode_dialog_prompt( + messages_to_encode, add_generation_prompt=True, + ) + encoded = torch.tensor( + encoded, dtype=torch.int, device=self.builder_args.device + ) if encoded.size(0) + start_pos > max_seq_length: print( "This prompt would take us past the max_seq_length. Ending Conversation." @@ -1231,6 +1294,7 @@ def main(args): speculative_builder_args = BuilderArgs.from_speculative_args(args) tokenizer_args = TokenizerArgs.from_args(args) generator_args = GeneratorArgs.from_args(args) + logger.debug("GeneratorArgs: %s", generator_args) if not builder_args.distributed: gen = Generator( builder_args, diff --git a/torchchat/model.py b/torchchat/model.py index 2a3b9f12f..1c78d4c63 100644 --- a/torchchat/model.py +++ b/torchchat/model.py @@ -4,6 +4,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json +import logging import os import warnings from abc import ABC, abstractmethod @@ -48,6 +49,8 @@ config_path = Path(f"{str(Path(__file__).parent)}/model_params") +logger = logging.getLogger(__name__) + class QuickGELUActivation(nn.Module): """ @@ -273,6 +276,7 @@ class TransformerArgs: # Select the desired tokenizer. Defaults to sentencepiece use_tiktoken: bool = False use_hf_tokenizer: bool = False + tokenizer_prepend_bos: bool = True max_seq_length: int = 8192 rope_scaling: Optional[Dict[str, Any]] = None # For pipeline parallel @@ -330,6 +334,7 @@ class ModelArgs: transformer_args: Dict[str, Dict[str, Any]] use_tiktoken: bool use_hf_tokenizer: bool + tokenizer_prepend_bos: bool def __init__( self, @@ -337,6 +342,7 @@ def __init__( model_type: ModelType = ModelType.TextOnly, use_tiktoken: bool = False, use_hf_tokenizer: bool = False, + tokenizer_prepend_bos: bool = True, ) -> None: self._sanity_check(transformer_args, model_type) @@ -346,6 +352,7 @@ def __init__( # Model-level attributes self.use_tiktoken = use_tiktoken self.use_hf_tokenizer = use_hf_tokenizer + self.tokenizer_prepend_bos = tokenizer_prepend_bos def _sanity_check( self, @@ -373,7 +380,14 @@ def from_params(cls, params_path): use_tiktoken = loaded_params.get("use_tiktoken", False) use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False) - return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer) + tokenizer_prepend_bos = loaded_params.get("tokenizer_prepend_bos", True) + return cls( + transformer_args=transformer_args, + model_type=model_type, + use_tiktoken=use_tiktoken, + use_hf_tokenizer=use_hf_tokenizer, + tokenizer_prepend_bos=tokenizer_prepend_bos, + ) @classmethod def from_table(cls, name: str): @@ -477,7 +491,9 @@ def build_model(self) -> nn.Module: for name, module_class in recipe.modules.items(): config_args = self.config.transformer_args[name] if module_class == Transformer: - modules[name] = module_class(TransformerArgs.from_params(config_args)) + transformer_args = TransformerArgs.from_params(config_args) + logger.debug("Transformer Args: %s", transformer_args) + modules[name] = module_class(transformer_args) else: modules[name] = module_class(**config_args) diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json index 2d3dfcbeb..8791601fb 100644 --- a/torchchat/model_config/models.json +++ b/torchchat/model_config/models.json @@ -164,5 +164,19 @@ "https://github.com/karpathy/llama2.c/raw/master/tokenizer.model" ], "checkpoint_file": "stories110M.pt" + }, + "ibm-granite/granite-3b-code-instruct-128k": { + "aliases": ["granite-code", "granite-code-3b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-3b-code-instruct-128k", + "transformer_params_key": "Granite-3B-Code", + "tokenizer_file": "tokenizer.json" + }, + "ibm-granite/granite-8b-code-instruct-128k": { + "aliases": ["granite-code-8b"], + "distribution_channel": "HuggingFaceSnapshot", + "distribution_path": "ibm-granite/granite-8b-code-instruct-128k", + "transformer_params_key": "Granite-8B-Code", + "tokenizer_file": "tokenizer.json" } } diff --git a/torchchat/model_params/Granite-3B-Code.json b/torchchat/model_params/Granite-3B-Code.json new file mode 100644 index 000000000..0654a8f2c --- /dev/null +++ b/torchchat/model_params/Granite-3B-Code.json @@ -0,0 +1,17 @@ +{ + "block_size": 128000, + "dim": 2560, + "hidden_dim": 10240, + "n_heads": 32, + "n_local_heads": 32, + "n_layers": 32, + "rope_base": 10000000, + "vocab_size": 49152, + "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, + "norm_eps": 0.00001, + "rope_scaling": null, + "attention_bias": true, + "feed_forward_bias": true, + "tie_word_embeddings": true +} \ No newline at end of file diff --git a/torchchat/model_params/Granite-8B-Code.json b/torchchat/model_params/Granite-8B-Code.json new file mode 100644 index 000000000..079a32070 --- /dev/null +++ b/torchchat/model_params/Granite-8B-Code.json @@ -0,0 +1,17 @@ +{ + "block_size": 128000, + "dim": 4096, + "hidden_dim": 14336, + "n_heads": 32, + "n_local_heads": 8, + "n_layers": 36, + "rope_base": 10000000, + "vocab_size": 49152, + "use_hf_tokenizer": true, + "tokenizer_prepend_bos": false, + "norm_eps": 0.00001, + "rope_scaling": null, + "attention_bias": true, + "feed_forward_bias": true, + "tie_word_embeddings": true +} \ No newline at end of file