-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][fixes] Add tool call parsing fixes and Qwen3 coder parser #8817
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,344 @@ | ||
| # Adapted from: https://raw.githubusercontent.com/sgl-project/sglang/d8fcbaa38da95201914a1277971044ee66837b26/python/sglang/srt/function_call/qwen3_coder_detector.py | ||
|
|
||
| import ast | ||
| import html | ||
| import json | ||
| import re | ||
| from typing import Any, Dict, List, Tuple | ||
|
|
||
| from tensorrt_llm.logger import logger | ||
| from tensorrt_llm.serve.openai_protocol import ChatCompletionToolsParam as Tool | ||
| from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser | ||
| from tensorrt_llm.serve.tool_parser.core_types import ( | ||
| StreamingParseResult, | ||
| ToolCallItem, | ||
| _GetInfoFunc, | ||
| ) | ||
|
|
||
|
|
||
| def _safe_val(raw: str) -> Any: | ||
| raw = html.unescape(raw.strip()) | ||
| try: | ||
| return json.loads(raw) | ||
| except Exception: | ||
| try: | ||
| return ast.literal_eval(raw) | ||
| except Exception: | ||
| return raw | ||
|
|
||
|
|
||
| class Qwen3CoderToolParser(BaseToolParser): | ||
| """Tool parser for Qwen 3 models. | ||
|
|
||
| Assumes function call format: | ||
| <tool_call> | ||
| <function=execute_bash> | ||
| <parameter=command> | ||
| pwd && ls | ||
| </parameter> | ||
| </function> | ||
| </tool_call> | ||
| """ | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.tool_call_start_token: str = "<tool_call>" | ||
| self.tool_call_end_token: str = "</tool_call>" | ||
| self.tool_call_prefix: str = "<function=" | ||
| self.tool_call_regex = re.compile( | ||
| r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL | ||
| ) | ||
| self.tool_call_function_regex = re.compile( | ||
| r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL | ||
| ) | ||
| self.tool_call_parameter_regex = re.compile( | ||
| r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL | ||
| ) | ||
| self._buf: str = "" | ||
|
|
||
| # Streaming state variables | ||
| self._current_function_name: str = "" | ||
| self._current_parameters: Dict[str, Any] = {} | ||
| self._streamed_parameters: Dict[ | ||
| str, str | ||
| ] = {} # Track what parameter content we've streamed | ||
| self._in_tool_call: bool = False | ||
| self._function_name_sent: bool = False | ||
|
|
||
| def has_tool_call(self, text: str) -> bool: | ||
| return self.tool_call_start_token in text | ||
|
|
||
| def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: | ||
| normal, calls = self._extract(text, tools) | ||
| return StreamingParseResult(normal_text=normal, calls=calls) | ||
|
|
||
| def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult: | ||
| self._buf += new_text | ||
| normal = "" | ||
| calls: List[ToolCallItem] = [] | ||
|
|
||
| # Build tool indices for validation | ||
| if not hasattr(self, "_tool_indices"): | ||
| self._tool_indices = self._get_tool_indices(tools) | ||
|
|
||
| while True: | ||
| # If we're not in a tool call and don't see a start token, return normal text | ||
| if not self._in_tool_call and self.tool_call_start_token not in self._buf: | ||
| normal += self._buf | ||
| self._buf = "" | ||
| break | ||
|
|
||
| # Look for tool call start | ||
| if not self._in_tool_call: | ||
| s = self._buf.find(self.tool_call_start_token) | ||
| if s == -1: | ||
| normal += self._buf | ||
| self._buf = "" | ||
| break | ||
|
|
||
| normal += self._buf[:s] | ||
| self._buf = self._buf[s:] | ||
|
|
||
| self._in_tool_call = True | ||
| self._function_name_sent = False | ||
| self._current_function_name = "" | ||
| self._current_parameters = {} | ||
| self._streamed_parameters = {} | ||
|
|
||
| # Remove the start token | ||
| self._buf = self._buf[len(self.tool_call_start_token) :] | ||
| continue | ||
|
|
||
| # We're in a tool call, try to parse function name if not sent yet | ||
| if not self._function_name_sent: | ||
| # Look for function name pattern: <function=name> | ||
| function_match = re.search(r"<function=([^>]+)>", self._buf) | ||
| if function_match: | ||
| function_name = function_match.group(1).strip() | ||
|
|
||
| # Validate function name | ||
| if function_name in self._tool_indices: | ||
| self._current_function_name = function_name | ||
| self._function_name_sent = True | ||
|
|
||
| # Initialize tool call tracking | ||
| if self.current_tool_id == -1: | ||
| self.current_tool_id = 0 | ||
|
|
||
| # Ensure tracking arrays are large enough | ||
| while len(self.prev_tool_call_arr) <= self.current_tool_id: | ||
| self.prev_tool_call_arr.append({}) | ||
| while len(self.streamed_args_for_tool) <= self.current_tool_id: | ||
| self.streamed_args_for_tool.append("") | ||
|
|
||
| # Store tool call info | ||
| self.prev_tool_call_arr[self.current_tool_id] = { | ||
| "name": function_name, | ||
| "arguments": {}, | ||
| } | ||
|
|
||
| # Send tool name with empty parameters | ||
| calls.append( | ||
| ToolCallItem( | ||
| tool_index=self.current_tool_id, | ||
| name=function_name, | ||
| parameters="", | ||
| ) | ||
| ) | ||
|
|
||
| # Remove the processed function declaration | ||
| self._buf = self._buf[function_match.end() :] | ||
| continue | ||
| else: | ||
| # Invalid function name, reset state | ||
| logger.warning(f"Invalid function name: {function_name}") | ||
| self._reset_streaming_state() | ||
| normal += self._buf | ||
| self._buf = "" | ||
| break | ||
| else: | ||
| # Function name not complete yet, wait for more text | ||
| break | ||
|
|
||
| # Parse parameters incrementally | ||
| if self._function_name_sent: | ||
| # Process parameters and get any calls to emit | ||
| parameter_calls = self._parse_and_stream_parameters(self._buf) | ||
| calls.extend(parameter_calls) | ||
|
|
||
| # Check if tool call is complete | ||
| if self.tool_call_end_token in self._buf: | ||
| end_pos = self._buf.find(self.tool_call_end_token) | ||
|
|
||
| # Add closing brace to complete the JSON object | ||
| current_streamed = self.streamed_args_for_tool[self.current_tool_id] | ||
| if current_streamed: | ||
| # Count opening and closing braces to check if JSON is complete | ||
| open_braces = current_streamed.count("{") | ||
| close_braces = current_streamed.count("}") | ||
| if open_braces > close_braces: | ||
| calls.append( | ||
| ToolCallItem( | ||
| tool_index=self.current_tool_id, | ||
| name=None, | ||
| parameters="}", | ||
| ) | ||
| ) | ||
| self.streamed_args_for_tool[self.current_tool_id] = ( | ||
| current_streamed + "}" | ||
| ) | ||
|
|
||
| # Complete the tool call | ||
| self._buf = self._buf[end_pos + len(self.tool_call_end_token) :] | ||
| self._reset_streaming_state() | ||
| self.current_tool_id += 1 | ||
| continue | ||
| else: | ||
| # Tool call not complete yet, wait for more text | ||
| break | ||
|
|
||
| return StreamingParseResult(normal_text=normal, calls=calls) | ||
|
|
||
| def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]: | ||
| """Parse complete parameter blocks from text and return any tool call items to emit. | ||
|
|
||
| This method: | ||
| 1. Finds all complete <parameter> blocks | ||
| 2. Parses them into a dictionary | ||
| 3. Compares with current parameters and generates diff if needed | ||
| 4. Updates internal state | ||
|
|
||
| Args: | ||
| text_to_parse: The text to search for parameter blocks | ||
|
|
||
| Returns: | ||
| List of ToolCallItem objects to emit (may be empty) | ||
| """ | ||
| calls: List[ToolCallItem] = [] | ||
|
|
||
| # Find all complete parameter patterns | ||
| param_matches = list( | ||
| re.finditer(r"<parameter=([^>]+)>(.*?)</parameter>", text_to_parse, re.DOTALL) | ||
| ) | ||
|
|
||
| # Build new parameters dictionary | ||
| new_params = {} | ||
| for match in param_matches: | ||
| param_name = match.group(1).strip() | ||
| param_value = match.group(2) | ||
| new_params[param_name] = _safe_val(param_value) | ||
|
|
||
| # Calculate parameter diff to stream with proper incremental JSON building | ||
| if new_params != self._current_parameters: | ||
| previous_args_json = self.streamed_args_for_tool[self.current_tool_id] | ||
|
|
||
| # Build incremental JSON properly | ||
| if not self._current_parameters: | ||
| # First parameter(s) - start JSON object but don't close it yet | ||
| items = [] | ||
| for key, value in new_params.items(): | ||
| items.append( | ||
| f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" | ||
| ) | ||
| json_fragment = "{" + ", ".join(items) | ||
|
|
||
| calls.append( | ||
| ToolCallItem( | ||
| tool_index=self.current_tool_id, | ||
| name=None, | ||
| parameters=json_fragment, | ||
| ) | ||
| ) | ||
| self.streamed_args_for_tool[self.current_tool_id] = json_fragment | ||
|
|
||
| else: | ||
| # Additional parameters - add them incrementally | ||
| new_keys = set(new_params.keys()) - set(self._current_parameters.keys()) | ||
| if new_keys: | ||
| # Build the continuation part (no closing brace yet) | ||
| continuation_parts = [] | ||
| for key in new_keys: | ||
| value = new_params[key] | ||
| continuation_parts.append( | ||
| f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}" | ||
| ) | ||
|
|
||
| json_fragment = ", " + ", ".join(continuation_parts) | ||
|
|
||
| calls.append( | ||
| ToolCallItem( | ||
| tool_index=self.current_tool_id, | ||
| name=None, | ||
| parameters=json_fragment, | ||
| ) | ||
| ) | ||
| self.streamed_args_for_tool[self.current_tool_id] = ( | ||
| previous_args_json + json_fragment | ||
| ) | ||
|
|
||
| # Update current state | ||
| self._current_parameters = new_params | ||
| self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params | ||
|
|
||
| return calls | ||
|
|
||
| def _reset_streaming_state(self): | ||
| """Reset streaming state for the next tool call.""" | ||
| self._in_tool_call = False | ||
| self._function_name_sent = False | ||
| self._current_function_name = "" | ||
| self._current_parameters = {} | ||
| self._streamed_parameters = {} | ||
| self.current_tool_name_sent = False | ||
|
|
||
| def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]: | ||
| normal_parts: List[str] = [] | ||
| calls: List[ToolCallItem] = [] | ||
| cursor = 0 | ||
| while True: | ||
| s = text.find(self.tool_call_start_token, cursor) | ||
| if s == -1: | ||
| normal_parts.append(text[cursor:]) | ||
| break | ||
| normal_parts.append(text[cursor:s]) | ||
| e = text.find(self.tool_call_end_token, s) | ||
| if e == -1: | ||
| normal_parts.append(text[s:]) | ||
| break | ||
| block = text[s : e + len(self.tool_call_end_token)] | ||
| cursor = e + len(self.tool_call_end_token) | ||
| calls.extend(self._parse_block(block, tools)) | ||
| return "".join(normal_parts), calls | ||
|
|
||
| def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]: | ||
| res: List[ToolCallItem] = [] | ||
| for m in self.tool_call_function_regex.findall(block): | ||
| txt = m[0] if m[0] else m[1] | ||
| if ">" not in txt: | ||
| continue | ||
| idx = txt.index(">") | ||
| fname = txt[:idx].strip() | ||
| body = txt[idx + 1 :] | ||
| params: Dict[str, Any] = {} | ||
| for pm in self.tool_call_parameter_regex.findall(body): | ||
| ptxt = pm[0] if pm[0] else pm[1] | ||
| if ">" not in ptxt: | ||
| continue | ||
| pidx = ptxt.index(">") | ||
| pname = ptxt[:pidx].strip() | ||
| pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n") | ||
| params[pname] = _safe_val(pval) | ||
| raw = {"name": fname, "arguments": params} | ||
| try: | ||
| # TODO: fix idx in function call, the index for a function | ||
| # call will always be -1 in parse_base_json | ||
| res.extend(self.parse_base_json(raw, tools)) | ||
| except Exception: | ||
| logger.warning("invalid tool call for %s dropped", fname) | ||
| return res | ||
|
|
||
| def supports_structural_tag(self) -> bool: | ||
| return False | ||
|
|
||
| def structure_info(self) -> _GetInfoFunc: | ||
| raise NotImplementedError |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.