Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tensorrt_llm/serve/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import uuid
from functools import partial
from typing import (Any, Callable, Coroutine, Dict, Iterable, List, Literal,
Expand Down Expand Up @@ -185,6 +186,36 @@ def parse_chat_message_content(
content,
mm_data_tracker,
)
if role == "assistant":
result.update(_parse_assistant_message_content(message))
elif role == "tool":
result.update(_parse_tool_message_content(message))
return result


# Adapted from: https://github.com/vllm-project/vllm/blob/4574d48bab9c4e38b7c0a830eeefc8f0980e8c58/vllm/entrypoints/chat_utils.py#L1406
def _parse_assistant_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
result = {}
tool_calls = message.get("tool_calls")
if tool_calls is not None:
result["tool_calls"] = []
for item in tool_calls:
if content := item["function"].get("arguments"):
if isinstance(content, str):
item["function"]["arguments"] = json.loads(content)
else:
item["function"]["arguments"] = content
else:
item["function"]["arguments"] = {}
result["tool_calls"].append(item)

return result


def _parse_tool_message_content(message: Dict[str, Any]) -> Dict[str, Any]:
result = {}
if "tool_call_id" in message:
result["tool_call_id"] = message["tool_call_id"]
return result


Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,12 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):

class CustomChatCompletionMessageParam(TypedDict, total=False):
"""Enables custom roles in the Chat Completion API."""

# This is so custom fields not in any of the `ChatCompletionMessage<XYZ>Param` defined by OpenAI
# are still allowed.
# Examples include: assistant messages with `reasoning` / `reasoning_content`.
__pydantic_config__ = ConfigDict(extra="allow") # type: ignore

role: Required[str]
"""The role of the message's author."""

Expand Down
344 changes: 344 additions & 0 deletions tensorrt_llm/serve/tool_parser/qwen3_coder_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
# Adapted from: https://raw.githubusercontent.com/sgl-project/sglang/d8fcbaa38da95201914a1277971044ee66837b26/python/sglang/srt/function_call/qwen3_coder_detector.py

import ast
import html
import json
import re
from typing import Any, Dict, List, Tuple

from tensorrt_llm.logger import logger
from tensorrt_llm.serve.openai_protocol import ChatCompletionToolsParam as Tool
from tensorrt_llm.serve.tool_parser.base_tool_parser import BaseToolParser
from tensorrt_llm.serve.tool_parser.core_types import (
StreamingParseResult,
ToolCallItem,
_GetInfoFunc,
)


def _safe_val(raw: str) -> Any:
raw = html.unescape(raw.strip())
try:
return json.loads(raw)
except Exception:
try:
return ast.literal_eval(raw)
except Exception:
return raw


class Qwen3CoderToolParser(BaseToolParser):
"""Tool parser for Qwen 3 models.

Assumes function call format:
<tool_call>
<function=execute_bash>
<parameter=command>
pwd && ls
</parameter>
</function>
</tool_call>
"""

def __init__(self):
super().__init__()
self.tool_call_start_token: str = "<tool_call>"
self.tool_call_end_token: str = "</tool_call>"
self.tool_call_prefix: str = "<function="
self.tool_call_regex = re.compile(
r"<tool_call>(.*?)</tool_call>|<tool_call>(.*?)$", re.DOTALL
)
self.tool_call_function_regex = re.compile(
r"<function=(.*?)</function>|<function=(.*)$", re.DOTALL
)
self.tool_call_parameter_regex = re.compile(
r"<parameter=(.*?)</parameter>|<parameter=(.*?)$", re.DOTALL
)
self._buf: str = ""

# Streaming state variables
self._current_function_name: str = ""
self._current_parameters: Dict[str, Any] = {}
self._streamed_parameters: Dict[
str, str
] = {} # Track what parameter content we've streamed
self._in_tool_call: bool = False
self._function_name_sent: bool = False

def has_tool_call(self, text: str) -> bool:
return self.tool_call_start_token in text

def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
normal, calls = self._extract(text, tools)
return StreamingParseResult(normal_text=normal, calls=calls)

def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> StreamingParseResult:
self._buf += new_text
normal = ""
calls: List[ToolCallItem] = []

# Build tool indices for validation
if not hasattr(self, "_tool_indices"):
self._tool_indices = self._get_tool_indices(tools)

while True:
# If we're not in a tool call and don't see a start token, return normal text
if not self._in_tool_call and self.tool_call_start_token not in self._buf:
normal += self._buf
self._buf = ""
break

# Look for tool call start
if not self._in_tool_call:
s = self._buf.find(self.tool_call_start_token)
if s == -1:
normal += self._buf
self._buf = ""
break

normal += self._buf[:s]
self._buf = self._buf[s:]

self._in_tool_call = True
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
self._streamed_parameters = {}

# Remove the start token
self._buf = self._buf[len(self.tool_call_start_token) :]
continue

# We're in a tool call, try to parse function name if not sent yet
if not self._function_name_sent:
# Look for function name pattern: <function=name>
function_match = re.search(r"<function=([^>]+)>", self._buf)
if function_match:
function_name = function_match.group(1).strip()

# Validate function name
if function_name in self._tool_indices:
self._current_function_name = function_name
self._function_name_sent = True

# Initialize tool call tracking
if self.current_tool_id == -1:
self.current_tool_id = 0

# Ensure tracking arrays are large enough
while len(self.prev_tool_call_arr) <= self.current_tool_id:
self.prev_tool_call_arr.append({})
while len(self.streamed_args_for_tool) <= self.current_tool_id:
self.streamed_args_for_tool.append("")

# Store tool call info
self.prev_tool_call_arr[self.current_tool_id] = {
"name": function_name,
"arguments": {},
}

