From 2ebedda3d9b4e4fc7172df3b93cfe90b39f96d57 Mon Sep 17 00:00:00 2001 From: ngxson Date: Thu, 8 Feb 2024 23:16:58 +0100 Subject: [PATCH 1/6] server: add mistral chat template --- examples/server/oai.hpp | 9 +++++++-- examples/server/server.cpp | 16 ++++++++++++++-- examples/server/utils.hpp | 30 ++++++++++++++++++++++++++++++ 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index 43410f803d469..ce87f0aac22f4 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -15,9 +15,14 @@ using json = nlohmann::json; inline static json oaicompat_completion_params_parse( - const json &body /* openai api json semantics */) + const json &body, /* openai api json semantics */ + const std::string &chat_template) { json llama_params; + bool using_chatml = chat_template == "chatml"; + std::string formated_prompt = using_chatml + ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml + : format_mistral(body["messages"]); // OpenAI 'messages' to mistral format llama_params["__oaicompat"] = true; @@ -30,7 +35,7 @@ inline static json oaicompat_completion_params_parse( // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt' + llama_params["prompt"] = formated_prompt; llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index eceda30d05fcc..2183abcb190e4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -36,6 +36,7 @@ struct server_params std::string hostname = "127.0.0.1"; std::vector api_keys; std::string public_path = "examples/server/public"; + std::string chat_template = "chatml"; int32_t port = 8080; int32_t read_timeout = 600; int32_t write_timeout = 600; @@ -1859,6 +1860,8 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); + printf(" --chat-template FORMAT_NAME"); + printf(" set chat template, possible valus is: mistral, chatml (default %s)", sparams.chat_template.c_str()); printf("\n"); } @@ -2290,6 +2293,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, log_set_target(stdout); LOG_INFO("logging to file is disabled.", {}); } + else if (arg == "--chat-template") + { + if (++i >= argc) + { + invalid_param = true; + break; + } + sparams.chat_template = argv[i]; + } else if (arg == "--override-kv") { if (++i >= argc) { @@ -2743,13 +2755,13 @@ int main(int argc, char **argv) // TODO: add mount point without "/v1" prefix -- how? - svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) + svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } - json data = oaicompat_completion_params_parse(json::parse(req.body)); + json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template); const int task_id = llama.queue_tasks.get_new_id(); llama.queue_results.add_waiting_task_id(task_id); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 70cce0721be08..5ec743fb8b90b 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -167,6 +167,34 @@ static T json_value(const json &body, const std::string &key, const T &default_v : default_value; } +inline std::string format_mistral(std::vector messages) +{ + std::ostringstream output; + bool is_inside_turn = false; + + for (auto it = messages.begin(); it != messages.end(); ++it) { + if (!is_inside_turn) { + output << "[INST] "; + } + std::string role = json_value(*it, "role", std::string("user")); + std::string content = json_value(*it, "content", std::string("")); + if (role == "system") { + output << "<>\n" << content << "\n<>\n\n"; + is_inside_turn = true; + } else if (role == "user") { + output << content << " [/INST]"; + is_inside_turn = true; + } else { + output << " " << content << " "; + is_inside_turn = false; + } + } + + LOG_VERBOSE("format_mistral", {{"text", output.str()}}); + + return output.str(); +} + inline std::string format_chatml(std::vector messages) { std::ostringstream chatml_msgs; @@ -180,6 +208,8 @@ inline std::string format_chatml(std::vector messages) chatml_msgs << "<|im_start|>assistant" << '\n'; + LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}}); + return chatml_msgs.str(); } From 27976c31b6c8293978fff41a951c9759b5e76735 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 9 Feb 2024 09:32:51 +0100 Subject: [PATCH 2/6] server: fix typo --- examples/server/oai.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index ce87f0aac22f4..c0c1f521ecb61 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -20,7 +20,7 @@ inline static json oaicompat_completion_params_parse( { json llama_params; bool using_chatml = chat_template == "chatml"; - std::string formated_prompt = using_chatml + std::string formatted_prompt = using_chatml ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml : format_mistral(body["messages"]); // OpenAI 'messages' to mistral format @@ -35,7 +35,7 @@ inline static json oaicompat_completion_params_parse( // https://platform.openai.com/docs/api-reference/chat/create llama_sampling_params default_sparams; llama_params["model"] = json_value(body, "model", std::string("unknown")); - llama_params["prompt"] = formated_prompt; + llama_params["prompt"] = formatted_prompt; llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); From 269437e4ebb3f793b6c37e357695e358189c8768 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 9 Feb 2024 09:55:20 +0100 Subject: [PATCH 3/6] server: rename template mistral to llama2 --- examples/server/oai.hpp | 4 ++-- examples/server/server.cpp | 2 +- examples/server/utils.hpp | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index c0c1f521ecb61..2199df237b495 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -21,8 +21,8 @@ inline static json oaicompat_completion_params_parse( json llama_params; bool using_chatml = chat_template == "chatml"; std::string formatted_prompt = using_chatml - ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml - : format_mistral(body["messages"]); // OpenAI 'messages' to mistral format + ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...) + : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...) llama_params["__oaicompat"] = true; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 2183abcb190e4..d4161ed92cd6f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1861,7 +1861,7 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); printf(" --chat-template FORMAT_NAME"); - printf(" set chat template, possible valus is: mistral, chatml (default %s)", sparams.chat_template.c_str()); + printf(" set chat template, possible valus is: llama2, chatml (default %s)", sparams.chat_template.c_str()); printf("\n"); } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 5ec743fb8b90b..b1623c737e9ac 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -167,7 +167,7 @@ static T json_value(const json &body, const std::string &key, const T &default_v : default_value; } -inline std::string format_mistral(std::vector messages) +inline std::string format_llama2(std::vector messages) { std::ostringstream output; bool is_inside_turn = false; @@ -190,7 +190,7 @@ inline std::string format_mistral(std::vector messages) } } - LOG_VERBOSE("format_mistral", {{"text", output.str()}}); + LOG_VERBOSE("format_llama2", {{"text", output.str()}}); return output.str(); } From 7efef47d2eb09160fc612c4edb811d261436e256 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 9 Feb 2024 17:00:36 +0100 Subject: [PATCH 4/6] server: format_llama2: remove BOS --- examples/server/utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index b1623c737e9ac..5485489627d5d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -174,7 +174,7 @@ inline std::string format_llama2(std::vector messages) for (auto it = messages.begin(); it != messages.end(); ++it) { if (!is_inside_turn) { - output << "[INST] "; + output << "[INST] "; } std::string role = json_value(*it, "role", std::string("user")); std::string content = json_value(*it, "content", std::string("")); From ebe30795390a4a5818d5b7c7bf2156dc2fd0312a Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 9 Feb 2024 17:00:53 +0100 Subject: [PATCH 5/6] server: validate "--chat-template" argument --- examples/server/server.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d4161ed92cd6f..70baf8de4df35 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2300,7 +2300,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } - sparams.chat_template = argv[i]; + std::string value(argv[i]); + if (value != "chatml" && value != "llama2") { + fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str()); + invalid_param = true; + break; + } + sparams.chat_template = value; } else if (arg == "--override-kv") { From 1a27406426947a51e968922605b6f6b51f63ee00 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 10 Feb 2024 00:11:13 +0100 Subject: [PATCH 6/6] server: clean up using_chatml variable Co-authored-by: Jared Van Bortel --- examples/server/oai.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index 2199df237b495..2eca8a9fb4560 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -19,8 +19,7 @@ inline static json oaicompat_completion_params_parse( const std::string &chat_template) { json llama_params; - bool using_chatml = chat_template == "chatml"; - std::string formatted_prompt = using_chatml + std::string formatted_prompt = chat_template == "chatml" ? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...) : format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)