diff --git a/install/requirements.txt b/install/requirements.txt
index 8fb1832ba..457131275 100644
--- a/install/requirements.txt
+++ b/install/requirements.txt
@@ -9,6 +9,10 @@ gguf
# Tiktoken tokenizer for Llama 3 and other advanced models
tiktoken
+# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
+tokenizers
+jinja2
+
# Miscellaneous
snakeviz
sentencepiece
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 000000000..c1580e27b
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,12 @@
+"""
+Global pytest config, fixtures, and helpers go here!
+"""
+
+# Standard
+import os
+import sys
+
+# Make sure tests can import torchchat
+sys.path.append(
+ os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
+)
diff --git a/tests/test_chat_formatters.py b/tests/test_chat_formatters.py
new file mode 100644
index 000000000..2f7f7a955
--- /dev/null
+++ b/tests/test_chat_formatters.py
@@ -0,0 +1,216 @@
+"""
+Unit tests for chat formatters
+"""
+
+# Third Party
+import pytest
+
+# Local
+from torchchat.generate import (
+ HFTokenizerChatFormatter,
+ Llama2ChatFormatter,
+ Llama3ChatFormatter,
+)
+
+## Helpers #####################################################################
+
+class DummyTokenizer:
+ """Dummy tokenizer that encodes as strings so it's easy to check formatting"""
+ def encode(self, text, *_, **__):
+ return text
+
+
+class DummySPTokenizer(DummyTokenizer):
+ """Emulated Sentencepiece tokenizer with bos/eos"""
+ bos = ""
+ eos = ""
+
+
+class DummyLlama3Tokenizer(DummyTokenizer):
+ class _IdentityDict:
+ def __getitem__(self, key):
+ return key
+ special_tokens = _IdentityDict()
+
+
+class DummyHFTokenizer(DummyTokenizer):
+ """Dummy made up chat template scheme"""
+ # Sequence
+ bos = ""
+ # Turn
+ bot = ""
+ eot = ""
+ # Role
+ bor = ""
+ eor = ""
+ def apply_chat_template(self, messages, add_generation_prompt):
+ out = [self.bos]
+ role = None
+ for msg in messages:
+ role = msg["role"]
+ content = msg["content"]
+ out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
+ if add_generation_prompt and role != "assistant":
+ out.append(f"{self.bot}{self.bor}assistant{self.eor}")
+ return "\n".join(out)
+
+
+def check_rendering(fmt, messages, expected, add_generation_prompt):
+ """Render messages and compare to expected output"""
+ assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected
+
+
+def make_message(role, text):
+ return {"role": role, "content": text}
+
+
+SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
+USER1 = "Hello world!"
+ASSISTANT1 = "Greetings! How can I help you?"
+USER2 = "Why is the sky blue?"
+ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."
+
+
+# Stock sets of messages to test
+MSGS_NO_SYS= [
+ make_message("user", USER1),
+]
+MSGS_SYS_USR = [
+ make_message("system", SYSTEM_PROMPT),
+ make_message("user", USER1),
+]
+MSGS_SYS_USR_ASST = [
+ make_message("system", SYSTEM_PROMPT),
+ make_message("user", USER1),
+ make_message("assistant", ASSISTANT1),
+]
+MSGS_MULTI_TURN = [
+ make_message("system", SYSTEM_PROMPT),
+ make_message("user", USER1),
+ make_message("assistant", ASSISTANT1),
+ make_message("user", USER2),
+ make_message("assistant", ASSISTANT2),
+]
+
+## Llama2ChatFormatter #########################################################
+
+@pytest.mark.parametrize(
+ ["messages", "expected"],
+ [
+ # single user message (no system prompt)
+ (MSGS_NO_SYS, f"[INST] {USER1} [/INST]"),
+ # sys, usr
+ (MSGS_SYS_USR, f"""[INST] <>
+{SYSTEM_PROMPT}
+<>
+
+{USER1} [/INST]"""),
+ # sys, usr, asst
+ (MSGS_SYS_USR_ASST, f"""[INST] <>
+{SYSTEM_PROMPT}
+<>
+
+{USER1} [/INST] {ASSISTANT1}
+"""),
+ # sys, usr, asst, usr, asst
+ (MSGS_MULTI_TURN, f"""[INST] <>
+{SYSTEM_PROMPT}
+<>
+
+{USER1} [/INST] {ASSISTANT1}
+[INST] {USER2} [/INST] {ASSISTANT2}
+"""),
+ ]
+)
+def test_llama2_chat_formatter(messages, expected):
+ """Tests for Llama2 following the official guide
+ https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
+ """
+ tok = DummySPTokenizer()
+ fmt = Llama2ChatFormatter(tok)
+ # NOTE: add_generation_prompt not used by Llama2
+ check_rendering(fmt, messages, expected, True)
+
+## Llama3ChatFormatter #########################################################
+
+@pytest.mark.parametrize(
+ ["messages", "expected"],
+ [
+ # single user message (no system prompt)
+ (MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
+
+{USER1}<|eot_id|>"""),
+ # sys, usr
+ (MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+
+{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
+
+{USER1}<|eot_id|>"""),
+ # sys, usr, asst
+ (MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+
+{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
+
+{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+{ASSISTANT1}<|eot_id|>"""),
+ # sys, usr, asst, usr, asst
+ (MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
+
+{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>
+
+{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>
+
+{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
+
+{ASSISTANT2}<|eot_id|>"""),
+ ]
+)
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
+ """Tests for Llama3 following the official guide
+ https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
+ """
+ tok = DummyLlama3Tokenizer()
+ fmt = Llama3ChatFormatter(tok)
+ # No assistant prompt added if the last message is from the assistant
+ if add_generation_prompt and messages[-1]["role"] != "assistant":
+ expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
+ check_rendering(fmt, messages, expected, add_generation_prompt)
+
+## HFTokenizerChatFormatter ####################################################
+
+@pytest.mark.parametrize(
+ ["messages", "expected"],
+ [
+ # single user message (no system prompt)
+ (MSGS_NO_SYS, f"""
+user{USER1}"""),
+ # sys, usr
+ (MSGS_SYS_USR, f"""
+system{SYSTEM_PROMPT}
+user{USER1}"""),
+ # sys, usr, asst
+ (MSGS_SYS_USR_ASST, f"""
+system{SYSTEM_PROMPT}
+user{USER1}
+assistant{ASSISTANT1}"""),
+ # sys, usr, asst, usr, asst
+ (MSGS_MULTI_TURN, f"""
+system{SYSTEM_PROMPT}
+user{USER1}
+assistant{ASSISTANT1}
+user{USER2}
+assistant{ASSISTANT2}"""),
+ ]
+)
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+def test_hf_chat_formatter(messages, expected, add_generation_prompt):
+ tok = DummyHFTokenizer()
+ fmt = HFTokenizerChatFormatter(tok)
+ # No assistant prompt added if the last message is from the assistant
+ if add_generation_prompt and messages[-1]["role"] != "assistant":
+ expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
+ check_rendering(fmt, messages, expected, add_generation_prompt)
diff --git a/tokenizer/hf_tokenizer.py b/tokenizer/hf_tokenizer.py
index 7ad5807d1..d10ecb076 100644
--- a/tokenizer/hf_tokenizer.py
+++ b/tokenizer/hf_tokenizer.py
@@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.
# Standard
-from typing import List, Optional
+from typing import Dict, List, Optional
import json
import os
# Third Party
+import jinja2
from tokenizers import Tokenizer
# Local
@@ -37,6 +38,9 @@ def __init__(self, file_path: str):
# Load the tokenizer itself
self._tokenizer = Tokenizer.from_file(tokenizer_path)
+ # Load the chat template if we have a config path
+ self._chat_template: Optional[jinja2.Template] = None
+
# If available, parse bos/eos tokens from the tokenizer config
self._bos_id, self._eos_id = None, None
if tokenizer_config_path is not None:
@@ -48,6 +52,8 @@ def __init__(self, file_path: str):
self._bos_id = self._tokenizer.token_to_id(bos_token)
if eos_token is not None:
self._eos_id = self._tokenizer.token_to_id(eos_token)
+ if chat_template_str := tok_config.get("chat_template"):
+ self._chat_template = jinja2.Template(chat_template_str)
# If no eos/bos tokens found, go looking for them!
if None in [self._bos_id, self._eos_id]:
@@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
if len(candidate_toks) == 1:
return candidate_toks[0]["id"]
+ ## Interface ##
+
def encode(
self,
s: str,
@@ -90,3 +98,21 @@ def bos_id(self) -> int:
def eos_id(self) -> int:
return self._eos_id
+
+ ## Additional Public Methods ##
+
+ def has_chat_template(self) -> bool:
+ return bool(self._chat_template)
+
+ def apply_chat_template(
+ self,
+ dialog: List[Dict[str, str]],
+ add_generation_prompt: bool = False,
+ ) -> str:
+ """If configured with a chat template, apply it to the list of messages
+ """
+ if not self._chat_template:
+ raise ValueError("No chat template configured!")
+ return self._chat_template.render(
+ messages=dialog, add_generation_prompt=add_generation_prompt
+ )
diff --git a/torchchat/cli/cli.py b/torchchat/cli/cli.py
index a7f7bbba2..91bdcaf26 100644
--- a/torchchat/cli/cli.py
+++ b/torchchat/cli/cli.py
@@ -17,7 +17,15 @@
allowable_params_table,
)
-logging.basicConfig(level=logging.INFO, format="%(message)s")
+_log_level_env = os.getenv("LOG_LEVEL", "INFO")
+try:
+ _log_level = getattr(logging, _log_level_env.upper())
+except AttributeError:
+ print(f"Invalid log level: {_log_level_env}", file=sys.stderr)
+ _log_level = logging.INFO
+
+
+logging.basicConfig(level=_log_level, format="%(message)s")
logger = logging.getLogger(__name__)
default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
diff --git a/torchchat/generate.py b/torchchat/generate.py
index 9b4c6430a..4d2439d2f 100644
--- a/torchchat/generate.py
+++ b/torchchat/generate.py
@@ -45,13 +45,52 @@
from torchchat.utils.device_info import get_device_info
+# NOTE: Logging disabled by default here due to conflicts with torch._dynamo
+class NoOpLogger:
+ def __no_op(self, *_, **__):
+ pass
+ def __getattr__(self, name):
+ return self.__no_op
+
+
+logger = (
+ NoOpLogger() if os.getenv("LOG_LEVEL") is None
+ else logging.getLogger(__name__)
+)
+
+## Chat Formatters #############################################################
+
class _ChatFormatter(ABC):
+
+ # Messages can arrive as a standard dict with "role" and "content" as
+ # strings, or where "content" is a list of objects with "text" fields.
+ MESSAGE_TYPE = Dict[str, Union[str, List[Dict[str, str]]]]
+
+ # A dialog is a sequence of messages
+ DIALOG_TYPE = List[MESSAGE_TYPE]
+
def __init__(self, tokenizer):
self.tokenizer = tokenizer
@abstractmethod
- def encode_dialog_prompt(self, dialog) -> List[int]:
- raise NotImplementedError()
+ def encode_dialog_prompt(
+ self,
+ dialog: DIALOG_TYPE,
+ add_generation_prompt: bool,
+ ) -> List[int]:
+ """Encode a sequence of messages into a sequence of token IDs, including
+ the chat template
+
+ Args:
+ dialog (DIALOG_TYPE): The sequence of dialog messages to encode.
+ This will be the additional messages on top of those that have
+ already been processed.
+ add_generation_prompt (bool): Whether to include a generation prompt
+ at the end of the encoded sequence.
+
+ Returns:
+ List[int]: A list of token IDs representing the encoded prompt.
+ """
class Llama3ChatFormatter(_ChatFormatter):
@@ -61,7 +100,7 @@ class Llama3ChatFormatter(_ChatFormatter):
"""
- def encode_header(self, role) -> List[int]:
+ def _encode_header(self, role) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
tokens.extend(self.tokenizer.encode(role, bos=False, eos=False))
@@ -69,8 +108,8 @@ def encode_header(self, role) -> List[int]:
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
- def encode_message(self, message) -> List[int]:
- tokens = self.encode_header(message["role"])
+ def _encode_message(self, message: _ChatFormatter.MESSAGE_TYPE) -> List[int]:
+ tokens = self._encode_header(message["role"])
if isinstance(message["content"], str):
tokens.extend(
self.tokenizer.encode(message["content"], bos=False, eos=False)
@@ -85,46 +124,80 @@ def encode_message(self, message) -> List[int]:
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
return tokens
- def encode_dialog_prompt(self, dialog) -> List[int]:
+ def encode_dialog_prompt(
+ self,
+ dialog: _ChatFormatter.DIALOG_TYPE,
+ add_generation_prompt: bool,
+ ) -> List[int]:
tokens = []
tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"])
for message in dialog:
- tokens.extend(self.encode_message(message))
+ tokens.extend(self._encode_message(message))
# Add the start of an assistant message for the model to complete.
- tokens.extend(self.encode_header("assistant")) # Pass role directly as a string
+ if add_generation_prompt and dialog and dialog[-1]["role"] != "assistant":
+ tokens.extend(self._encode_header("assistant")) # Pass role directly as a string
return tokens
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<>", "<>"
+class Llama2ChatFormatter(_ChatFormatter):
+ """
+ Chat formatting for Llama2
+ CITE: https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
+ """
+
+ B_INST, E_INST = "[INST] ", " [/INST]"
+ B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+ @staticmethod
+ def _get_content_str(message: _ChatFormatter.MESSAGE_TYPE) -> str:
+ if isinstance(message["content"], list):
+ return message["content"][0]["text"]
+ return message["content"]
-class Llama2ChatFormatter(_ChatFormatter):
- def encode_dialog_prompt(self, dialog) -> List[int]:
- tokens = self.tokenizer.encode(f"{B_INST} ")
- first_message = True # Bool to handle placing the B_INST token. Behavior is weird - the system prompt should have the B_INST, but not the first user message. All following user messages *should* have it. Also, if there is no system prompt, then the user message should have it.
+ def encode_dialog_prompt(
+ self,
+ dialog: _ChatFormatter.DIALOG_TYPE,
+ add_generation_prompt: bool, # UNUSED
+ ) -> List[int]:
+ new_turn = True
+ tokens = []
for message in dialog:
- if isinstance(message["content"], list):
- content = message["content"][0]["text"]
+ if new_turn:
+ tokens += self.tokenizer.encode(f"{self.tokenizer.bos}{self.B_INST}")
+ content = self._get_content_str(message).strip()
+ role = message["role"]
+ if role == "system":
+ tokens += self.tokenizer.encode(f"{self.B_SYS}{content}{self.E_SYS}")
+ new_turn = False
+ elif role == "user":
+ tokens += self.tokenizer.encode(f"{content}{self.E_INST}")
+ new_turn = False
+ elif role == "assistant":
+ tokens += self.tokenizer.encode(f" {content} {self.tokenizer.eos}\n")
+ new_turn = True
else:
- content = message["content"]
- content = content.strip()
- if message["role"] == "system":
- encoded = self.tokenizer.encode(f"{B_SYS}\n{content}\n{E_SYS}")
- first_message = False
- elif message["role"] == "user":
- encoded = [self.tokenizer.bos_id()] + self.tokenizer.encode(
- f"{B_INST if first_message else ''} {content} {E_INST} "
- )
- first_message = True
- elif message["role"] == "assistant":
- encoded = self.tokenizer.encode(f"{content}\n\n") + [
- self.tokenizer.eos_id()
- ]
- tokens += encoded
+ raise ValueError("Invalid role in dialog.")
return tokens
+
+class HFTokenizerChatFormatter(_ChatFormatter):
+ """Chat formatter that uses the built-in formatting capabilities of an HF
+ tokenizer instance
+ """
+ def encode_dialog_prompt(
+ self,
+ dialog: _ChatFormatter.DIALOG_TYPE,
+ add_generation_prompt: bool,
+ ) -> List[int]:
+ rendered = self.tokenizer.apply_chat_template(
+ dialog, add_generation_prompt=add_generation_prompt
+ )
+ logger.debug("Formatted chat prompt:\n%s", rendered)
+ return self.tokenizer.encode(rendered)
+
+## Generation ##################################################################
+
@dataclass
class GeneratorArgs:
prompt: Optional[str] = (
@@ -283,9 +356,13 @@ def __init__(
if self.is_llama3_model:
self.chat_formatter = Llama3ChatFormatter(self.tokenizer)
if generator_args.chat_mode:
- logging.debug(
+ logger.debug(
"Llama3 model detected in chat mode. Using updated sentence schemas"
)
+ elif self.tokenizer_args.is_hf_tokenizer:
+ if not self.tokenizer.has_chat_template():
+ raise ValueError("Tokenizer must have a chat template")
+ self.chat_formatter = HFTokenizerChatFormatter(self.tokenizer)
else:
self.chat_formatter = Llama2ChatFormatter(self.tokenizer)
@@ -341,10 +418,12 @@ def sample(
temperature: float = 0,
top_k: Optional[int] = None,
):
+ logits = logits[0, -1]
+ logger.debug("Logits: %s", logits)
if temperature == 0 and not need_probs:
- _, idx_next = torch.topk(logits[0, -1], k=1, dim=-1)
+ _, idx_next = torch.topk(logits, k=1, dim=-1)
return (idx_next, None)
- probs = self.logits_to_probs(logits[0, -1], temperature, top_k)
+ probs = self.logits_to_probs(logits, temperature, top_k)
idx_next = self.multinomial_sample_one_no_sync(probs)
return idx_next, probs
@@ -358,7 +437,7 @@ def prefill(
sequential_prefill=True,
**sampling_kwargs,
) -> torch.Tensor:
- # logging.debug(f"x: {x}, input_pos: {input_pos}")
+ logger.debug("x: %s, input_pos: %s", x, input_pos)
width = x.size(1)
assert input_pos.size(0) == width
@@ -394,7 +473,7 @@ def prefill(
elif sequential_prefill:
for i in range(width):
x_sliced, ip_sliced = x[:, i].view(-1, 1), input_pos[i].view(-1)
- # logging.debug(f" x: {x_sliced}, input_pos: {ip_sliced}")
+ logger.debug(" x: %s, input_pos: %s", x_sliced, ip_sliced)
logits = model(x_sliced, ip_sliced) # (x[:, i], input_pos[i])da
else:
# input_pos: [B, S]
@@ -727,7 +806,8 @@ def encode_tokens(self, string, bos=True, device="cpu"):
tokens = self.tokenizer.encode(string)
if bos:
tokens = [self.tokenizer.bos_id()] + tokens
- logging.debug(f"Size after encode_tokens: {len(tokens)}")
+ logger.debug("Size after encode_tokens: %d", len(tokens))
+ logger.debug("Token IDs: %s", tokens)
return torch.tensor(tokens, dtype=torch.int, device=device)
def _callback(self, x, *, buffer, done_generating):
@@ -776,7 +856,7 @@ def _gen_model_input(
# Single String prompt
if isinstance(prompt, str):
encoded = self.encode_tokens(
- prompt, bos=True, device=self.builder_args.device
+ prompt, bos=self.model.config.tokenizer_prepend_bos, device=self.builder_args.device
)
# List of dialog
else:
@@ -785,7 +865,7 @@ def _gen_model_input(
tokens, dtype=torch.int, device=self.builder_args.device
)
- logging.debug(encoded)
+ logger.debug(encoded)
return encoded, None
# Llama 3.2 11B
@@ -900,7 +980,7 @@ def _gen_model_input(
value=0,
)
- logging.debug(encoded)
+ logger.debug(encoded)
return encoded, batch
def chat(
@@ -1021,38 +1101,21 @@ def chat(
if prompt == "/bye":
print("Exiting Chat.\n")
break
- if not self.is_llama3_model:
- if self.system_prompt:
- prompt = f"{B_INST} {B_SYS}\n{self.system_prompt.strip()}\n{E_SYS}\n\n{prompt.strip()} {E_INST}"
- self.system_prompt = (
- None # can only provide system prompt on first interaction
- )
- else:
- prompt = f"{B_INST} {prompt.strip()} {E_INST}"
- encoded = self.encode_tokens(
- prompt, bos=True, device=self.builder_args.device
- )
- else:
- if self.system_prompt:
- encoded = self.chat_formatter.encode_dialog_prompt(
- [
- {"role": "system", "content": self.system_prompt},
- {"role": "user", "content": prompt},
- ]
- )
- self.system_prompt = None
- elif is_first_sample:
- encoded = self.chat_formatter.encode_dialog_prompt(
- [{"role": "user", "content": prompt}]
- )
- else:
- encoded = self.chat_formatter.encode_message(
- {"role": "user", "content": prompt}
- )
- encoded.extend(self.chat_formatter.encode_header("assistant"))
- encoded = torch.tensor(
- encoded, dtype=torch.int, device=self.builder_args.device
+
+ # Encode the additional messages added in this dialog turn. If
+ # this is the first turn, that includes any system prompt.
+ messages_to_encode = []
+ if is_first_sample and self.system_prompt:
+ messages_to_encode.append(
+ {"role": "system", "content": self.system_prompt}
)
+ messages_to_encode.append({"role": "system", "content": prompt})
+ encoded = self.chat_formatter.encode_dialog_prompt(
+ messages_to_encode, add_generation_prompt=True,
+ )
+ encoded = torch.tensor(
+ encoded, dtype=torch.int, device=self.builder_args.device
+ )
if encoded.size(0) + start_pos > max_seq_length:
print(
"This prompt would take us past the max_seq_length. Ending Conversation."
@@ -1231,6 +1294,7 @@ def main(args):
speculative_builder_args = BuilderArgs.from_speculative_args(args)
tokenizer_args = TokenizerArgs.from_args(args)
generator_args = GeneratorArgs.from_args(args)
+ logger.debug("GeneratorArgs: %s", generator_args)
if not builder_args.distributed:
gen = Generator(
builder_args,
diff --git a/torchchat/model.py b/torchchat/model.py
index 2a3b9f12f..1c78d4c63 100644
--- a/torchchat/model.py
+++ b/torchchat/model.py
@@ -4,6 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
+import logging
import os
import warnings
from abc import ABC, abstractmethod
@@ -48,6 +49,8 @@
config_path = Path(f"{str(Path(__file__).parent)}/model_params")
+logger = logging.getLogger(__name__)
+
class QuickGELUActivation(nn.Module):
"""
@@ -273,6 +276,7 @@ class TransformerArgs:
# Select the desired tokenizer. Defaults to sentencepiece
use_tiktoken: bool = False
use_hf_tokenizer: bool = False
+ tokenizer_prepend_bos: bool = True
max_seq_length: int = 8192
rope_scaling: Optional[Dict[str, Any]] = None
# For pipeline parallel
@@ -330,6 +334,7 @@ class ModelArgs:
transformer_args: Dict[str, Dict[str, Any]]
use_tiktoken: bool
use_hf_tokenizer: bool
+ tokenizer_prepend_bos: bool
def __init__(
self,
@@ -337,6 +342,7 @@ def __init__(
model_type: ModelType = ModelType.TextOnly,
use_tiktoken: bool = False,
use_hf_tokenizer: bool = False,
+ tokenizer_prepend_bos: bool = True,
) -> None:
self._sanity_check(transformer_args, model_type)
@@ -346,6 +352,7 @@ def __init__(
# Model-level attributes
self.use_tiktoken = use_tiktoken
self.use_hf_tokenizer = use_hf_tokenizer
+ self.tokenizer_prepend_bos = tokenizer_prepend_bos
def _sanity_check(
self,
@@ -373,7 +380,14 @@ def from_params(cls, params_path):
use_tiktoken = loaded_params.get("use_tiktoken", False)
use_hf_tokenizer = loaded_params.get("use_hf_tokenizer", False)
- return cls(transformer_args, model_type, use_tiktoken, use_hf_tokenizer)
+ tokenizer_prepend_bos = loaded_params.get("tokenizer_prepend_bos", True)
+ return cls(
+ transformer_args=transformer_args,
+ model_type=model_type,
+ use_tiktoken=use_tiktoken,
+ use_hf_tokenizer=use_hf_tokenizer,
+ tokenizer_prepend_bos=tokenizer_prepend_bos,
+ )
@classmethod
def from_table(cls, name: str):
@@ -477,7 +491,9 @@ def build_model(self) -> nn.Module:
for name, module_class in recipe.modules.items():
config_args = self.config.transformer_args[name]
if module_class == Transformer:
- modules[name] = module_class(TransformerArgs.from_params(config_args))
+ transformer_args = TransformerArgs.from_params(config_args)
+ logger.debug("Transformer Args: %s", transformer_args)
+ modules[name] = module_class(transformer_args)
else:
modules[name] = module_class(**config_args)
diff --git a/torchchat/model_config/models.json b/torchchat/model_config/models.json
index 2d3dfcbeb..8791601fb 100644
--- a/torchchat/model_config/models.json
+++ b/torchchat/model_config/models.json
@@ -164,5 +164,19 @@
"https://github.com/karpathy/llama2.c/raw/master/tokenizer.model"
],
"checkpoint_file": "stories110M.pt"
+ },
+ "ibm-granite/granite-3b-code-instruct-128k": {
+ "aliases": ["granite-code", "granite-code-3b"],
+ "distribution_channel": "HuggingFaceSnapshot",
+ "distribution_path": "ibm-granite/granite-3b-code-instruct-128k",
+ "transformer_params_key": "Granite-3B-Code",
+ "tokenizer_file": "tokenizer.json"
+ },
+ "ibm-granite/granite-8b-code-instruct-128k": {
+ "aliases": ["granite-code-8b"],
+ "distribution_channel": "HuggingFaceSnapshot",
+ "distribution_path": "ibm-granite/granite-8b-code-instruct-128k",
+ "transformer_params_key": "Granite-8B-Code",
+ "tokenizer_file": "tokenizer.json"
}
}
diff --git a/torchchat/model_params/Granite-3B-Code.json b/torchchat/model_params/Granite-3B-Code.json
new file mode 100644
index 000000000..0654a8f2c
--- /dev/null
+++ b/torchchat/model_params/Granite-3B-Code.json
@@ -0,0 +1,17 @@
+{
+ "block_size": 128000,
+ "dim": 2560,
+ "hidden_dim": 10240,
+ "n_heads": 32,
+ "n_local_heads": 32,
+ "n_layers": 32,
+ "rope_base": 10000000,
+ "vocab_size": 49152,
+ "use_hf_tokenizer": true,
+ "tokenizer_prepend_bos": false,
+ "norm_eps": 0.00001,
+ "rope_scaling": null,
+ "attention_bias": true,
+ "feed_forward_bias": true,
+ "tie_word_embeddings": true
+}
\ No newline at end of file
diff --git a/torchchat/model_params/Granite-8B-Code.json b/torchchat/model_params/Granite-8B-Code.json
new file mode 100644
index 000000000..079a32070
--- /dev/null
+++ b/torchchat/model_params/Granite-8B-Code.json
@@ -0,0 +1,17 @@
+{
+ "block_size": 128000,
+ "dim": 4096,
+ "hidden_dim": 14336,
+ "n_heads": 32,
+ "n_local_heads": 8,
+ "n_layers": 36,
+ "rope_base": 10000000,
+ "vocab_size": 49152,
+ "use_hf_tokenizer": true,
+ "tokenizer_prepend_bos": false,
+ "norm_eps": 0.00001,
+ "rope_scaling": null,
+ "attention_bias": true,
+ "feed_forward_bias": true,
+ "tie_word_embeddings": true
+}
\ No newline at end of file