-
Notifications
You must be signed in to change notification settings - Fork 13.6k
Model: Minimax M2 - chat support #16946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
You should submit it to https://github.com/ochafik/minja |
Ah, forgot about that. Done. |
|
|
||
| static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { | ||
| // Parse thinking tags first - this handles the main reasoning content | ||
| // Chat template doesn't seem to handle interleaving thinking, so we don't worry about it either |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we sure we're using the correct definition of interleaved thinking here? I don't think it means the CoT is interleaved with the content during generation, but rather it is interleaved in the entire prompt during multi-turn tool calling sessions. It seems to behave very similarly to gpt-oss. None of my testing, granted at Q2_XL, seems to indicate that the CoT is interleaved during generation. It's also only applied if the last message is a tool response.
Using the proposed fix for tool response support by @ochafik, it works as is if I pass reasoning_content with the assistant messages. Without this fix, the tool messages are transformed to user by the polyfill.
Template Example
curl -X POST http://localhost:8080/apply-template \
-H "Content-Type: application/json" \
-d '{
"messages": [
{
"role": "system",
"content": "You are a weather man"
},
{
"role": "user",
"content": "Can you compare the weather at New York and Los Angeles?"
},
{
"role": "assistant",
"reasoning_content": "I need to get the weather of New York and Los Angeles, let me do New York first.",
"tool_calls": [
{
"id": "1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\": \"New York\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "1",
"content": "50 F"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather for a specified city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city name, e.g. San Francisco"
}
},
"required": ["city"]
}
}
}
]
}']~b]system
You are a weather man
# Tools
You may call one or more tools to assist with the user query.
Here are the tools available in JSONSchema format:
<tools>
<tool>{"name": "get_weather", "description": "Get the current weather for a specified city", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "The city name, e.g. San Francisco"}}, "required": ["city"]}}</tool>
</tools>
When making tool calls, use XML format to invoke tools and pass parameters:
<minimax:tool_call>
<invoke name="tool-name-1">
<parameter name="param-key-1">param-value-1</parameter>
<parameter name="param-key-2">param-value-2</parameter>
...
</invoke>
</minimax:tool_call>[e~[
]~b]user
Can you compare the weather at New York and Los Angeles?[e~[
]~b]ai
<think>
I need to get the weather of New York and Los Angeles, let me do New York first.
</think>
<minimax:tool_call>
<invoke name="get_weather">
<parameter name="city">New York</parameter>
</invoke>
</minimax:tool_call>[e~[
]~b]tool
<response>50 F</response>[e~[
]~b]ai
<think>
It does place the burden of returning reasoning_content on the clients.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@aldehir That's actually a good clarification - I was somehow convinced that interlaving reasoning actually meant content blocks with multiple reasoning / content chunks intertwined (I think that the Anthropic protocol allows something like that). We shouldn't have a problem with it if it's just tool calls intertwined with reasoning blocks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@hksdpc255 please take a look at this discussion, since I feel you're repeating the same error (with using reasoning-format none + literally outputting the opening <think> tag).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pwilkin Thanks for pointing that out. I actually had the same misunderstanding about interleaved thinking at first.
Because of that, I initially implemented full support for reasoning and normal content being interleaved during generation. Later I realized that this wasn’t really required in our current setup. But since I already had a custom test harness for it, I verified that my implementation can indeed handle such interleaved reasoning/content streams. It might still be useful in the future if models start emitting that pattern more often.
As for --reasoning-format none, my understanding was that it means not to treat reasoning specially, but to include it directly in the normal assistant message. This interpretation seemed consistent with how some chat templates (like GLM 4.5 / 4.6 and MiniMax M2) automatically detect <think> blocks in the main content, extract them into reasoning_content, and remove them from the visible answer. That behavior is quite helpful for clients that don’t support returning reasoning_content back to the server — which I believe is the case for most code agents.
I’m currently using --reasoning-format none to serve the Zed editor, and in that setup, MiniMax M2 performs impressively well on fairly complex tasks.
However, I might have misunderstood the actual purpose of --reasoning-format none. If so, I’d really appreciate clarification. And if it’s not meant for this kind of use case, I think introducing a new --reasoning-format mode to explicitly support it would make a lot of sense.
|
Could we take a look at PR #16932? I’ve already implemented Minimax M2 tool calling there. |
commit 23d4bb7 Author: Piotr Wilkin <[email protected]> Date: Tue Nov 4 19:07:49 2025 +0100 Add proper handling of optional parameters with test commit 9481289 Author: Piotr Wilkin <[email protected]> Date: Sun Nov 2 19:30:35 2025 +0100 Whitespace. commit 1a351a0 Author: Piotr Wilkin <[email protected]> Date: Sun Nov 2 17:34:47 2025 +0100 Use Unsloth template, add extra test parameters for ignoring additional whitespace commit de67255 Author: Piotr Wilkin <[email protected]> Date: Sat Nov 1 22:33:40 2025 +0100 On the other hand, this is probably safer commit 4e58382 Author: Piotr Wilkin <[email protected]> Date: Sat Nov 1 22:32:20 2025 +0100 No newline after <think> commit e21f87e Author: Piotr Wilkin <[email protected]> Date: Sat Nov 1 22:19:48 2025 +0100 Minimax M2 chat template support
|
Superseded by #16932 |

Adds chat support to Minimax M2 together with tool calling and simple reasoning (non-interleaved).
Uses fixed Unsloth template (https://huggingface.co/unsloth/MiniMax-M2-GGUF)
Includes upstream minja fix: google/minja#87