Skip to content

Commit 1a351a0

Browse files
committed
Use Unsloth template, add extra test parameters for ignoring additional whitespace
1 parent de67255 commit 1a351a0

File tree

2 files changed

+56
-15
lines changed

2 files changed

+56
-15
lines changed

models/templates/MiniMax-M2.jinja renamed to models/templates/unsloth-MiniMax-M2.jinja

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1+
{# Unsloth template fixes #}
12
{# ----------‑‑‑ special token variables ‑‑‑---------- #}
23
{%- set toolcall_begin_token = '<minimax:tool_call>' -%}
34
{%- set toolcall_end_token = '</minimax:tool_call>' -%}
45
{#- Tool Rendering Functions ============================================== -#}
56
{%- macro render_tool_namespace(namespace_name, tool_list) -%}
67
{%- for tool in tool_list -%}
7-
<tool>{{ tool.function | tojson() }}</tool>
8+
<tool>{{ tool.function | tojson | string }}</tool>
89
{% endfor -%}
910
{%- endmacro -%}
1011
{%- macro visible_text(content) -%}
@@ -76,20 +77,31 @@
7677
{{- '\n' ~ toolcall_end_token }}
7778
{%- endif -%}
7879
{{- '[e~[\n' }}
80+
7981
{#- Render messages -#}
8082
{%- set last_tool_call = namespace(name=none) -%}
8183
{%- for message in conversation_messages -%}
8284
{%- if message.role == 'assistant' -%}
8385
{#- Only render reasoning_content if no user message follows -#}
8486
{{- ']~b]ai' ~ '\n' }}
87+
8588
{%- set reasoning_content = '' %}
8689
{%- set content = visible_text(message.content) %}
8790
{%- if message.reasoning_content is string %}
8891
{%- set reasoning_content = message.reasoning_content %}
8992
{%- else %}
9093
{%- if '</think>' in content %}
91-
{%- set reasoning_content = content.split('</think>')[0].strip('\n').split('<think>')[-1].strip('\n') %}
92-
{%- set content = content.split('</think>')[-1].strip('\n') %}
94+
{# Unsloth template fixes - must change to for loop since llama.cpp will error out if not #}
95+
{%- set parts = content.split('</think>') %}
96+
{%- for part in parts %}
97+
{%- if loop.index0 == 0 -%}
98+
{%- set reasoning_content = part.strip('\n') %}
99+
{%- set reasoning_content = (reasoning_content.split('<think>')|last) %}
100+
{%- set reasoning_content = reasoning_content.strip('\n') -%}
101+
{%- else -%}
102+
{%- set content = part.strip('\n') %}
103+
{%- endif %}
104+
{%- endfor %}
93105
{%- endif %}
94106
{%- endif %}
95107
{%- if reasoning_content and loop.index0 > ns.last_user_index -%}
@@ -99,26 +111,30 @@
99111
{{- content }}
100112
{%- endif -%}
101113
{%- if message.tool_calls -%}
102-
{{- toolcall_begin_token ~ '\n' }}
114+
{{- '\n' ~ toolcall_begin_token ~ '\n' }}
115+
103116
{%- for tool_call in message.tool_calls -%}
104117
{%- if tool_call.function %}
105118
{%- set tool_call = tool_call.function %}
106119
{%- endif %}
107-
{{- '<invoke name="' + tool_call.name + '">' }}
120+
{{- '<invoke name="' + tool_call.name + '">\n' }}
121+
{%- if tool_call.arguments is defined and tool_call.arguments is mapping -%}
108122
{% set _args = tool_call.arguments %}
109-
{%- for k, v in _args | items %}
123+
{%- for k, v in _args|items %}
110124
{{- '<parameter name="' + k + '">' }}
111-
{{- v | tojson if v is not string else v }}
125+
{{- v | tojson | string if v is not string else v }}
112126
{{- '</parameter>' }}
113-
{% endfor %}
127+
{% endfor %}{%- endif -%}
114128
{{- '</invoke>' ~ '\n' }}
115129
{%- endfor -%}
130+
116131
{{- toolcall_end_token}}
117132
{%- set last_tool_call.name = message.tool_calls[-1].name -%}
118133
{%- else -%}
119134
{%- set last_tool_call.name = none -%}
120135
{%- endif -%}
121136
{{- '[e~[' ~ '\n' }}
137+
122138
{%- elif message.role == 'tool' -%}
123139
{%- if last_tool_call.name is none -%}
124140
{{- raise_exception("Message has tool role, but there was no previous assistant message with a tool call!") }}
@@ -147,7 +163,9 @@
147163
{{- '[e~[' ~ '\n' }}
148164
{%- endif -%}
149165
{%- endfor -%}
166+
150167
{#- Generation prompt -#}
151168
{%- if add_generation_prompt -%}
152169
{{- ']~b]ai' ~ '\n' ~ '<think>' ~ '\n' }}
153170
{%- endif -%}
171+
{# Copyright 2025-present Unsloth. Apache 2.0 License. #}

tests/test-chat.cpp

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,21 @@ static common_chat_msg normalize(const common_chat_msg & msg) {
7575
}
7676
return normalized;
7777
}
78+
79+
80+
// trim whitespace from the beginning and end of a string
81+
static std::string trim(const std::string & str) {
82+
size_t start = 0;
83+
size_t end = str.size();
84+
while (start < end && isspace(static_cast<unsigned char>(str[start]))) {
85+
start += 1;
86+
}
87+
while (end > start && isspace(static_cast<unsigned char>(str[end - 1]))) {
88+
end -= 1;
89+
}
90+
return str.substr(start, end - start);
91+
}
92+
7893
template <>
7994
bool equals(const common_chat_msg & expected, const common_chat_msg & actual) {
8095
return normalize(expected) == normalize(actual);
@@ -148,15 +163,15 @@ static std::string renormalize_json(const std::string & json_str) {
148163
return json_str;
149164
}
150165
}
151-
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual) {
166+
static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual, bool ignore_whitespace_differences = false) {
152167
assert_equals(expected.role, actual.role);
153-
assert_equals(expected.content, actual.content);
168+
assert_equals(expected.content, ignore_whitespace_differences ? trim(actual.content) : actual.content);
154169
assert_equals(expected.content_parts.size(), actual.content_parts.size());
155170
for (size_t i = 0; i < expected.content_parts.size(); i++) {
156171
const auto & expected_part = expected.content_parts[i];
157172
const auto & actual_part = actual.content_parts[i];
158173
assert_equals(expected_part.type, actual_part.type);
159-
assert_equals(expected_part.text, actual_part.text);
174+
assert_equals(expected_part.text, ignore_whitespace_differences ? trim(actual_part.text) : actual_part.text);
160175
}
161176
assert_equals(expected.reasoning_content, actual.reasoning_content);
162177
assert_equals(expected.tool_calls.size(), actual.tool_calls.size());
@@ -280,14 +295,19 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
280295
const std::string & expected_delta = "",
281296
bool expect_grammar_triggered = true,
282297
bool test_grammar_if_triggered = true,
283-
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE) {
298+
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE,
299+
bool ignore_whitespace_differences = false
300+
) {
284301
common_chat_msg user_message;
285302
user_message.role = "user";
286303
user_message.content = "Hello, world!";
287304

288305
for (const auto & tool_choice : std::vector<common_chat_tool_choice> {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) {
289306
auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice);
290307
if (!expected_delta.empty()) {
308+
if (ignore_whitespace_differences) {
309+
data.delta = trim(data.delta);
310+
}
291311
assert_equals(expected_delta, data.delta);
292312
}
293313

@@ -296,7 +316,7 @@ static void test_templates(const struct common_chat_templates * tmpls, const std
296316
syntax.format = data.params.format;
297317
syntax.reasoning_format = reasoning_format;
298318
const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, syntax);
299-
assert_msg_equals(test_message, msg);
319+
assert_msg_equals(test_message, msg, ignore_whitespace_differences);
300320
}
301321

302322
if (!test_message.tool_calls.empty()) {
@@ -2289,7 +2309,7 @@ Hey there!<|im_end|>
22892309
}
22902310

22912311
{
2292-
auto tmpls = read_templates("models/templates/MiniMax-M2.jinja");
2312+
auto tmpls = read_templates("models/templates/unsloth-MiniMax-M2.jinja");
22932313
std::vector<std::string> end_tokens{ "[e~[" };
22942314

22952315
assert_equals(COMMON_CHAT_FORMAT_MINIMAX_M2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format);
@@ -2355,7 +2375,10 @@ Hey there!<|im_end|>
23552375
// Test template generation for tool calls
23562376
test_templates(tmpls.get(), end_tokens, message_assist_call, tools,
23572377
"<minimax:tool_call>\n<invoke name=\"special_function\">\n<parameter name=\"arg1\">1</parameter>\n</invoke>\n</minimax:tool_call>",
2358-
/* expect_grammar_triggered= */ true
2378+
/* expect_grammar_triggered= */ true,
2379+
/* test_grammar_if_triggered= */ true,
2380+
/* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE,
2381+
/* ignore_whitespace_differences= */ true
23592382
);
23602383

23612384
}

0 commit comments

Comments
 (0)