Skip to content
Merged
97 changes: 97 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
default:
throw std::runtime_error("Unknown chat format");
}
Expand Down Expand Up @@ -1184,6 +1185,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
});
return data;
}

static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;

// Generate the prompt using the apply() function with the template
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;

// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}

// When tools are present, build grammar for the <TOOLCALL> format, similar to CommandR, but without tool call ID
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = true;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
schemas.push_back({
{ "type", "object" },
{ "properties",
{
{ "name",
{
{ "type", "string" },
{ "const", function.at("name") },
} },
{ "arguments", function.at("parameters") },
} },
{ "required", json::array({ "name", "arguments" }) },
});
});
auto schema = json{
{ "type", "array" },
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
{ "minItems", 1 },
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
"\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) +
" \"</TOOLCALL>\"");
});
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ?
"[\\s\\S]*?(</think>\\s*)" :
"(?:<think>[\\s\\S]*?</think>\\s*)?") +
"(<TOOLCALL>)[\\s\\S]*" });
}
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
Expand Down Expand Up @@ -2060,6 +2122,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}

static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);

// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}

static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
Expand Down Expand Up @@ -2293,6 +2382,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_seed_oss(tmpl, params, inputs);
}

// Nemotron v2
if (src.find("<SPECIAL_10>") != std::string::npos) {
return common_chat_params_init_nemotron_v2(tmpl, params);
}

// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
Expand Down Expand Up @@ -2454,6 +2548,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,

COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
Expand Down
162 changes: 162 additions & 0 deletions models/templates/NVIDIA-Nemotron-Nano-v2.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
{%- set ns = namespace(enable_thinking=true) -%}
{%- for message in messages -%}
{%- set content = message['content'] -%}
{%- if message['role'] == 'user' or message['role'] == 'system' -%}
{%- if '/think' in content -%}
{%- set ns.enable_thinking = true -%}
{%- elif '/no_think' in content -%}
{%- set ns.enable_thinking = false -%}
{%- endif -%}
{%- endif -%}
{%- endfor -%}