# Send tool name with empty parameters
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=function_name,
parameters="",
)
)

# Remove the processed function declaration
self._buf = self._buf[function_match.end() :]
continue
else:
# Invalid function name, reset state
logger.warning(f"Invalid function name: {function_name}")
self._reset_streaming_state()
normal += self._buf
self._buf = ""
break
else:
# Function name not complete yet, wait for more text
break

# Parse parameters incrementally
if self._function_name_sent:
# Process parameters and get any calls to emit
parameter_calls = self._parse_and_stream_parameters(self._buf)
calls.extend(parameter_calls)

# Check if tool call is complete
if self.tool_call_end_token in self._buf:
end_pos = self._buf.find(self.tool_call_end_token)

# Add closing brace to complete the JSON object
current_streamed = self.streamed_args_for_tool[self.current_tool_id]
if current_streamed:
# Count opening and closing braces to check if JSON is complete
open_braces = current_streamed.count("{")
close_braces = current_streamed.count("}")
if open_braces > close_braces:
calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters="}",
)
)
self.streamed_args_for_tool[self.current_tool_id] = (
current_streamed + "}"
)

# Complete the tool call
self._buf = self._buf[end_pos + len(self.tool_call_end_token) :]
self._reset_streaming_state()
self.current_tool_id += 1
continue
else:
# Tool call not complete yet, wait for more text
break

return StreamingParseResult(normal_text=normal, calls=calls)

def _parse_and_stream_parameters(self, text_to_parse: str) -> List[ToolCallItem]:
"""Parse complete parameter blocks from text and return any tool call items to emit.

This method:
1. Finds all complete <parameter> blocks
2. Parses them into a dictionary
3. Compares with current parameters and generates diff if needed
4. Updates internal state

Args:
text_to_parse: The text to search for parameter blocks

Returns:
List of ToolCallItem objects to emit (may be empty)
"""
calls: List[ToolCallItem] = []

# Find all complete parameter patterns
param_matches = list(
re.finditer(r"<parameter=([^>]+)>(.*?)</parameter>", text_to_parse, re.DOTALL)
)

# Build new parameters dictionary
new_params = {}
for match in param_matches:
param_name = match.group(1).strip()
param_value = match.group(2)
new_params[param_name] = _safe_val(param_value)

# Calculate parameter diff to stream with proper incremental JSON building
if new_params != self._current_parameters:
previous_args_json = self.streamed_args_for_tool[self.current_tool_id]

# Build incremental JSON properly
if not self._current_parameters:
# First parameter(s) - start JSON object but don't close it yet
items = []
for key, value in new_params.items():
items.append(
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
)
json_fragment = "{" + ", ".join(items)

calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=json_fragment,
)
)
self.streamed_args_for_tool[self.current_tool_id] = json_fragment

else:
# Additional parameters - add them incrementally
new_keys = set(new_params.keys()) - set(self._current_parameters.keys())
if new_keys:
# Build the continuation part (no closing brace yet)
continuation_parts = []
for key in new_keys:
value = new_params[key]
continuation_parts.append(
f"{json.dumps(key, ensure_ascii=False)}: {json.dumps(value, ensure_ascii=False)}"
)

json_fragment = ", " + ", ".join(continuation_parts)

calls.append(
ToolCallItem(
tool_index=self.current_tool_id,
name=None,
parameters=json_fragment,
)
)
self.streamed_args_for_tool[self.current_tool_id] = (
previous_args_json + json_fragment
)

# Update current state
self._current_parameters = new_params
self.prev_tool_call_arr[self.current_tool_id]["arguments"] = new_params

return calls

def _reset_streaming_state(self):
"""Reset streaming state for the next tool call."""
self._in_tool_call = False
self._function_name_sent = False
self._current_function_name = ""
self._current_parameters = {}
self._streamed_parameters = {}
self.current_tool_name_sent = False

def _extract(self, text: str, tools: List[Tool]) -> Tuple[str, List[ToolCallItem]]:
normal_parts: List[str] = []
calls: List[ToolCallItem] = []
cursor = 0
while True:
s = text.find(self.tool_call_start_token, cursor)
if s == -1:
normal_parts.append(text[cursor:])
break
normal_parts.append(text[cursor:s])
e = text.find(self.tool_call_end_token, s)
if e == -1:
normal_parts.append(text[s:])
break
block = text[s : e + len(self.tool_call_end_token)]
cursor = e + len(self.tool_call_end_token)
calls.extend(self._parse_block(block, tools))
return "".join(normal_parts), calls

def _parse_block(self, block: str, tools: List[Tool]) -> List[ToolCallItem]:
res: List[ToolCallItem] = []
for m in self.tool_call_function_regex.findall(block):
txt = m[0] if m[0] else m[1]
if ">" not in txt:
continue
idx = txt.index(">")
fname = txt[:idx].strip()
body = txt[idx + 1 :]
params: Dict[str, Any] = {}
for pm in self.tool_call_parameter_regex.findall(body):
ptxt = pm[0] if pm[0] else pm[1]
if ">" not in ptxt:
continue
pidx = ptxt.index(">")
pname = ptxt[:pidx].strip()
pval = ptxt[pidx + 1 :].lstrip("\n").rstrip("\n")
params[pname] = _safe_val(pval)
raw = {"name": fname, "arguments": params}
try:
# TODO: fix idx in function call, the index for a function
# call will always be -1 in parse_base_json
res.extend(self.parse_base_json(raw, tools))
except Exception:
logger.warning("invalid tool call for %s dropped", fname)
return res

def supports_structural_tag(self) -> bool:
return False

def structure_info(self) -> _GetInfoFunc:
raise NotImplementedError
Loading