Skip to content

Commit 95f4c09

Browse files
committed
Add parallel tools constraint and update parsing
1 parent 56bb724 commit 95f4c09

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

common/chat-parser.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ bool common_chat_msg_parser::try_consume_partial_literal(const std::string & lit
108108
}
109109
}
110110
}
111-
112111
return try_consume_literal(literal);
113112
}
114113

common/chat.cpp

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,9 @@ static common_chat_params common_chat_params_init_llama_3_x_nemotron(const commo
17231723
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
17241724
{"minItems", 1},
17251725
};
1726+
if (!inputs.parallel_tool_calls) {
1727+
schema["maxItems"] = 1;
1728+
}
17261729
builder.add_rule("root", "\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) + " \"</TOOLCALL>\"");
17271730
});
17281731
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<TOOLCALL>"});
@@ -1740,14 +1743,17 @@ static void common_chat_parse_llama_3_x_nemotron(common_chat_msg_parser & builde
17401743
builder.add_content(builder.consume_rest());
17411744
return;
17421745
}
1743-
1744-
static const common_regex prefix(regex_escape("<TOOLCALL>"));
1745-
static const common_regex suffix(regex_escape("</TOOLCALL>"));
1746-
1747-
parse_prefixed_json_tool_call_array(builder, prefix);
1748-
if (!builder.try_find_regex(suffix)) {
1749-
builder.consume_rest();
1746+
static const common_regex toolcall_regex("<TOOLCALL>");
1747+
static const common_regex close_regex("</TOOLCALL>");
1748+
static const std::vector<std::vector<std::string>> args_paths = {{"arguments"}};
1749+
if (builder.try_find_regex(toolcall_regex)) {
1750+
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
1751+
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
1752+
throw common_chat_msg_partial_exception("incomplete tool call array");
1753+
}
1754+
builder.consume_regex(close_regex);
17501755
}
1756+
builder.add_content(builder.consume_rest());
17511757
}
17521758

17531759
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
@@ -1852,7 +1858,7 @@ static common_chat_params common_chat_templates_apply_jinja(
18521858
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
18531859
}
18541860

1855-
// Llama 3.3 Nemo (w/ tools)
1861+
// Llama 3.x Nemotron (w/ tools)
18561862
if (src.find("<TOOLCALL>") != std::string::npos) {
18571863
return common_chat_params_init_llama_3_x_nemotron(tmpl, params);
18581864
}

0 commit comments

Comments
 (0)