diff --git a/engine/commands/chat_cmd.cc b/engine/commands/chat_cmd.cc index bb44b476b..3f5f221bc 100644 --- a/engine/commands/chat_cmd.cc +++ b/engine/commands/chat_cmd.cc @@ -1,144 +1,17 @@ #include "chat_cmd.h" #include "httplib.h" -#include "cortex_upd_cmd.h" #include "database/models.h" #include "model_status_cmd.h" #include "server_start_cmd.h" #include "trantor/utils/Logger.h" #include "utils/logging_utils.h" +#include "run_cmd.h" namespace commands { -namespace { -constexpr const char* kExitChat = "exit()"; -constexpr const auto kMinDataChunkSize = 6u; -constexpr const char* kUser = "user"; -constexpr const char* kAssistant = "assistant"; - -} // namespace - -struct ChunkParser { - std::string content; - bool is_done = false; - - ChunkParser(const char* data, size_t data_length) { - if (data && data_length > kMinDataChunkSize) { - std::string s(data + kMinDataChunkSize, data_length - kMinDataChunkSize); - if (s.find("[DONE]") != std::string::npos) { - is_done = true; - } else { - try { - content = nlohmann::json::parse(s)["choices"][0]["delta"]["content"]; - } catch (const nlohmann::json::parse_error& e) { - CTL_WRN("JSON parse error: " << e.what()); - } - } - } - } -}; - void ChatCmd::Exec(const std::string& host, int port, - const std::string& model_handle, std::string msg) { - cortex::db::Models modellist_handler; - config::YamlHandler yaml_handler; - try { - auto model_entry = modellist_handler.GetModelInfo(model_handle); - if (model_entry.has_error()) { - CLI_LOG("Error: " + model_entry.error()); - return; - } - yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); - auto mc = yaml_handler.GetModelConfig(); - Exec(host, port, mc, std::move(msg)); - } catch (const std::exception& e) { - CLI_LOG("Fail to start model information with ID '" + model_handle + - "': " + e.what()); - } + const std::string& model_handle) { + RunCmd rc(host, port, model_handle); + rc.Exec(true /*chat_flag*/); } - -void ChatCmd::Exec(const std::string& host, int port, - const config::ModelConfig& mc, std::string msg) { - auto address = host + ":" + std::to_string(port); - // Check if server is started - { - if (!commands::IsServerAlive(host, port)) { - CLI_LOG("Server is not started yet, please run `" - << commands::GetCortexBinary() << " start` to start server!"); - return; - } - } - - // Only check if llamacpp engine - if ((mc.engine.find("llamacpp") != std::string::npos) && - !commands::ModelStatusCmd().IsLoaded(host, port, mc)) { - CLI_LOG("Model is not loaded yet!"); - return; - } - - // Interactive mode or not - bool interactive = msg.empty(); - - // Some instruction for user here - if (interactive) { - std::cout << "Inorder to exit, type `exit()`" << std::endl; - } - // Model is loaded, start to chat - { - do { - std::string user_input = std::move(msg); - if (user_input.empty()) { - std::cout << "> "; - std::getline(std::cin, user_input); - } - if (user_input == kExitChat) { - break; - } - - if (!user_input.empty()) { - httplib::Client cli(address); - nlohmann::json json_data; - nlohmann::json new_data; - new_data["role"] = kUser; - new_data["content"] = user_input; - histories_.push_back(std::move(new_data)); - json_data["engine"] = mc.engine; - json_data["messages"] = histories_; - json_data["model"] = mc.name; - //TODO: support non-stream - json_data["stream"] = true; - json_data["stop"] = mc.stop; - auto data_str = json_data.dump(); - // std::cout << data_str << std::endl; - cli.set_read_timeout(std::chrono::seconds(60)); - // std::cout << "> "; - httplib::Request req; - req.headers = httplib::Headers(); - req.set_header("Content-Type", "application/json"); - req.method = "POST"; - req.path = "/v1/chat/completions"; - req.body = data_str; - std::string ai_chat; - req.content_receiver = [&](const char* data, size_t data_length, - uint64_t offset, uint64_t total_length) { - ChunkParser cp(data, data_length); - if (cp.is_done) { - std::cout << std::endl; - return false; - } - std::cout << cp.content << std::flush; - ai_chat += cp.content; - return true; - }; - cli.send(req); - - nlohmann::json ai_res; - ai_res["role"] = kAssistant; - ai_res["content"] = ai_chat; - histories_.push_back(std::move(ai_res)); - } - // std::cout << "ok Done" << std::endl; - } while (interactive); - } -} - }; // namespace commands \ No newline at end of file diff --git a/engine/commands/chat_cmd.h b/engine/commands/chat_cmd.h index 596cfce2d..5abcb3cd6 100644 --- a/engine/commands/chat_cmd.h +++ b/engine/commands/chat_cmd.h @@ -1,18 +1,9 @@ #pragma once #include -#include -#include "config/model_config.h" -#include "nlohmann/json.hpp" namespace commands { class ChatCmd { public: - void Exec(const std::string& host, int port, const std::string& model_handle, - std::string msg); - void Exec(const std::string& host, int port, const config::ModelConfig& mc, - std::string msg); - - private: - std::vector histories_; + void Exec(const std::string& host, int port, const std::string& model_handle); }; } // namespace commands \ No newline at end of file diff --git a/engine/commands/chat_completion_cmd.cc b/engine/commands/chat_completion_cmd.cc new file mode 100644 index 000000000..fb228f021 --- /dev/null +++ b/engine/commands/chat_completion_cmd.cc @@ -0,0 +1,145 @@ +#include "chat_completion_cmd.h" +#include "httplib.h" + +#include "cortex_upd_cmd.h" +#include "database/models.h" +#include "model_status_cmd.h" +#include "server_start_cmd.h" +#include "trantor/utils/Logger.h" +#include "utils/logging_utils.h" +#include "run_cmd.h" + +namespace commands { +namespace { +constexpr const char* kExitChat = "exit()"; +constexpr const auto kMinDataChunkSize = 6u; +constexpr const char* kUser = "user"; +constexpr const char* kAssistant = "assistant"; + +} // namespace + +struct ChunkParser { + std::string content; + bool is_done = false; + + ChunkParser(const char* data, size_t data_length) { + if (data && data_length > kMinDataChunkSize) { + std::string s(data + kMinDataChunkSize, data_length - kMinDataChunkSize); + if (s.find("[DONE]") != std::string::npos) { + is_done = true; + } else { + try { + content = nlohmann::json::parse(s)["choices"][0]["delta"]["content"]; + } catch (const nlohmann::json::parse_error& e) { + CTL_WRN("JSON parse error: " << e.what()); + } + } + } + } +}; + +void ChatCompletionCmd::Exec(const std::string& host, int port, + const std::string& model_handle, std::string msg) { + cortex::db::Models modellist_handler; + config::YamlHandler yaml_handler; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + if (model_entry.has_error()) { + CLI_LOG("Error: " + model_entry.error()); + return; + } + yaml_handler.ModelConfigFromFile(model_entry.value().path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + Exec(host, port, mc, std::move(msg)); + } catch (const std::exception& e) { + CLI_LOG("Fail to start model information with ID '" + model_handle + + "': " + e.what()); + } +} + +void ChatCompletionCmd::Exec(const std::string& host, int port, + const config::ModelConfig& mc, std::string msg) { + auto address = host + ":" + std::to_string(port); + // Check if server is started + { + if (!commands::IsServerAlive(host, port)) { + CLI_LOG("Server is not started yet, please run `" + << commands::GetCortexBinary() << " start` to start server!"); + return; + } + } + + // Only check if llamacpp engine + if ((mc.engine.find("llamacpp") != std::string::npos) && + !commands::ModelStatusCmd().IsLoaded(host, port, mc)) { + CLI_LOG("Model is not loaded yet!"); + return; + } + + // Interactive mode or not + bool interactive = msg.empty(); + + // Some instruction for user here + if (interactive) { + std::cout << "Inorder to exit, type `exit()`" << std::endl; + } + // Model is loaded, start to chat + { + do { + std::string user_input = std::move(msg); + if (user_input.empty()) { + std::cout << "> "; + std::getline(std::cin, user_input); + } + if (user_input == kExitChat) { + break; + } + + if (!user_input.empty()) { + httplib::Client cli(address); + nlohmann::json json_data; + nlohmann::json new_data; + new_data["role"] = kUser; + new_data["content"] = user_input; + histories_.push_back(std::move(new_data)); + json_data["engine"] = mc.engine; + json_data["messages"] = histories_; + json_data["model"] = mc.name; + //TODO: support non-stream + json_data["stream"] = true; + json_data["stop"] = mc.stop; + auto data_str = json_data.dump(); + // std::cout << data_str << std::endl; + cli.set_read_timeout(std::chrono::seconds(60)); + // std::cout << "> "; + httplib::Request req; + req.headers = httplib::Headers(); + req.set_header("Content-Type", "application/json"); + req.method = "POST"; + req.path = "/v1/chat/completions"; + req.body = data_str; + std::string ai_chat; + req.content_receiver = [&](const char* data, size_t data_length, + uint64_t offset, uint64_t total_length) { + ChunkParser cp(data, data_length); + if (cp.is_done) { + std::cout << std::endl; + return false; + } + std::cout << cp.content << std::flush; + ai_chat += cp.content; + return true; + }; + cli.send(req); + + nlohmann::json ai_res; + ai_res["role"] = kAssistant; + ai_res["content"] = ai_chat; + histories_.push_back(std::move(ai_res)); + } + // std::cout << "ok Done" << std::endl; + } while (interactive); + } +} + +}; // namespace commands \ No newline at end of file diff --git a/engine/commands/chat_completion_cmd.h b/engine/commands/chat_completion_cmd.h new file mode 100644 index 000000000..bd488e91f --- /dev/null +++ b/engine/commands/chat_completion_cmd.h @@ -0,0 +1,18 @@ +#pragma once +#include +#include +#include "config/model_config.h" +#include "nlohmann/json.hpp" + +namespace commands { +class ChatCompletionCmd { + public: + void Exec(const std::string& host, int port, const std::string& model_handle, + std::string msg); + void Exec(const std::string& host, int port, const config::ModelConfig& mc, + std::string msg); + + private: + std::vector histories_; +}; +} // namespace commands \ No newline at end of file diff --git a/engine/commands/run_cmd.cc b/engine/commands/run_cmd.cc index d1a88733d..2fff4c285 100644 --- a/engine/commands/run_cmd.cc +++ b/engine/commands/run_cmd.cc @@ -1,5 +1,5 @@ #include "run_cmd.h" -#include "chat_cmd.h" +#include "chat_completion_cmd.h" #include "config/yaml_config.h" #include "database/models.h" #include "model_start_cmd.h" @@ -7,9 +7,11 @@ #include "server_start_cmd.h" #include "utils/logging_utils.h" +#include "cortex_upd_cmd.h" + namespace commands { -void RunCmd::Exec() { +void RunCmd::Exec(bool chat_flag) { std::optional model_id = model_handle_; cortex::db::Models modellist_handler; @@ -78,7 +80,13 @@ void RunCmd::Exec() { } // Chat - ChatCmd().Exec(host_, port_, mc, ""); + if (chat_flag) { + ChatCompletionCmd().Exec(host_, port_, mc, ""); + } else { + CLI_LOG(*model_id << " model started successfully. Use `" + << commands::GetCortexBinary() << " chat " << *model_id + << "` for interactive chat shell"); + } } catch (const std::exception& e) { CLI_LOG("Fail to run model with ID '" + model_handle_ + "': " + e.what()); } diff --git a/engine/commands/run_cmd.h b/engine/commands/run_cmd.h index 136800102..3d5c77719 100644 --- a/engine/commands/run_cmd.h +++ b/engine/commands/run_cmd.h @@ -1,5 +1,6 @@ #pragma once #include +#include "nlohmann/json.hpp" #include "services/engine_service.h" #include "services/model_service.h" @@ -12,7 +13,7 @@ class RunCmd { model_handle_{std::move(model_handle)}, model_service_{ModelService()} {}; - void Exec(); + void Exec(bool chat_flag); private: std::string host_; diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index 61abf7182..d5951906f 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -1,5 +1,6 @@ #include "command_line_parser.h" #include "commands/chat_cmd.h" +#include "commands/chat_completion_cmd.h" #include "commands/cmd_info.h" #include "commands/cortex_upd_cmd.h" #include "commands/engine_get_cmd.h" @@ -123,12 +124,12 @@ void CommandLineParser::SetupCommonCommands() { } }); - auto run_cmd = - app_.add_subcommand("run", "Shortcut to start a model and chat"); + auto run_cmd = app_.add_subcommand("run", "Shortcut to start a model"); run_cmd->group(kCommonCommandsGroup); run_cmd->usage("Usage:\n" + commands::GetCortexBinary() + " run [options] [model_id]"); run_cmd->add_option("model_id", cml_data_.model_id, ""); + run_cmd->add_flag("--chat", cml_data_.chat_flag, "Flag for interactive mode"); run_cmd->callback([this, run_cmd] { if (cml_data_.model_id.empty()) { CLI_LOG("[model_id] is required\n"); @@ -138,10 +139,12 @@ void CommandLineParser::SetupCommonCommands() { commands::RunCmd rc(cml_data_.config.apiServerHost, std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id); - rc.Exec(); + rc.Exec(cml_data_.chat_flag); }); - auto chat_cmd = app_.add_subcommand("chat", "Send a chat completion request"); + auto chat_cmd = app_.add_subcommand( + "chat", + "Shortcut for `cortex run --chat` or send a chat completion request"); chat_cmd->group(kCommonCommandsGroup); chat_cmd->usage("Usage:\n" + commands::GetCortexBinary() + " chat [model_id] -m [msg]"); @@ -149,15 +152,22 @@ void CommandLineParser::SetupCommonCommands() { chat_cmd->add_option("-m,--message", cml_data_.msg, "Message to chat with model"); chat_cmd->callback([this, chat_cmd] { - if (cml_data_.model_id.empty() || cml_data_.msg.empty()) { - CLI_LOG("[model_id] and [msg] are required\n"); + if (cml_data_.model_id.empty()) { + CLI_LOG("[model_id] is required\n"); CLI_LOG(chat_cmd->help()); return; } - commands::ChatCmd().Exec(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - cml_data_.model_id, cml_data_.msg); + if (cml_data_.msg.empty()) { + commands::ChatCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + cml_data_.model_id); + } else { + commands::ChatCompletionCmd().Exec( + cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, + cml_data_.msg); + } }); } diff --git a/engine/controllers/command_line_parser.h b/engine/controllers/command_line_parser.h index 1ca308eef..f93dac1ef 100644 --- a/engine/controllers/command_line_parser.h +++ b/engine/controllers/command_line_parser.h @@ -38,6 +38,7 @@ class CommandLineParser { std::string engine_src; std::string cortex_version; bool check_upd = true; + bool chat_flag = false; int port; config_yaml_utils::CortexConfig config; std::unordered_map model_update_options;