{%- if messages[0]['role'] != 'system' -%}
{%- set ns.non_tool_system_content = '' -%}
{{- '<SPECIAL_10>System
' -}}
{%- else -%}
{%- set ns.non_tool_system_content = (messages[0]['content'] | default('', true)).replace('/think', '').replace('/no_think', '').strip() -%}
{{- '<SPECIAL_10>System
' + ns.non_tool_system_content }}
{%- endif -%}

{%- if tools -%}
{%- if ns.non_tool_system_content is defined and ns.non_tool_system_content != '' -%}
{{- '

' -}}
{%- endif -%}
{{- 'You can use the following tools to assist the user if required:' -}}
{{- '
<AVAILABLE_TOOLS>[' -}}
{%- for tool in tools -%}
{{- (tool.function if tool.function is defined else tool) | tojson -}}
{{- ', ' if not loop.last else '' -}}
{%- endfor -%}
{{- ']</AVAILABLE_TOOLS>

' -}}
{{- 'If you decide to call any tool(s), use the following format:
' -}}
{{- '<TOOLCALL>[{{"name": "tool_name1", "arguments": "tool_args1"}}, ' -}}
{{- '{{"name": "tool_name2", "arguments": "tool_args2"}}]</TOOLCALL>

' -}}
{{- 'The user will execute tool-calls and return responses from tool(s) in this format:
' -}}
{{- '<TOOL_RESPONSE>[{{"tool_response1"}}, {{"tool_response2"}}]</TOOL_RESPONSE>

' -}}
{{- 'Based on the tool responses, you can call additional tools if needed, correct tool calls if any errors are found, or just respond to the user.' -}}
{%- endif -%}
{{- '

' -}}
{%- set messages = messages[1:] if messages[0]['role'] == 'system' else messages -%}
{%- if messages[-1]['role'] == 'assistant' -%}
{%- set ns.last_turn_assistant_content = (messages[-1]['content'] | default('', true)).strip() -%}
{%- set ns.last_turn_assistant_tool_calls = messages[-1]['tool_calls'] if 'tool_calls' in messages[-1] else [] -%}
{%- set messages = messages[:-1] -%}
{%- endif -%}

{%- for message in messages %}
{%- set content = message['content'] %}
{%- if message['role'] == 'user' -%}
{{- '<SPECIAL_11>User
' + (content | default('', true)).replace('/think', '').replace('/no_think', '').strip() + '
' }}
{%- elif message['role'] == 'tool' -%}
{%- if loop.first or (messages[loop.index0 - 1].role != 'tool') -%}
{{- '<SPECIAL_11>User
' + '<TOOL_RESPONSE>[' }}
{%- endif -%}
{{- message['content'] -}}
{{- ', ' if not loop.last and (messages[loop.index0 + 1].role == 'tool') else '' -}}
{%- if loop.last or (messages[loop.index0 + 1].role != 'tool') -%}
{{- ']</TOOL_RESPONSE>' -}}
{%- endif -%}
{%- elif message['role'] == 'assistant' -%}
{%- if content and '</think>' in content -%}
{%- set content = (content.split('</think>')[1] | default('', true)).strip() %}
{%- endif -%}
{{- '<SPECIAL_11>Assistant
' + ((content | default('', true)).strip() if content is not none else '') }}
{%- if message.tool_calls -%}
{%- if (content | default('', true)).strip() != '' -%}
{{- '
' -}}
{%- endif -%}
{{- '<TOOLCALL>[' -}}
{%- for call in message.tool_calls -%}
{%- set fn = call.function if call.function is defined else call -%}
{{- '{"name": "' + fn.name + '", "arguments": ' -}}
{%- if fn.arguments is string -%}
{{- fn.arguments -}}
{%- else -%}
{{- fn.arguments | tojson -}}
{%- endif -%}
{{- '}' + (', ' if not loop.last else '') -}}
{%- endfor -%}
{{- ']</TOOLCALL>' -}}
{%- endif -%}
{{- '
<SPECIAL_12>
' -}}
{%- endif -%}
{%- endfor -%}

{%- if add_generation_prompt -%}
{{- '<SPECIAL_11>Assistant
' -}}
{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}
{{- '<think></think>' -}}
{%- else -%}
{{- '<think>
' -}}
{%- endif -%}
{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}
{{- ns.last_turn_assistant_content -}}
{%- endif -%}
{%- else -%}
{%- if ns.last_turn_assistant_content is defined and ns.last_turn_assistant_content != '' -%}
{{- '<SPECIAL_11>Assistant
' -}}
{%- if ns.enable_thinking is defined and ns.enable_thinking is false -%}
{{- '<think></think>' -}}
{%- else -%}
{{- '<think>
' -}}
{%- endif -%}
{{- ns.last_turn_assistant_content -}}
{%- if continue_final_message is defined -%}
{%- if continue_final_message is false -%}
{{- '
<SPECIAL_12>
' -}}
{%- endif -%}
{%- else -%}
{{- '
<SPECIAL_12>
' -}}
{%- endif -%}
{%- endif -%}
{%- if ns.last_turn_assistant_tool_calls is defined and ns.last_turn_assistant_tool_calls | length > 0 -%}
{{- '<SPECIAL_11>Assistant
' -}}
{{- '<TOOLCALL>[' -}}
{%- for call in ns.last_turn_assistant_tool_calls -%}
{%- set fn = call.function if call.function is defined else call -%}
{{- '{"name": "' + fn.name + '", "arguments": ' -}}
{%- if fn.arguments is string -%}
{{- fn.arguments -}}
{%- else -%}
{{- fn.arguments | tojson -}}
{%- endif -%}
{{- '}' + (', ' if not loop.last else '') -}}
{%- endfor -%}
{{- ']</TOOLCALL>' -}}
{{- '<SPECIAL_12>

' -}}
{%- endif -%}
{%- endif -%}
73 changes: 73 additions & 0 deletions tests/test-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ const common_chat_msg message_assist_call_empty_args = simple_assist
const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg");
const common_chat_msg message_assist_call_thoughts = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}");
const common_chat_msg message_assist_call_thoughts_unparsed = simple_assist_msg("<think>I'm\nthinking</think>\n\n", "", "special_function", "{\"arg1\": 1}");
const common_chat_msg message_assist_call_thoughts_content = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}");
const common_chat_msg message_assist_call_id = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789");
const common_chat_msg message_assist_call_idx = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0");
const common_chat_msg message_assist_thoughts_call_idx = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0");
Expand All @@ -436,6 +437,7 @@ static void test_msgs_oaicompat_json_conversion() {
message_assist_call,
message_assist_call_thoughts,
message_assist_call_thoughts_unparsed,
message_assist_call_thoughts_content,
message_assist_call_id,
message_assist_call_idx,
message_assist_call_python,
Expand Down Expand Up @@ -1755,6 +1757,77 @@ static void test_template_output_parsers() {
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_SEED_OSS}));
}

{
auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-Nano-v2.jinja");
std::vector<std::string> end_tokens{ "<SPECIAL_12>" };

assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format);

// Test parsing regular content
assert_msg_equals(message_assist,
common_chat_parse(
"Hello, world!\nWhat's up?",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));

// Test parsing content with thinking
assert_msg_equals(message_assist_thoughts,
common_chat_parse(
"<think>I'm\nthinking</think>Hello, world!\nWhat's up?",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK,
}));

// Test parsing tool calls
assert_msg_equals(message_assist_call,
common_chat_parse(
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_NEMOTRON_V2}));

// Test parsing tool calls with thinking
assert_msg_equals(message_assist_call_thoughts,
common_chat_parse(
"<think>I'm\nthinking</think><TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
}));

// Test tool calls with extra content
assert_msg_equals(message_assist_call_content,
common_chat_parse(
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>Hello, world!\nWhat's up?",
/* is_partial= */ false,
{COMMON_CHAT_FORMAT_NEMOTRON_V2}
));

// Test tool calls with extra content AND thinking
assert_msg_equals(message_assist_call_thoughts_content,
common_chat_parse(
"<think>I'm\nthinking</think><TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>Hello, world!\nWhat's up?",
/* is_partial= */ false,
{
/* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2,
/* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK
}));

// Test template generation for regular content
test_templates(tmpls.get(), end_tokens, message_assist, tools,
"Hello, world!\nWhat's up?\n",
/* expect_grammar_triggered= */ false);

// Test template generation for tool calls
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
"<TOOLCALL>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]</TOOLCALL>",
/* expect_grammar_triggered= */ true
);
}
}

static void test_msg_diffs_compute() {
Expand Down
Loading