Skip to content

Commit 23d4bb7

Browse files
committed
Add proper handling of optional parameters with test
1 parent 9481289 commit 23d4bb7

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

common/chat.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,22 @@ static void foreach_function(const json & tools, const std::function<void(const
791791
}
792792
}
793793

794+
static std::set<std::string> get_required_parameters(const json & params) {
795+
std::set<std::string> retval;
796+
if (!params.empty()) {
797+
for (const auto& element : params.array()) {
798+
if (element.is_string()) {
799+
retval.emplace(element.get<std::string>());
800+
}
801+
}
802+
}
803+
return retval;
804+
}
805+
806+
static std::string gr_optional(std::string rule) {
807+
return "( " + rule + " )?";
808+
}
809+
794810
static std::string apply(
795811
const common_chat_template & tmpl,
796812
const struct templates_params & inputs,
@@ -2821,14 +2837,19 @@ static common_chat_params common_chat_params_init_minimax_m2(
28212837
// Create rule for Seed-OSS function call format
28222838
std::string param_rules;
28232839
if (parameters.contains("properties")) {
2840+
std::set<std::string> requiredParameters;
2841+
if (parameters.contains("required")) {
2842+
requiredParameters = get_required_parameters(parameters.at("required"));
2843+
}
28242844
for (const auto & [key, value] : parameters.at("properties").items()) {
2825-
param_rules += "\"<parameter name=\\\"" + key + "\\\">\" " + builder.add_schema(name + "-arg-" + key, value) + " \"</parameter>\" space ";
2845+
bool required = requiredParameters.count(key) > 0;
2846+
std::string specific_param_rules = "\"<parameter name=\\\"" + key + "\\\">\" " + builder.add_schema(name + "-arg-" + key, value) + " \"</parameter>\" space ";
2847+
param_rules += required ? specific_param_rules : gr_optional(specific_param_rules);
28262848
}
28272849
}
28282850
tool_rules.push_back(builder.add_rule(name + "-call",
2829-
"\"<minimax:tool_call>\" space \"<invoke name=\\\"" + name + "\\\">\" space " +
2830-
param_rules +
2831-
" \"</invoke>\" space \"</minimax:tool_call>\""));
2851+
"\"<minimax:tool_call>\" space \"<invoke name=\\\"" + name + "\\\">\" space " +
2852+
param_rules + " \"</invoke>\" space \"</minimax:tool_call>\""));
28322853
});
28332854

28342855
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<minimax:tool_call>" });

tests/test-chat.cpp

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,24 @@ common_chat_tool special_function_tool {
198198
"required": ["arg1"]
199199
})",
200200
};
201+
common_chat_tool special_function_tool_with_optional_param {
202+
/* .name = */ "special_function_with_opt",
203+
/* .description = */ "I'm special but have optional stuff",
204+
/* .parameters = */ R"({
205+
"type": "object",
206+
"properties": {
207+
"arg1": {
208+
"type": "integer",
209+
"description": "The arg."
210+
},
211+
"arg2": {
212+
"type": "integer",
213+
"description": "The optional arg."
214+
}
215+
},
216+
"required": ["arg1"]
217+
})",
218+
};
201219
common_chat_tool python_tool {
202220
/* .name = */ "python",
203221
/* .description = */ "an ipython interpreter",
@@ -226,7 +244,7 @@ common_chat_tool code_interpreter_tool {
226244
"required": ["code"]
227245
})",
228246
};
229-
std::vector<common_chat_tool> tools { special_function_tool, python_tool };
247+
std::vector<common_chat_tool> tools { special_function_tool, special_function_tool_with_optional_param, python_tool };
230248
std::vector<common_chat_tool> llama_3_1_tools { special_function_tool, code_interpreter_tool };
231249

232250
struct delta_data {
@@ -437,6 +455,8 @@ const common_chat_msg message_assist_thoughts = simple_assist
437455
const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinking</think>Hello, world!\nWhat's up?");
438456
const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking");
439457
const common_chat_msg message_assist_call = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}");
458+
const common_chat_msg message_assist_call_noopt = simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1}");
459+
const common_chat_msg message_assist_call_withopt = simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1, \"arg2\": 2}");
440460
const common_chat_msg message_assist_call_content = simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}");
441461
const common_chat_msg message_assist_call_empty_args = simple_assist_msg("", "", "special_function");
442462
const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg");
@@ -2381,6 +2401,21 @@ Hey there!<|im_end|>
23812401
/* ignore_whitespace_differences= */ true
23822402
);
23832403

2404+
// Test template generation for tools with optional parameters
2405+
test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools,
2406+
"<minimax:tool_call>\n<invoke name=\"special_function_with_opt\">\n<parameter name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>",
2407+
/* expect_grammar_triggered= */ true,
2408+
/* test_grammar_if_triggered= */ true,
2409+
/* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE,
2410+
/* ignore_whitespace_differences= */ true
2411+
);
2412+
test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools,
2413+
"<minimax:tool_call>\n<invoke name=\"special_function_with_opt\">\n<parameter name=\"arg1\">1</parameter>\n<parameter name=\"arg2\">2</parameter>\n</invoke>\n</minimax:tool_call>",
2414+
/* expect_grammar_triggered= */ true,
2415+
/* test_grammar_if_triggered= */ true,
2416+
/* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE,
2417+
/* ignore_whitespace_differences= */ true
2418+
);
23842419
}
23852420

23862421
}

0 commit comments

Comments
 (0)