Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 42 additions & 11 deletions xllm/api_service/chat_service_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,40 @@ void set_logprobs(proto::ChatChoice* choice,
}

struct StreamingState {
std::unique_ptr<function_call::FunctionCallParser> parser;
std::unordered_map<size_t, bool> has_tool_calls;
std::vector<function_call::JsonTool> tools;
std::string parser_format;

std::vector<std::unique_ptr<function_call::FunctionCallParser>> parsers;
std::vector<bool> has_tool_calls;

StreamingState(const std::vector<function_call::JsonTool>& tools,
const std::string& parser_format) {
const std::string& parser_format)
: tools(tools), parser_format(parser_format) {
if (!tools.empty() && !parser_format.empty()) {
parser = std::make_unique<function_call::FunctionCallParser>(
parsers.resize(1);
has_tool_calls.resize(1, false);
parsers[0] = std::make_unique<function_call::FunctionCallParser>(
tools, parser_format);
}
}

function_call::FunctionCallParser* get_parser_for_sequence(size_t index) {
if (tools.empty() || parser_format.empty()) {
return nullptr;
}

if (index >= parsers.size()) {
parsers.resize(index + 1);
has_tool_calls.resize(index + 1, false);
}

if (!parsers[index]) {
parsers[index] = std::make_unique<function_call::FunctionCallParser>(
tools, parser_format);
}

return parsers[index].get();
}
};

template <typename ChatCall>
Expand Down Expand Up @@ -206,11 +230,12 @@ bool process_tool_call_stream(std::shared_ptr<ChatCall> call,
const std::string& request_id,
int64_t created_time,
const std::string& model) {
if (!streaming_state->parser) {
auto* parser = streaming_state->get_parser_for_sequence(index);
if (!parser) {
return true;
}

auto parse_result = streaming_state->parser->parse_streaming_increment(delta);
auto parse_result = parser->parse_streaming_increment(delta);

if (!parse_result.normal_text.empty()) {
if (!send_normal_text_chunk(call,
Expand All @@ -224,6 +249,9 @@ bool process_tool_call_stream(std::shared_ptr<ChatCall> call,
}

for (const auto& call_item : parse_result.calls) {
if (index >= streaming_state->has_tool_calls.size()) {
streaming_state->has_tool_calls.resize(index + 1, false);
}
streaming_state->has_tool_calls[index] = true;

std::string tool_call_id;
Expand Down Expand Up @@ -258,11 +286,12 @@ bool check_for_unstreamed_tool_args(
const std::string& request_id,
int64_t created_time,
const std::string& model) {
if (!streaming_state->parser) {
auto* parser = streaming_state->get_parser_for_sequence(index);
if (!parser) {
return true;
}

auto* detector = streaming_state->parser->get_detector();
auto* detector = parser->get_detector();
if (!detector) {
return true;
}
Expand Down Expand Up @@ -335,7 +364,7 @@ bool send_delta_to_client_brpc(
}

if (!seq_output.text.empty()) {
if (streaming_state && streaming_state->parser) {
if (streaming_state && streaming_state->get_parser_for_sequence(index)) {
if (!process_tool_call_stream(call,
streaming_state,
index,
Expand Down Expand Up @@ -365,7 +394,8 @@ bool send_delta_to_client_brpc(
// Handle finish reason
if (seq_output.finish_reason.has_value()) {
// Check for unstreamed tool args before sending finish reason
if (streaming_state && streaming_state->has_tool_calls[index]) {
if (streaming_state && index < streaming_state->has_tool_calls.size() &&
streaming_state->has_tool_calls[index]) {
if (!check_for_unstreamed_tool_args(call,
streaming_state,
index,
Expand All @@ -385,7 +415,8 @@ bool send_delta_to_client_brpc(
choice->set_index(index);
choice->mutable_delta();

if (streaming_state && streaming_state->has_tool_calls[index] &&
if (streaming_state && index < streaming_state->has_tool_calls.size() &&
streaming_state->has_tool_calls[index] &&
seq_output.finish_reason.value() == "stop") {
choice->set_finish_reason("tool_calls");
} else {
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ DEFINE_bool(enable_shm,
DEFINE_string(tool_call_parser,
"",
"Specify the parser for handling tool-call interactions(e.g. "
"qwen25, qwen3, kimi_k2, deepseekv3).");
"qwen25, qwen3, kimi_k2, deepseekv3, glm45).");

// --- speculative config ---

Expand Down
3 changes: 3 additions & 0 deletions xllm/function_call/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ cc_library (
qwen25_detector.h
kimik2_detector.h
deepseekv3_detector.h
glm45_detector.h
function_call_parser.h
function_call.h
utils.h
Expand All @@ -20,6 +21,7 @@ cc_library (
qwen25_detector.cpp
kimik2_detector.cpp
deepseekv3_detector.cpp
glm45_detector.cpp
function_call_parser.cpp
utils.cpp
DEPS
Expand Down Expand Up @@ -47,4 +49,5 @@ endfunction()
add_detector_test(qwen25_detector_test)
add_detector_test(kimik2_detector_test)
add_detector_test(deepseekv3_detector_test)
add_detector_test(glm45_detector_test)

1 change: 1 addition & 0 deletions xllm/function_call/function_call.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include "core_types.h"
#include "deepseekv3_detector.h"
#include "function_call_parser.h"
#include "glm45_detector.h"
#include "kimik2_detector.h"
#include "qwen25_detector.h"

Expand Down
7 changes: 6 additions & 1 deletion xllm/function_call/function_call_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License.

#include "core/util/uuid.h"
#include "deepseekv3_detector.h"
#include "glm45_detector.h"
#include "kimik2_detector.h"
#include "qwen25_detector.h"
namespace xllm {
Expand All @@ -31,12 +32,12 @@ const std::unordered_map<std::string, std::string>
{"qwen3", "qwen25"},
{"kimi_k2", "kimi_k2"},
{"deepseekv3", "deepseekv3"},
{"glm45", "glm45"},
// TODO
// {"llama3", "llama3"},
// {"mistral", "mistral"},
// {"pythonic", "pythonic"},
// {"qwen3_coder", "qwen3_coder"},
// {"glm45", "glm45"},
// {"step3", "step3"},
};

Expand Down Expand Up @@ -96,6 +97,10 @@ std::unique_ptr<BaseFormatDetector> FunctionCallParser::create_detector(
return std::make_unique<DeepSeekV3Detector>();
}

if (it->second == "glm45") {
return std::make_unique<Glm45Detector>();
}

// if (tool_call_parser == "llama3") {
// return std::make_unique<Llama32Detector>();
// }
Expand Down
196 changes: 196 additions & 0 deletions xllm/function_call/glm45_detector.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "glm45_detector.h"

#include <algorithm>
#include <iostream>
#include <sstream>

namespace xllm {
namespace function_call {

Glm45Detector::Glm45Detector() : BaseFormatDetector() {
bot_token_ = "<tool_call>";
eot_token_ = "</tool_call>";

// Regex patterns for GLM-4.5 format
func_call_regex_ = std::regex("<tool_call>[\\s\\S]*?</tool_call>",
std::regex_constants::ECMAScript);
func_detail_regex_ =
std::regex("<tool_call>([^\\n]*)\\n([\\s\\S]*?)</tool_call>",
std::regex_constants::ECMAScript);
func_arg_regex_ = std::regex(
"<arg_key>([\\s\\S]*?)</arg_key>\\s*<arg_value>([\\s\\S]*?)</arg_value>",
std::regex_constants::ECMAScript);
}

std::string Glm45Detector::trim_whitespace(std::string_view str) const {
const char* whitespace = " \t\n\r";

size_t start = str.find_first_not_of(whitespace);
if (start == std::string_view::npos) {
return std::string{};
}

size_t end = str.find_last_not_of(whitespace);

return std::string(str.substr(start, end - start + 1));
}

bool Glm45Detector::has_tool_call(const std::string& text) {
return text.find(bot_token_) != std::string::npos;
}

StreamingParseResult Glm45Detector::detect_and_parse(
const std::string& text,
const std::vector<JsonTool>& tools) {
size_t idx = text.find(bot_token_);
std::string normal_text =
(idx != std::string::npos) ? text.substr(0, idx) : text;

// Trim normal text
if (!normal_text.empty()) {
normal_text = trim_whitespace(normal_text);
}

if (idx == std::string::npos) {
return StreamingParseResult(normal_text, {});
}

std::vector<ToolCallItem> calls;

try {
std::sregex_iterator iter(text.begin(), text.end(), func_call_regex_);
std::sregex_iterator end;

for (; iter != end; ++iter) {
std::smatch match = *iter;
std::string match_result = match.str();

// Parse function name and arguments
std::smatch func_detail;
if (std::regex_search(match_result, func_detail, func_detail_regex_)) {
std::string func_name = func_detail[1].str();
std::string func_args = func_detail[2].str();

// Parse arguments using regex
std::unordered_map<std::string, nlohmann::json> arguments;
std::sregex_iterator arg_iter(
func_args.begin(), func_args.end(), func_arg_regex_);
std::sregex_iterator arg_end;

for (; arg_iter != arg_end; ++arg_iter) {
std::smatch arg_match = *arg_iter;
if (arg_match.size() >= 3) {
std::string arg_key = arg_match[1].str();
std::string arg_value = arg_match[2].str();

arg_key = trim_whitespace(arg_key);

arg_value = trim_whitespace(arg_value);

try {
nlohmann::json parsed_value = nlohmann::json::parse(arg_value);
arguments[arg_key] = parsed_value;
} catch (const nlohmann::json::parse_error&) {
arguments[arg_key] = nlohmann::json(arg_value);
}
}
}

// Create JSON object for parse_base_json
nlohmann::json match_json;
match_json["name"] = func_name;
match_json["parameters"] = arguments;

auto parsed_calls = parse_base_json(match_json, tools);
calls.insert(calls.end(), parsed_calls.begin(), parsed_calls.end());
}
}

return StreamingParseResult(normal_text, calls);

} catch (const std::exception& e) {
LOG(ERROR) << "Error in GLM-4.5 detect_and_parse: " << e.what();
return StreamingParseResult(text, {});
}
}

StreamingParseResult Glm45Detector::parse_streaming_increment(
const std::string& new_text,
const std::vector<JsonTool>& tools) {
buffer_ += new_text;
std::string current_text = buffer_;

size_t start = current_text.find(bot_token_);
if (start == std::string::npos) {
buffer_.clear();
if (current_tool_id_ > 0) {
current_text = "";
}
return StreamingParseResult(current_text, {});
}

// Look for complete tool call
size_t end = current_text.find(eot_token_);
if (end != std::string::npos) {
// Initialize state if this is the first tool call
if (current_tool_id_ == -1) {
current_tool_id_ = 0;
prev_tool_call_arr_.clear();
streamed_args_for_tool_.clear();
streamed_args_for_tool_.push_back("");
}

// Ensure we have enough entries in tracking arrays
while (prev_tool_call_arr_.size() <= current_tool_id_) {
prev_tool_call_arr_.push_back({});
}
while (streamed_args_for_tool_.size() <= current_tool_id_) {
streamed_args_for_tool_.push_back("");
}

// Parse the complete tool call
std::string complete_call =
current_text.substr(0, end + eot_token_.length());
StreamingParseResult result = detect_and_parse(complete_call, tools);

if (!result.calls.empty()) {
// Store tool call info for serving layer
prev_tool_call_arr_[current_tool_id_]["name"] =
result.calls[0].name.value_or("");
prev_tool_call_arr_[current_tool_id_]["arguments"] =
result.calls[0].parameters;
streamed_args_for_tool_[current_tool_id_] = result.calls[0].parameters;

// Update tool index
result.calls[0].tool_index = current_tool_id_;
current_tool_id_++;
}

// Update buffer with remaining text
buffer_ = current_text.substr(end + eot_token_.length());
return result;
}

// Return normal text before tool call start
std::string normal_text = current_text.substr(0, start);
buffer_ = current_text.substr(start);
return StreamingParseResult(normal_text, {});
}

} // namespace function_call
} // namespace xllm
Loading