Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions install/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ gguf
# Tiktoken tokenizer for Llama 3 and other advanced models
tiktoken

# Tokenizers and jinja2 for other non-llama models that use HF tokenizers
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these here, but did not add pytest (yet). I think there's a pending conversation about introducing optional dependency sets, so it would make sense to add a test or dev set at that point, but I didn't want to accidentally carry pytest along as a runtime dependency.

tokenizers
jinja2

# Miscellaneous
snakeviz
sentencepiece
Expand Down
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""
Global pytest config, fixtures, and helpers go here!
"""

# Standard
import os
import sys

# Make sure tests can import torchchat
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be a lot cleaner if we move to having a pyproject.toml or setup.py to bundle torchchat as a package that could be installed with pip install -e.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

on the list

sys.path.append(
os.path.realpath(os.path.join(os.path.dirname(__file__), ".."))
)
216 changes: 216 additions & 0 deletions tests/test_chat_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""
Unit tests for chat formatters
"""

# Third Party
import pytest

# Local
from torchchat.generate import (
HFTokenizerChatFormatter,
Llama2ChatFormatter,
Llama3ChatFormatter,
)

## Helpers #####################################################################

class DummyTokenizer:
"""Dummy tokenizer that encodes as strings so it's easy to check formatting"""
def encode(self, text, *_, **__):
return text


class DummySPTokenizer(DummyTokenizer):
"""Emulated Sentencepiece tokenizer with bos/eos"""
bos = "<s>"
eos = "</s>"


class DummyLlama3Tokenizer(DummyTokenizer):
class _IdentityDict:
def __getitem__(self, key):
return key
special_tokens = _IdentityDict()


class DummyHFTokenizer(DummyTokenizer):
"""Dummy made up chat template scheme"""
# Sequence
bos = "<bos>"
# Turn
bot = "<bot>"
eot = "<eot>"
# Role
bor = "<bor>"
eor = "<eor>"
def apply_chat_template(self, messages, add_generation_prompt):
out = [self.bos]
role = None
for msg in messages:
role = msg["role"]
content = msg["content"]
out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}")
if add_generation_prompt and role != "assistant":
out.append(f"{self.bot}{self.bor}assistant{self.eor}")
return "\n".join(out)


def check_rendering(fmt, messages, expected, add_generation_prompt):
"""Render messages and compare to expected output"""
assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected


def make_message(role, text):
return {"role": role, "content": text}


SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything."
USER1 = "Hello world!"
ASSISTANT1 = "Greetings! How can I help you?"
USER2 = "Why is the sky blue?"
ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering."


# Stock sets of messages to test
MSGS_NO_SYS= [
make_message("user", USER1),
]
MSGS_SYS_USR = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
]
MSGS_SYS_USR_ASST = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
make_message("assistant", ASSISTANT1),
]
MSGS_MULTI_TURN = [
make_message("system", SYSTEM_PROMPT),
make_message("user", USER1),
make_message("assistant", ASSISTANT1),
make_message("user", USER2),
make_message("assistant", ASSISTANT2),
]

