This repository was archived by the owner on Sep 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 248
Granite code support #1336
Merged
Merged
Granite code support #1336
Changes from all commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
e3c2849
feat(models): Add models.json blocks for Granite Code 3b and 8b
gabe-l-hart 9d80e52
feat: Initial model params for granite code 3b
gabe-l-hart 85057bc
fix(model config): Fix model configs for Granite Code
gabe-l-hart 1e3addc
feat(granite): Add model params for granite-code-8b
gabe-l-hart 5d342fa
fix(deps): Add tokenizers to the deps explicitly
gabe-l-hart dadff61
feat(tokenizer): Add basic support for jinja2 template rendering for …
gabe-l-hart 2017e5d
fix(chat): Add HFTokenizerChatFormatter and use it for HF tokenizers
gabe-l-hart 38a649a
fix(deps): Add jinja2 as an explicit dep
gabe-l-hart 0b4f159
feat(log): Add env-based LOG_LEVEL config to CLI
gabe-l-hart 526ce15
feat(log): Add better logging in model and generate
gabe-l-hart c9f8a71
feat(generate): Make prepending BOS model-conigurable
gabe-l-hart ef132cb
fix(chat): Refactor chat template logic to encapsulate all formatting…
gabe-l-hart 8d26923
fix(chat): Fix small formatting bugs in llama3 chat formatter
gabe-l-hart 0390655
test: Add initial unit tests for chat formatters
gabe-l-hart 2d7e546
fix(logging): Disable logging in generate unless set in the env
gabe-l-hart 78a3637
fix: Remove trailing n from llama3 <|eot_id|>
gabe-l-hart 6fb3b98
Merge branch 'main' into GraniteCodeSupport
Jack-Khuu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| """ | ||
| Global pytest config, fixtures, and helpers go here! | ||
| """ | ||
|
|
||
| # Standard | ||
| import os | ||
| import sys | ||
|
|
||
| # Make sure tests can import torchchat | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would be a lot cleaner if we move to having a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. on the list |
||
| sys.path.append( | ||
| os.path.realpath(os.path.join(os.path.dirname(__file__), "..")) | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,216 @@ | ||
| """ | ||
| Unit tests for chat formatters | ||
| """ | ||
|
|
||
| # Third Party | ||
| import pytest | ||
|
|
||
| # Local | ||
| from torchchat.generate import ( | ||
| HFTokenizerChatFormatter, | ||
| Llama2ChatFormatter, | ||
| Llama3ChatFormatter, | ||
| ) | ||
|
|
||
| ## Helpers ##################################################################### | ||
|
|
||
| class DummyTokenizer: | ||
| """Dummy tokenizer that encodes as strings so it's easy to check formatting""" | ||
| def encode(self, text, *_, **__): | ||
| return text | ||
|
|
||
|
|
||
| class DummySPTokenizer(DummyTokenizer): | ||
| """Emulated Sentencepiece tokenizer with bos/eos""" | ||
| bos = "<s>" | ||
| eos = "</s>" | ||
|
|
||
|
|
||
| class DummyLlama3Tokenizer(DummyTokenizer): | ||
| class _IdentityDict: | ||
| def __getitem__(self, key): | ||
| return key | ||
| special_tokens = _IdentityDict() | ||
|
|
||
|
|
||
| class DummyHFTokenizer(DummyTokenizer): | ||
| """Dummy made up chat template scheme""" | ||
| # Sequence | ||
| bos = "<bos>" | ||
| # Turn | ||
| bot = "<bot>" | ||
| eot = "<eot>" | ||
| # Role | ||
| bor = "<bor>" | ||
| eor = "<eor>" | ||
| def apply_chat_template(self, messages, add_generation_prompt): | ||
| out = [self.bos] | ||
| role = None | ||
| for msg in messages: | ||
| role = msg["role"] | ||
| content = msg["content"] | ||
| out.append(f"{self.bot}{self.bor}{role}{self.eor}{content}{self.eot}") | ||
| if add_generation_prompt and role != "assistant": | ||
| out.append(f"{self.bot}{self.bor}assistant{self.eor}") | ||
| return "\n".join(out) | ||
|
|
||
|
|
||
| def check_rendering(fmt, messages, expected, add_generation_prompt): | ||
| """Render messages and compare to expected output""" | ||
| assert "".join(fmt.encode_dialog_prompt(messages, add_generation_prompt)) == expected | ||
|
|
||
|
|
||
| def make_message(role, text): | ||
| return {"role": role, "content": text} | ||
|
|
||
|
|
||
| SYSTEM_PROMPT = "You are a helpful assistant, feel free to ask me anything." | ||
| USER1 = "Hello world!" | ||
| ASSISTANT1 = "Greetings! How can I help you?" | ||
| USER2 = "Why is the sky blue?" | ||
| ASSISTANT2 = "The sky appears blue because of a phenomenon called Rayleigh scattering." | ||
|
|
||
|
|
||
| # Stock sets of messages to test | ||
| MSGS_NO_SYS= [ | ||
| make_message("user", USER1), | ||
| ] | ||
| MSGS_SYS_USR = [ | ||
| make_message("system", SYSTEM_PROMPT), | ||
| make_message("user", USER1), | ||
| ] | ||
| MSGS_SYS_USR_ASST = [ | ||
| make_message("system", SYSTEM_PROMPT), | ||
| make_message("user", USER1), | ||
| make_message("assistant", ASSISTANT1), | ||
| ] | ||
| MSGS_MULTI_TURN = [ | ||
| make_message("system", SYSTEM_PROMPT), | ||
| make_message("user", USER1), | ||
| make_message("assistant", ASSISTANT1), | ||
| make_message("user", USER2), | ||
| make_message("assistant", ASSISTANT2), | ||
| ] | ||
|
|
||
| ## Llama2ChatFormatter ######################################################### | ||
|
|
||
| @pytest.mark.parametrize( | ||
| ["messages", "expected"], | ||
| [ | ||
| # single user message (no system prompt) | ||
| (MSGS_NO_SYS, f"<s>[INST] {USER1} [/INST]"), | ||
| # sys, usr | ||
| (MSGS_SYS_USR, f"""<s>[INST] <<SYS>> | ||
| {SYSTEM_PROMPT} | ||
| <</SYS>> | ||
|
|
||
| {USER1} [/INST]"""), | ||
| # sys, usr, asst | ||
| (MSGS_SYS_USR_ASST, f"""<s>[INST] <<SYS>> | ||
| {SYSTEM_PROMPT} | ||
| <</SYS>> | ||
|
|
||
| {USER1} [/INST] {ASSISTANT1} </s> | ||
| """), | ||
| # sys, usr, asst, usr, asst | ||
| (MSGS_MULTI_TURN, f"""<s>[INST] <<SYS>> | ||
| {SYSTEM_PROMPT} | ||
| <</SYS>> | ||
|
|
||
| {USER1} [/INST] {ASSISTANT1} </s> | ||
| <s>[INST] {USER2} [/INST] {ASSISTANT2} </s> | ||
| """), | ||
| ] | ||
| ) | ||
| def test_llama2_chat_formatter(messages, expected): | ||
| """Tests for Llama2 following the official guide | ||
| https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-2/ | ||
| """ | ||
| tok = DummySPTokenizer() | ||
| fmt = Llama2ChatFormatter(tok) | ||
| # NOTE: add_generation_prompt not used by Llama2 | ||
| check_rendering(fmt, messages, expected, True) | ||
|
|
||
| ## Llama3ChatFormatter ######################################################### | ||
|
|
||
| @pytest.mark.parametrize( | ||
| ["messages", "expected"], | ||
| [ | ||
| # single user message (no system prompt) | ||
| (MSGS_NO_SYS, f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> | ||
|
|
||
| {USER1}<|eot_id|>"""), | ||
| # sys, usr | ||
| (MSGS_SYS_USR, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | ||
|
|
||
| {SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> | ||
|
|
||
| {USER1}<|eot_id|>"""), | ||
| # sys, usr, asst | ||
| (MSGS_SYS_USR_ASST, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | ||
|
|
||
| {SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> | ||
|
|
||
| {USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||
|
|
||
| {ASSISTANT1}<|eot_id|>"""), | ||
| # sys, usr, asst, usr, asst | ||
| (MSGS_MULTI_TURN, f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | ||
|
|
||
| {SYSTEM_PROMPT}<|eot_id|><|start_header_id|>user<|end_header_id|> | ||
|
|
||
| {USER1}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||
|
|
||
| {ASSISTANT1}<|eot_id|><|start_header_id|>user<|end_header_id|> | ||
|
|
||
| {USER2}<|eot_id|><|start_header_id|>assistant<|end_header_id|> | ||
|
|
||
| {ASSISTANT2}<|eot_id|>"""), | ||
| ] | ||
| ) | ||
| @pytest.mark.parametrize("add_generation_prompt", [True, False]) | ||
| def test_llama3_chat_formatter(messages, expected, add_generation_prompt): | ||
| """Tests for Llama3 following the official guide | ||
| https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/ | ||
| """ | ||
| tok = DummyLlama3Tokenizer() | ||
| fmt = Llama3ChatFormatter(tok) | ||
| # No assistant prompt added if the last message is from the assistant | ||
| if add_generation_prompt and messages[-1]["role"] != "assistant": | ||
| expected += "<|start_header_id|>assistant<|end_header_id|>\n\n" | ||
| check_rendering(fmt, messages, expected, add_generation_prompt) | ||
|
|
||
| ## HFTokenizerChatFormatter #################################################### | ||
|
|
||
| @pytest.mark.parametrize( | ||
| ["messages", "expected"], | ||
| [ | ||
| # single user message (no system prompt) | ||
| (MSGS_NO_SYS, f"""<bos> | ||
| <bot><bor>user<eor>{USER1}<eot>"""), | ||
| # sys, usr | ||
| (MSGS_SYS_USR, f"""<bos> | ||
| <bot><bor>system<eor>{SYSTEM_PROMPT}<eot> | ||
| <bot><bor>user<eor>{USER1}<eot>"""), | ||
| # sys, usr, asst | ||
| (MSGS_SYS_USR_ASST, f"""<bos> | ||
| <bot><bor>system<eor>{SYSTEM_PROMPT}<eot> | ||
| <bot><bor>user<eor>{USER1}<eot> | ||
| <bot><bor>assistant<eor>{ASSISTANT1}<eot>"""), | ||
| # sys, usr, asst, usr, asst | ||
| (MSGS_MULTI_TURN, f"""<bos> | ||
| <bot><bor>system<eor>{SYSTEM_PROMPT}<eot> | ||
| <bot><bor>user<eor>{USER1}<eot> | ||
| <bot><bor>assistant<eor>{ASSISTANT1}<eot> | ||
| <bot><bor>user<eor>{USER2}<eot> | ||
| <bot><bor>assistant<eor>{ASSISTANT2}<eot>"""), | ||
| ] | ||
| ) | ||
| @pytest.mark.parametrize("add_generation_prompt", [True, False]) | ||
| def test_hf_chat_formatter(messages, expected, add_generation_prompt): | ||
| tok = DummyHFTokenizer() | ||
| fmt = HFTokenizerChatFormatter(tok) | ||
| # No assistant prompt added if the last message is from the assistant | ||
| if add_generation_prompt and messages[-1]["role"] != "assistant": | ||
| expected += f"\n{tok.bot}{tok.bor}assistant{tok.eor}" | ||
| check_rendering(fmt, messages, expected, add_generation_prompt) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added these here, but did not add
pytest(yet). I think there's a pending conversation about introducing optional dependency sets, so it would make sense to add atestordevset at that point, but I didn't want to accidentally carrypytestalong as a runtime dependency.