Skip to content

Commit 56bb724

Browse files
committed
model : add reasoning/tool support for Llama 3.x Nemotron
1 parent 19f68fa commit 56bb724

File tree

8 files changed

+210
-1
lines changed

8 files changed

+210
-1
lines changed

common/chat-parser.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,20 @@ bool common_chat_msg_parser::try_consume_literal(const std::string & literal) {
9898
return true;
9999
}
100100

101+
bool common_chat_msg_parser::try_consume_partial_literal(const std::string & literal) {
102+
if (is_partial_) {
103+
auto idx = string_find_partial_stop(input_, literal);
104+
if (idx != std::string::npos && idx >= pos_) {
105+
auto end = input_.size();
106+
if (end < idx + literal.size()) {
107+
throw common_chat_msg_partial_exception(literal);
108+
}
109+
}
110+
}
111+
112+
return try_consume_literal(literal);
113+
}
114+
101115
std::optional<common_chat_msg_parser::find_regex_result> common_chat_msg_parser::try_find_literal(const std::string & literal) {
102116
auto idx = input_.find(literal, pos_);
103117
if (idx != std::string::npos) {
@@ -145,7 +159,7 @@ bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think
145159
}
146160
};
147161
if (syntax_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
148-
if (syntax_.thinking_forced_open || try_consume_literal(start_think)) {
162+
if (syntax_.thinking_forced_open || try_consume_partial_literal(start_think)) {
149163
if (auto res = try_find_literal(end_think)) {
150164
handle_reasoning(res->prelude, /* closed */ true);
151165
consume_spaces();

common/chat-parser.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class common_chat_msg_parser {
8282
std::optional<find_regex_result> try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true);
8383

8484
bool try_consume_literal(const std::string & literal);
85+
bool try_consume_partial_literal(const std::string & literal);
8586

8687
std::optional<find_regex_result> try_find_literal(const std::string & literal);
8788

common/chat.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,7 @@ const char * common_chat_format_name(common_chat_format format) {
586586
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
587587
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
588588
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
589+
case COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON: return "Llama 3.x Nemotron";
589590
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
590591
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
591592
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
@@ -1698,6 +1699,57 @@ static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
16981699
builder.add_content(builder.consume_rest());
16991700
}
17001701

1702+
static common_chat_params common_chat_params_init_llama_3_x_nemotron(const common_chat_template & tmpl, const struct templates_params & inputs) {
1703+
common_chat_params data;
1704+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1705+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1706+
auto schemas = json::array();
1707+
foreach_function(inputs.tools, [&](const json & tool) {
1708+
const auto & function = tool.at("function");
1709+
schemas.push_back({
1710+
{"type", "object"},
1711+
{"properties", {
1712+
{"name", {
1713+
{"type", "string"},
1714+
{"const", function.at("name")},
1715+
}},
1716+
{"arguments", function.at("parameters")},
1717+
}},
1718+
{"required", json::array({"name", "arguments"})},
1719+
});
1720+
});
1721+
auto schema = json {
1722+
{"type", "array"},
1723+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
1724+
{"minItems", 1},
1725+
};
1726+
builder.add_rule("root", "\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) + " \"</TOOLCALL>\"");
1727+
});
1728+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<TOOLCALL>"});
1729+
data.preserved_tokens = {
1730+
"<TOOLCALL>",
1731+
"</TOOLCALL>"
1732+
};
1733+
data.prompt = apply(tmpl, inputs);
1734+
data.format = COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON;
1735+
return data;
1736+
}
1737+
static void common_chat_parse_llama_3_x_nemotron(common_chat_msg_parser & builder) {
1738+
builder.try_parse_reasoning("<think>", "</think>");
1739+
if (!builder.syntax().parse_tool_calls) {
1740+
builder.add_content(builder.consume_rest());
1741+
return;
1742+
}
1743+
1744+
static const common_regex prefix(regex_escape("<TOOLCALL>"));
1745+
static const common_regex suffix(regex_escape("</TOOLCALL>"));
1746+
1747+
parse_prefixed_json_tool_call_array(builder, prefix);
1748+
if (!builder.try_find_regex(suffix)) {
1749+
builder.consume_rest();
1750+
}
1751+
}
1752+
17011753
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
17021754
common_chat_params data;
17031755
data.prompt = apply(tmpl, inputs);
@@ -1800,6 +1852,11 @@ static common_chat_params common_chat_templates_apply_jinja(
18001852
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
18011853
}
18021854

1855+
// Llama 3.3 Nemo (w/ tools)
1856+
if (src.find("<TOOLCALL>") != std::string::npos) {
1857+
return common_chat_params_init_llama_3_x_nemotron(tmpl, params);
1858+
}
1859+
18031860
// Plain handler (no tools)
18041861
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
18051862
return common_chat_params_init_without_tools(tmpl, params);
@@ -1905,6 +1962,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
19051962
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
19061963
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
19071964
break;
1965+
case COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON:
1966+
common_chat_parse_llama_3_x_nemotron(builder);
1967+
break;
19081968
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
19091969
common_chat_parse_deepseek_r1(builder);
19101970
break;

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ enum common_chat_format {
103103
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
104104
COMMON_CHAT_FORMAT_LLAMA_3_X,
105105
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
106+
COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
106107
COMMON_CHAT_FORMAT_DEEPSEEK_R1,
107108
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
108109
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,

models/templates/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ These templates can be updated with the following commands:
1818
./scripts/get_chat_template.py mistralai/Mistral-Nemo-Instruct-2407 > models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja
1919
./scripts/get_chat_template.py NousResearch/Hermes-2-Pro-Llama-3-8B tool_use > models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja
2020
./scripts/get_chat_template.py NousResearch/Hermes-3-Llama-3.1-8B tool_use > models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja
21+
./scripts/get_chat_template.py nvidia/Llama-3_3-Nemotron-Super-49B-v1_5 > models/templates/nvidia/nvidia-Llama-3_3-Nemotron-Super-49B-v1_5.jinja
2122
./scripts/get_chat_template.py Qwen/Qwen2.5-7B-Instruct > models/templates/Qwen-Qwen2.5-7B-Instruct.jinja
2223
./scripts/get_chat_template.py Qwen/QwQ-32B > models/templates/Qwen-QwQ-32B.jinja
2324
./scripts/get_chat_template.py Qwen/Qwen3-0.6B > models/templates/Qwen-Qwen3-0.6B.jinja
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{% set bos = "<|begin_of_text|>" %}{%- set enable_thinking = true -%}{% set system_start_header = "<|start_header_id|>" %}{% set system_end_header = "<|end_header_id|>
2+
3+
" %}{% set start_header = "<|start_header_id|>" %}{% set end_header = "<|end_header_id|>
4+
5+
" %}{% set eot = "<|eot_id|>" %}{% set system_token = "system" %}{% set user_token = "user" %}{% set assistant_token = "assistant" %}{% set tool_token = "tool" %}{{- bos ~ system_start_header ~ system_token ~ system_end_header -}}{%- if messages[0].role == 'system' and messages[0].content != '' -%}{%- set system_content = messages[0].content -%}{%- if '/no_think' in system_content -%}{%- set system_content = system_content.replace('/no_think', '')|trim -%}{%- set enable_thinking = false -%}{%- elif '/think' in system_content -%}{%- set system_content = system_content.replace('/think', '')|trim -%}{%- set enable_thinking = true -%}{%- endif -%}{{- system_content + '
6+
7+
' -}}{%- endif -%}{%- if tools -%}{{- 'You can use the following tools to assist the user if required:
8+
<AVAILABLE_TOOLS>[' -}}{%- for tool in tools -%}{{- (tool.function if tool.function is defined else tool) | tojson -}}{{- ', ' if not loop.last else '' -}}{%- endfor -%}{{- ']</AVAILABLE_TOOLS>
9+
10+
If you decide to call any tool(s), use the following format:
11+
<TOOLCALL>[{{"name": "tool_name1", "arguments": "tool_args1"}}, {{"name": "tool_name2", "arguments": "tool_args2"}}]</TOOLCALL>
12+
13+
Response from tool(s) will be returned in this format:
14+
<TOOL_RESPONSE>[{{"response": "tool_response1"}}, {{"response": "tool_response2"}}]</TOOL_RESPONSE>
15+
16+
Based on the results returned by the tool(s), you can call additional tools if needed, correct tool calls if any errors are found, or just respond with the answer to the user.' -}}{%- endif -%}{{- eot -}}{%- for message in messages -%}{%- if message.role == user_token -%}{{- start_header ~ user_token ~ end_header -}}{{ message.content -}}{{ eot -}}{%- elif message.role == assistant_token -%}{%- if '</think>' in message.content -%}{%- set content = message.content.split('</think>')[-1].lstrip() -%}{%- else -%}{%- set content = message.content -%}{%- endif -%}{{- start_header ~ assistant_token ~ end_header -}}{{ content -}}{%- if message.tool_calls -%}{{- '<TOOLCALL>[' -}}{%- for call in message.tool_calls -%}{%- set fn = call.function if call.function is defined else call -%}{{- '{"name": "' + fn.name + '", "arguments": ' -}}{%- if fn.arguments is string -%}{{- fn.arguments -}}{%- else -%}{{- fn.arguments | tojson -}}{%- endif -%}{{- '}' + (', ' if not loop.last else '') -}}{%- endfor -%}{{- ']</TOOLCALL>' -}}{%- endif -%}{{- eot -}}{%- elif message.role == tool_token -%}{%- if loop.first or (messages[loop.index0 - 1].role != tool_token) -%}{{- start_header ~ tool_token ~ end_header -}}{{ '<TOOL_RESPONSE>[' -}}{%- endif -%}{{- message.content -}}{{- ', ' if not loop.last and (messages[loop.index0 + 1].role == tool_token) else '' -}}{%- if loop.last or (messages[loop.index0 + 1].role != tool_token) -%}{{- ']</TOOL_RESPONSE>' -}}{{ eot -}}{%- endif -%}{%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%}{{- start_header ~ assistant_token ~ end_header -}}{%- if not enable_thinking -%}{{- '<think>
17+
18+
</think>
19+
20+
' -}}{%- endif -%}{%- endif -%}

tests/test-chat-parser.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,47 @@ static void test_reasoning() {
9999
assert_equals("<think>Cogito</think>", builder.result().content);
100100
assert_equals("Ergo sum", builder.consume_rest());
101101
}
102+
{
103+
common_chat_msg_parser builder("<tnk>Cogito", /* is_partial= */ true, {
104+
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
105+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
106+
/* .reasoning_in_content = */ false,
107+
/* .thinking_forced_open = */ false,
108+
});
109+
110+
assert_equals(true, builder.try_parse_reasoning("<tnk>", "</tnk>"));
111+
assert_equals("Cogito", builder.result().reasoning_content);
112+
assert_equals("", builder.consume_rest());
113+
}
114+
{
115+
common_chat_msg_parser builder("<t", /* is_partial= */ true, {
116+
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
117+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
118+
/* .reasoning_in_content = */ false,
119+
/* .thinking_forced_open = */ false,
120+
});
121+
122+
try {
123+
builder.try_parse_reasoning("<tnk>", "</tnk>");
124+
throw std::runtime_error("Expected exception");
125+
} catch (const std::exception & e) {
126+
if (std::string(e.what()).find("<tnk>") == std::string::npos) {
127+
throw std::runtime_error("Expected exception about partial <tnk>");
128+
}
129+
}
130+
}
131+
{
132+
common_chat_msg_parser builder("<think>Cogito", /* is_partial= */ true, {
133+
/* .format = */ COMMON_CHAT_FORMAT_CONTENT_ONLY,
134+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
135+
/* .reasoning_in_content = */ false,
136+
/* .thinking_forced_open = */ false,
137+
});
138+
139+
assert_equals(false, builder.try_parse_reasoning("<tnk>", "</tnk>"));
140+
assert_equals("", builder.result().reasoning_content);
141+
assert_equals("<think>Cogito", builder.consume_rest());
142+
}
102143
}
103144

104145
static void test_regex() {

tests/test-chat.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,77 @@ static void test_template_output_parsers() {
13861386
"{\"arg1\": 1}\n"
13871387
"```<|tool▁call▁end|><|tool▁calls▁end|>");
13881388
}
1389+
{
1390+
auto tmpls = read_templates("models/templates/nvidia-Llama-3_3-Nemotron-Super-49B-v1_5.jinja");
1391+
std::vector<std::string> end_tokens{ "<|eot_id|>" };
1392+
1393+
assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON, common_chat_templates_apply(tmpls.get(), inputs_tools).format);
1394+
1395+
test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
1396+
test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false);
1397+
1398+
assert_msg_equals(message_assist_thoughts_unparsed_deepseek,
1399+
common_chat_parse(
1400+
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
1401+
/* is_partial= */ false,
1402+
{COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON}));
1403+
assert_msg_equals(message_assist_thoughts,
1404+
common_chat_parse(
1405+
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
1406+
/* is_partial= */ false,
1407+
{
1408+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1409+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1410+
}));
1411+
assert_msg_equals(message_assist_thoughts,
1412+
common_chat_parse(
1413+
"I'm\nthinking</think>Hello, world!\nWhat's up?",
1414+
/* is_partial= */ false,
1415+
{
1416+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1417+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1418+
/* .reasoning_in_content = */ false,
1419+
/* .thinking_forced_open = */ true,
1420+
}));
1421+
1422+
assert_msg_equals(message_assist_call_thoughts_unparsed,
1423+
common_chat_parse(
1424+
"<think>I'm\nthinking</think>\n\n"
1425+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1426+
/* is_partial= */ false,
1427+
{COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON}));
1428+
assert_msg_equals(message_assist_call,
1429+
common_chat_parse(
1430+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1431+
/* is_partial= */ false,
1432+
{COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON}));
1433+
assert_msg_equals(message_assist_call_thoughts,
1434+
common_chat_parse(
1435+
"<think>I'm\nthinking</think>\n\n"
1436+
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
1437+
/* is_partial= */ false,
1438+
{
1439+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1440+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1441+
}));
1442+
1443+
assert_msg_equals(message_assist_empty,
1444+
common_chat_parse(
1445+
"<th",
1446+
/* is_partial= */ true,
1447+
{
1448+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1449+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1450+
}));
1451+
assert_msg_equals(message_assist_thoughts_no_content,
1452+
common_chat_parse(
1453+
"<think>I'm\nthinking",
1454+
/* is_partial= */ true,
1455+
{
1456+
/* .format = */ COMMON_CHAT_FORMAT_LLAMA_3_X_NEMOTRON,
1457+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
1458+
}));
1459+
}
13891460
}
13901461

13911462
static void test_msg_diffs_compute() {

0 commit comments

Comments
 (0)