## Llama2ChatFormatter #########################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"),
# sys, usr
(MSGS_SYS_USR, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>

{USER1} [/INST]"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>

{USER1} [/INST] {ASSISTANT1} </s>
"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>>
{SYSTEM_PROMPT}
<</SYS>>

{USER1} [/INST] {ASSISTANT1} </s>
<s>[INST] {USER2} [/INST] {ASSISTANT2} </s>
"""),
]
)
def test_llama2_chat_formatter(messages, expected):
"""Tests for Llama2 following the official guide
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/
"""
tok = DummySPTokenizer()
fmt = Llama2ChatFormatter(tok)
# NOTE: add_generation_prompt not used by Llama2
check_rendering(fmt, messages, expected, True)

## Llama3ChatFormatter #########################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|>"""),
# sys, usr
(MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|>"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{ASSISTANT1}<|eot_id|>"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

{SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|>

{USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{ASSISTANT2}<|eot_id|>"""),
]
)
@pytest.mark.parametrize("add_generation_prompt", [True, False])
def test_llama3_chat_formatter(messages, expected, add_generation_prompt):
"""Tests for Llama3 following the official guide
https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
"""
tok = DummyLlama3Tokenizer()
fmt = Llama3ChatFormatter(tok)
# No assistant prompt added if the last message is from the assistant
if add_generation_prompt and messages[-1]["role"] != "assistant":
expected += "<|start_header_id|>assistant<|end_header_id|>\n\n"
check_rendering(fmt, messages, expected, add_generation_prompt)

## HFTokenizerChatFormatter ####################################################

@pytest.mark.parametrize(
["messages", "expected"],
[
# single user message (no system prompt)
(MSGS_NO_SYS, f"""<bos>
<bot><bor>user<eor>{USER1}<eot>"""),
# sys, usr
(MSGS_SYS_USR, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>"""),
# sys, usr, asst
(MSGS_SYS_USR_ASST, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>
<bot><bor>assistant<eor>{ASSISTANT1}<eot>"""),
# sys, usr, asst, usr, asst
(MSGS_MULTI_TURN, f"""<bos>
<bot><bor>system<eor>{SYSTEM_PROMPT}<eot>
<bot><bor>user<eor>{USER1}<eot>
<bot><bor>assistant<eor>{ASSISTANT1}<eot>
<bot><bor>user<eor>{USER2}<eot>
<bot><bor>assistant<eor>{ASSISTANT2}<eot>"""),
]
)
@pytest.mark.parametrize("add_generation_prompt", [True, False])
def test_hf_chat_formatter(messages, expected, add_generation_prompt):
tok = DummyHFTokenizer()
fmt = HFTokenizerChatFormatter(tok)
# No assistant prompt added if the last message is from the assistant
if add_generation_prompt and messages[-1]["role"] != "assistant":
expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}"
check_rendering(fmt, messages, expected, add_generation_prompt)
28 changes: 27 additions & 1 deletion tokenizer/hf_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.

# Standard
from typing import List, Optional
from typing import Dict, List, Optional
import json
import os

# Third Party
import jinja2
from tokenizers import Tokenizer

# Local
Expand Down Expand Up @@ -37,6 +38,9 @@ def __init__(self, file_path: str):
# Load the tokenizer itself
self._tokenizer = Tokenizer.from_file(tokenizer_path)

# Load the chat template if we have a config path
self._chat_template: Optional[jinja2.Template] = None

# If available, parse bos/eos tokens from the tokenizer config
self._bos_id, self._eos_id = None, None
if tokenizer_config_path is not None:
Expand All @@ -48,6 +52,8 @@ def __init__(self, file_path: str):
self._bos_id = self._tokenizer.token_to_id(bos_token)
if eos_token is not None:
self._eos_id = self._tokenizer.token_to_id(eos_token)
if chat_template_str := tok_config.get("chat_template"):
self._chat_template = jinja2.Template(chat_template_str)

# If no eos/bos tokens found, go looking for them!
if None in [self._bos_id, self._eos_id]:
Expand All @@ -70,6 +76,8 @@ def _look_for_special_token(added_tokens: dict, search_strs: List[str]) -> Optio
if len(candidate_toks) == 1:
return candidate_toks[0]["id"]

## Interface ##

def encode(
self,
s: str,
Expand All @@ -90,3 +98,21 @@ def bos_id(self) -> int:

def eos_id(self) -> int:
return self._eos_id

## Additional Public Methods ##

def has_chat_template(self) -> bool:
return bool(self._chat_template)

def apply_chat_template(
self,
dialog: List[Dict[str, str]],
add_generation_prompt: bool = False,
) -> str:
"""If configured with a chat template, apply it to the list of messages
"""
if not self._chat_template:
raise ValueError("No chat template configured!")
return self._chat_template.render(
messages=dialog, add_generation_prompt=add_generation_prompt
)
10 changes: 9 additions & 1 deletion torchchat/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,15 @@
allowable_params_table,
)

logging.basicConfig(level=logging.INFO, format="%(message)s")
_log_level_env = os.getenv("LOG_LEVEL", "INFO")
try:
_log_level = getattr(logging, _log_level_env.upper())
except AttributeError:
print(f"Invalid log level: {_log_level_env}", file=sys.stderr)
_log_level = logging.INFO


logging.basicConfig(level=_log_level, format="%(message)s")
logger = logging.getLogger(__name__)

default_device = os.getenv("TORCHCHAT_DEVICE", "fast")
Expand Down
Loading
Loading