diff --git a/engine/commands/chat_cmd.cc b/engine/commands/chat_cmd.cc index da232a321..e4d0eda3d 100644 --- a/engine/commands/chat_cmd.cc +++ b/engine/commands/chat_cmd.cc @@ -6,6 +6,7 @@ #include "server_start_cmd.h" #include "trantor/utils/Logger.h" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace commands { namespace { @@ -36,23 +37,36 @@ struct ChunkParser { } }; -ChatCmd::ChatCmd(std::string host, int port, const config::ModelConfig& mc) - : host_(std::move(host)), port_(port), mc_(mc) {} +void ChatCmd::Exec(const std::string& host, int port, + const std::string& model_handle, std::string msg) { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + Exec(host, port, mc, std::move(msg)); + } catch (const std::exception& e) { + CLI_LOG("Fail to start model information with ID '" + model_handle + + "': " + e.what()); + } +} -void ChatCmd::Exec(std::string msg) { +void ChatCmd::Exec(const std::string& host, int port, + const config::ModelConfig& mc, std::string msg) { + auto address = host + ":" + std::to_string(port); // Check if server is started { - if (!commands::IsServerAlive(host_, port_)) { + if (!commands::IsServerAlive(host, port)) { CLI_LOG("Server is not started yet, please run `" << commands::GetCortexBinary() << " start` to start server!"); return; } } - auto address = host_ + ":" + std::to_string(port_); // Only check if llamacpp engine - if ((mc_.engine.find("llamacpp") != std::string::npos) && - !commands::ModelStatusCmd().IsLoaded(host_, port_, mc_)) { + if ((mc.engine.find("llamacpp") != std::string::npos) && + !commands::ModelStatusCmd().IsLoaded(host, port, mc)) { CLI_LOG("Model is not loaded yet!"); return; } @@ -78,12 +92,12 @@ void ChatCmd::Exec(std::string msg) { new_data["role"] = kUser; new_data["content"] = user_input; histories_.push_back(std::move(new_data)); - json_data["engine"] = mc_.engine; + json_data["engine"] = mc.engine; json_data["messages"] = histories_; - json_data["model"] = mc_.name; + json_data["model"] = mc.name; //TODO: support non-stream json_data["stream"] = true; - json_data["stop"] = mc_.stop; + json_data["stop"] = mc.stop; auto data_str = json_data.dump(); // std::cout << data_str << std::endl; cli.set_read_timeout(std::chrono::seconds(60)); diff --git a/engine/commands/chat_cmd.h b/engine/commands/chat_cmd.h index d5b48927c..596cfce2d 100644 --- a/engine/commands/chat_cmd.h +++ b/engine/commands/chat_cmd.h @@ -7,13 +7,12 @@ namespace commands { class ChatCmd { public: - ChatCmd(std::string host, int port, const config::ModelConfig& mc); - void Exec(std::string msg); + void Exec(const std::string& host, int port, const std::string& model_handle, + std::string msg); + void Exec(const std::string& host, int port, const config::ModelConfig& mc, + std::string msg); private: - std::string host_; - int port_; - const config::ModelConfig& mc_; std::vector histories_; }; } // namespace commands \ No newline at end of file diff --git a/engine/commands/model_del_cmd.cc b/engine/commands/model_del_cmd.cc index f2023f5c1..7f6b6d32a 100644 --- a/engine/commands/model_del_cmd.cc +++ b/engine/commands/model_del_cmd.cc @@ -2,55 +2,47 @@ #include "cmd_info.h" #include "config/yaml_config.h" #include "utils/file_manager_utils.h" +#include "utils/modellist_utils.h" namespace commands { -bool ModelDelCmd::Exec(const std::string& model_id) { - // TODO this implentation may be changed after we have a decision - // on https://github.com/janhq/cortex.cpp/issues/1154 but the logic should be similar - CmdInfo ci(model_id); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - auto models_path = file_manager_utils::GetModelsContainerPath(); - if (std::filesystem::exists(models_path) && - std::filesystem::is_directory(models_path)) { - // Iterate through directory - for (const auto& entry : std::filesystem::directory_iterator(models_path)) { - if (entry.is_regular_file() && entry.path().extension() == ".yaml") { - try { - config::YamlHandler handler; - handler.ModelConfigFromFile(entry.path().string()); - auto cfg = handler.GetModelConfig(); - if (entry.path().stem().string() == model_file) { - // Delete data - if (cfg.files.size() > 0) { - std::filesystem::path f(cfg.files[0]); - auto rel = std::filesystem::relative(f, models_path); - // Only delete model data if it is stored in our models folder - if (!rel.empty()) { - if (cfg.engine == "cortex.llamacpp") { - std::filesystem::remove_all(f.parent_path()); - } else { - std::filesystem::remove_all(f); - } - } - } +bool ModelDelCmd::Exec(const std::string& model_handle) { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; - // Delete yaml file - std::filesystem::remove(entry); - CLI_LOG("The model " << model_id << " was deleted"); - return true; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + // Remove yaml file + std::filesystem::remove(model_entry.path_to_model_yaml); + // Remove model files if they are not imported locally + if (model_entry.branch_name != "imported") { + if (mc.files.size() > 0) { + if (mc.engine == "cortex.llamacpp") { + for (auto& file : mc.files) { + std::filesystem::path gguf_p(file); + std::filesystem::remove(gguf_p); } - } catch (const std::exception& e) { - CTL_WRN("Error reading yaml file '" << entry.path().string() - << "': " << e.what()); - return false; + } else { + std::filesystem::path f(mc.files[0]); + std::filesystem::remove_all(f); } + } else { + CTL_WRN("model config files are empty!"); } } - } - - CLI_LOG("Model does not exist: " << model_id); - return false; + // update model.list + if (modellist_handler.DeleteModelEntry(model_handle)) { + CLI_LOG("The model " << model_handle << " was deleted"); + return true; + } else { + CTL_ERR("Could not delete model: " << model_handle); + return false; + } + } catch (const std::exception& e) { + CLI_LOG("Fail to delete model with ID '" + model_handle + "': " + e.what()); + false; + } } } // namespace commands \ No newline at end of file diff --git a/engine/commands/model_del_cmd.h b/engine/commands/model_del_cmd.h index 0dd41f74e..437564208 100644 --- a/engine/commands/model_del_cmd.h +++ b/engine/commands/model_del_cmd.h @@ -6,6 +6,6 @@ namespace commands { class ModelDelCmd { public: - bool Exec(const std::string& model_id); + bool Exec(const std::string& model_handle); }; } \ No newline at end of file diff --git a/engine/commands/model_import_cmd.cc b/engine/commands/model_import_cmd.cc index 193b2488b..3fb047a9d 100644 --- a/engine/commands/model_import_cmd.cc +++ b/engine/commands/model_import_cmd.cc @@ -1,10 +1,8 @@ #include "model_import_cmd.h" #include -#include #include #include "config/gguf_parser.h" #include "config/yaml_config.h" -#include "trantor/utils/Logger.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" #include "utils/modellist_utils.h" @@ -45,7 +43,7 @@ void ModelImportCmd::Exec() { } } catch (const std::exception& e) { - // don't need to remove yml file here, because it's written only if model entry is successfully added, + // don't need to remove yml file here, because it's written only if model entry is successfully added, // remove file here can make it fail with edge case when user try to import new model with existed model_id CLI_LOG("Error importing model path '" + model_path_ + "' with model_id '" + model_handle_ + "': " + e.what()); diff --git a/engine/commands/model_start_cmd.cc b/engine/commands/model_start_cmd.cc index 1a96b4fee..1340614d9 100644 --- a/engine/commands/model_start_cmd.cc +++ b/engine/commands/model_start_cmd.cc @@ -7,43 +7,59 @@ #include "trantor/utils/Logger.h" #include "utils/file_manager_utils.h" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace commands { -ModelStartCmd::ModelStartCmd(std::string host, int port, - const config::ModelConfig& mc) - : host_(std::move(host)), port_(port), mc_(mc) {} +bool ModelStartCmd::Exec(const std::string& host, int port, + const std::string& model_handle) { -bool ModelStartCmd::Exec() { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + return Exec(host, port, mc); + } catch (const std::exception& e) { + CLI_LOG("Fail to start model information with ID '" + model_handle + + "': " + e.what()); + return false; + } +} + +bool ModelStartCmd::Exec(const std::string& host, int port, + const config::ModelConfig& mc) { // Check if server is started - if (!commands::IsServerAlive(host_, port_)) { + if (!commands::IsServerAlive(host, port)) { CLI_LOG("Server is not started yet, please run `" << commands::GetCortexBinary() << " start` to start server!"); return false; } + // Only check for llamacpp for now - if ((mc_.engine.find("llamacpp") != std::string::npos) && - commands::ModelStatusCmd().IsLoaded(host_, port_, mc_)) { + if ((mc.engine.find("llamacpp") != std::string::npos) && + commands::ModelStatusCmd().IsLoaded(host, port, mc)) { CLI_LOG("Model has already been started!"); return true; } - httplib::Client cli(host_ + ":" + std::to_string(port_)); + httplib::Client cli(host + ":" + std::to_string(port)); nlohmann::json json_data; - if (mc_.files.size() > 0) { + if (mc.files.size() > 0) { // TODO(sang) support multiple files - json_data["model_path"] = mc_.files[0]; + json_data["model_path"] = mc.files[0]; } else { LOG_WARN << "model_path is empty"; return false; } - json_data["model"] = mc_.name; - json_data["system_prompt"] = mc_.system_template; - json_data["user_prompt"] = mc_.user_template; - json_data["ai_prompt"] = mc_.ai_template; - json_data["ctx_len"] = mc_.ctx_len; - json_data["stop"] = mc_.stop; - json_data["engine"] = mc_.engine; + json_data["model"] = mc.name; + json_data["system_prompt"] = mc.system_template; + json_data["user_prompt"] = mc.user_template; + json_data["ai_prompt"] = mc.ai_template; + json_data["ctx_len"] = mc.ctx_len; + json_data["stop"] = mc.stop; + json_data["engine"] = mc.engine; auto data_str = json_data.dump(); cli.set_read_timeout(std::chrono::seconds(60)); @@ -52,13 +68,17 @@ bool ModelStartCmd::Exec() { if (res) { if (res->status == httplib::StatusCode::OK_200) { CLI_LOG("Model loaded!"); + return true; + } else { + CTL_ERR("Model failed to load with status code: " << res->status); + return false; } } else { auto err = res.error(); CTL_ERR("HTTP error: " << httplib::to_string(err)); return false; } - return true; + return false; } }; // namespace commands diff --git a/engine/commands/model_start_cmd.h b/engine/commands/model_start_cmd.h index 26daf9d0e..fbf3c0645 100644 --- a/engine/commands/model_start_cmd.h +++ b/engine/commands/model_start_cmd.h @@ -6,13 +6,8 @@ namespace commands { class ModelStartCmd { public: - explicit ModelStartCmd(std::string host, int port, - const config::ModelConfig& mc); - bool Exec(); + bool Exec(const std::string& host, int port, const std::string& model_handle); - private: - std::string host_; - int port_; - const config::ModelConfig& mc_; + bool Exec(const std::string& host, int port, const config::ModelConfig& mc); }; } // namespace commands diff --git a/engine/commands/model_status_cmd.cc b/engine/commands/model_status_cmd.cc index f54aa9100..e6ba9bbe0 100644 --- a/engine/commands/model_status_cmd.cc +++ b/engine/commands/model_status_cmd.cc @@ -3,8 +3,25 @@ #include "httplib.h" #include "nlohmann/json.hpp" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace commands { +bool ModelStatusCmd::IsLoaded(const std::string& host, int port, + const std::string& model_handle) { + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; + try { + auto model_entry = modellist_handler.GetModelInfo(model_handle); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + return IsLoaded(host, port, mc); + } catch (const std::exception& e) { + CLI_LOG("Fail to get model status with ID '" + model_handle + + "': " + e.what()); + return false; + } +} + bool ModelStatusCmd::IsLoaded(const std::string& host, int port, const config::ModelConfig& mc) { httplib::Client cli(host + ":" + std::to_string(port)); diff --git a/engine/commands/model_status_cmd.h b/engine/commands/model_status_cmd.h index 2ef44a41d..273d73ef9 100644 --- a/engine/commands/model_status_cmd.h +++ b/engine/commands/model_status_cmd.h @@ -6,6 +6,8 @@ namespace commands { class ModelStatusCmd { public: + bool IsLoaded(const std::string& host, int port, + const std::string& model_handle); bool IsLoaded(const std::string& host, int port, const config::ModelConfig& mc); }; diff --git a/engine/commands/model_upd_cmd.cc b/engine/commands/model_upd_cmd.cc new file mode 100644 index 000000000..65883def3 --- /dev/null +++ b/engine/commands/model_upd_cmd.cc @@ -0,0 +1,300 @@ +#include "model_upd_cmd.h" + +#include "utils/logging_utils.h" + +namespace commands { + +ModelUpdCmd::ModelUpdCmd(std::string model_handle) + : model_handle_(std::move(model_handle)) {} + +void ModelUpdCmd::Exec( + const std::unordered_map& options) { + try { + auto model_entry = model_list_utils_.GetModelInfo(model_handle_); + yaml_handler_.ModelConfigFromFile(model_entry.path_to_model_yaml); + model_config_ = yaml_handler_.GetModelConfig(); + + for (const auto& [key, value] : options) { + if (!value.empty()) { + UpdateConfig(key, value); + } + } + + yaml_handler_.UpdateModelConfig(model_config_); + yaml_handler_.WriteYamlFile(model_entry.path_to_model_yaml); + CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!"); + } catch (const std::exception& e) { + CLI_LOG("Failed to update model with model ID '" + model_handle_ + + "': " + e.what()); + } +} + +void ModelUpdCmd::UpdateConfig(const std::string& key, + const std::string& value) { + static const std::unordered_map< + std::string, + std::function> + updaters = { + {"name", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.name = v; + }}, + {"model", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.model = v; + }}, + {"version", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.version = v; + }}, + {"engine", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.engine = v; + }}, + {"prompt_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.prompt_template = v; + }}, + {"system_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.system_template = v; + }}, + {"user_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.user_template = v; + }}, + {"ai_template", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.ai_template = v; + }}, + {"os", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.os = v; + }}, + {"gpu_arch", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.gpu_arch = v; + }}, + {"quantization_method", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.quantization_method = v; + }}, + {"precision", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.precision = v; + }}, + {"trtllm_version", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.trtllm_version = v; + }}, + {"object", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.object = v; + }}, + {"owned_by", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.owned_by = v; + }}, + {"grammar", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.grammar = v; + }}, + {"stop", &ModelUpdCmd::UpdateVectorField}, + {"files", &ModelUpdCmd::UpdateVectorField}, + {"top_p", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.top_p = f; }); + }}, + {"temperature", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.temperature = f; + }); + }}, + {"frequency_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.frequency_penalty = f; + }); + }}, + {"presence_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.presence_penalty = f; + }); + }}, + {"dynatemp_range", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.dynatemp_range = f; + }); + }}, + {"dynatemp_exponent", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.dynatemp_exponent = f; + }); + }}, + {"min_p", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.min_p = f; }); + }}, + {"tfs_z", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.tfs_z = f; }); + }}, + {"typ_p", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.typ_p = f; }); + }}, + {"repeat_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.repeat_penalty = f; + }); + }}, + {"mirostat_tau", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.mirostat_tau = f; + }); + }}, + {"mirostat_eta", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.mirostat_eta = f; + }); + }}, + {"max_tokens", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.max_tokens = static_cast(f); + }); + }}, + {"ngl", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.ngl = static_cast(f); + }); + }}, + {"ctx_len", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.ctx_len = static_cast(f); + }); + }}, + {"tp", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.tp = static_cast(f); + }); + }}, + {"seed", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.seed = static_cast(f); + }); + }}, + {"top_k", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.top_k = static_cast(f); + }); + }}, + {"repeat_last_n", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.repeat_last_n = static_cast(f); + }); + }}, + {"n_probs", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.n_probs = static_cast(f); + }); + }}, + {"min_keep", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.min_keep = static_cast(f); + }); + }}, + {"stream", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.stream = b; }); + }}, + {"text_model", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.text_model = b; }); + }}, + {"mirostat", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.mirostat = b; }); + }}, + {"penalize_nl", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.penalize_nl = b; }); + }}, + {"ignore_eos", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.ignore_eos = b; }); + }}, + {"created", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.created = static_cast(f); + }); + }}, + }; + + if (auto it = updaters.find(key); it != updaters.end()) { + it->second(this, key, value); + LogUpdate(key, value); + } else { + CLI_LOG("Warning: Unknown configuration key '" << key << "' ignored."); + } +} + +void ModelUpdCmd::UpdateVectorField(const std::string& key, + const std::string& value) { + std::vector tokens; + std::istringstream iss(value); + std::string token; + while (std::getline(iss, token, ',')) { + tokens.push_back(token); + } + model_config_.stop = tokens; +} + +void ModelUpdCmd::UpdateNumericField(const std::string& key, + const std::string& value, + std::function setter) { + try { + float numericValue = std::stof(value); + setter(numericValue); + } catch (const std::exception& e) { + CLI_LOG("Failed to parse numeric value for " << key << ": " << e.what()); + } +} + +void ModelUpdCmd::UpdateBooleanField(const std::string& key, + const std::string& value, + std::function setter) { + bool boolValue = (value == "true" || value == "1"); + setter(boolValue); +} + +void ModelUpdCmd::LogUpdate(const std::string& key, const std::string& value) { + CLI_LOG("Updated " << key << " to: " << value); +} + +} // namespace commands \ No newline at end of file diff --git a/engine/commands/model_upd_cmd.h b/engine/commands/model_upd_cmd.h new file mode 100644 index 000000000..51f5a88d3 --- /dev/null +++ b/engine/commands/model_upd_cmd.h @@ -0,0 +1,30 @@ +#pragma once +#include +#include +#include +#include +#include +#include "config/model_config.h" +#include "utils/modellist_utils.h" +#include "config/yaml_config.h" +namespace commands { +class ModelUpdCmd { + public: + ModelUpdCmd(std::string model_handle); + void Exec(const std::unordered_map& options); + + private: + std::string model_handle_; + config::ModelConfig model_config_; + config::YamlHandler yaml_handler_; + modellist_utils::ModelListUtils model_list_utils_; + + void UpdateConfig(const std::string& key, const std::string& value); + void UpdateVectorField(const std::string& key, const std::string& value); + void UpdateNumericField(const std::string& key, const std::string& value, + std::function setter); + void UpdateBooleanField(const std::string& key, const std::string& value, + std::function setter); + void LogUpdate(const std::string& key, const std::string& value); +}; +} // namespace commands \ No newline at end of file diff --git a/engine/commands/run_cmd.cc b/engine/commands/run_cmd.cc index 16b496b0d..d17d91e9f 100644 --- a/engine/commands/run_cmd.cc +++ b/engine/commands/run_cmd.cc @@ -5,71 +5,76 @@ #include "model_start_cmd.h" #include "model_status_cmd.h" #include "server_start_cmd.h" +#include "utils/cortex_utils.h" #include "utils/file_manager_utils.h" - +#include "utils/modellist_utils.h" namespace commands { void RunCmd::Exec() { + std::optional model_id = model_handle_; + + modellist_utils::ModelListUtils modellist_handler; + config::YamlHandler yaml_handler; auto address = host_ + ":" + std::to_string(port_); - CmdInfo ci(model_id_); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - // TODO should we clean all resource if something fails? - // Check if model existed. If not, download it - { - auto model_conf = model_service_.GetDownloadedModel(model_file + ".yaml"); - if (!model_conf.has_value()) { - model_service_.DownloadModel(model_id_); - } - } - // Check if engine existed. If not, download it + // Download model if it does not exist { - auto required_engine = engine_service_.GetEngineInfo(ci.engine_name); - if (!required_engine.has_value()) { - throw std::runtime_error("Engine not found: " + ci.engine_name); - } - if (required_engine.value().status == EngineService::kIncompatible) { - throw std::runtime_error("Engine " + ci.engine_name + " is incompatible"); - } - if (required_engine.value().status == EngineService::kNotInstalled) { - engine_service_.InstallEngine(ci.engine_name); + if (!modellist_handler.HasModel(model_handle_)) { + model_id = model_service_.DownloadModel(model_handle_); + if (!model_id.has_value()) { + CTL_ERR("Error: Could not get model_id from handle: " << model_handle_); + return; + } else { + CTL_INF("model_id: " << model_id.value()); + } } } - // Start server if it is not running - { - if (!commands::IsServerAlive(host_, port_)) { - CLI_LOG("Starting server ..."); - commands::ServerStartCmd ssc; - if (!ssc.Exec(host_, port_)) { - return; + try { + auto model_entry = modellist_handler.GetModelInfo(*model_id); + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + auto mc = yaml_handler.GetModelConfig(); + + // Check if engine existed. If not, download it + { + auto required_engine = engine_service_.GetEngineInfo(mc.engine); + if (!required_engine.has_value()) { + throw std::runtime_error("Engine not found: " + mc.engine); + } + if (required_engine.value().status == EngineService::kIncompatible) { + throw std::runtime_error("Engine " + mc.engine + " is incompatible"); + } + if (required_engine.value().status == EngineService::kNotInstalled) { + engine_service_.InstallEngine(mc.engine); } } - } - config::YamlHandler yaml_handler; - yaml_handler.ModelConfigFromFile( - file_manager_utils::GetModelsContainerPath().string() + "/" + model_file + - ".yaml"); - auto mc = yaml_handler.GetModelConfig(); + // Start server if it is not running + { + if (!commands::IsServerAlive(host_, port_)) { + CLI_LOG("Starting server ..."); + commands::ServerStartCmd ssc; + if (!ssc.Exec(host_, port_)) { + return; + } + } + } - // Always start model if not llamacpp - // If it is llamacpp, then check model status first - { - if ((mc.engine.find("llamacpp") == std::string::npos) || - !commands::ModelStatusCmd().IsLoaded(host_, port_, mc)) { - ModelStartCmd msc(host_, port_, mc); - if (!msc.Exec()) { - return; + // Always start model if not llamacpp + // If it is llamacpp, then check model status first + { + if ((mc.engine.find("llamacpp") == std::string::npos) || + !commands::ModelStatusCmd().IsLoaded(host_, port_, mc)) { + if (!ModelStartCmd().Exec(host_, port_, mc)) { + return; + } } } - } - // Chat - { - ChatCmd cc(host_, port_, mc); - cc.Exec(""); + // Chat + ChatCmd().Exec(host_, port_, mc, ""); + } catch (const std::exception& e) { + CLI_LOG("Fail to run model with ID '" + model_handle_ + "': " + e.what()); } } }; // namespace commands diff --git a/engine/commands/run_cmd.h b/engine/commands/run_cmd.h index c862926a6..136800102 100644 --- a/engine/commands/run_cmd.h +++ b/engine/commands/run_cmd.h @@ -6,10 +6,10 @@ namespace commands { class RunCmd { public: - explicit RunCmd(std::string host, int port, std::string model_id) + explicit RunCmd(std::string host, int port, std::string model_handle) : host_{std::move(host)}, port_{port}, - model_id_{std::move(model_id)}, + model_handle_{std::move(model_handle)}, model_service_{ModelService()} {}; void Exec(); @@ -17,7 +17,7 @@ class RunCmd { private: std::string host_; int port_; - std::string model_id_; + std::string model_handle_; ModelService model_service_; EngineService engine_service_; diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 74410db52..a65114ca7 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -58,7 +58,115 @@ struct ModelConfig { int n_probs = 0; int min_keep = 0; std::string grammar; + + void FromJson(const Json::Value& json) { + // do now allow to update ID and model field because it is unique identifier + // if (json.isMember("id")) + // id = json["id"].asString(); + if (json.isMember("name")) + name = json["name"].asString(); + // if (json.isMember("model")) + // model = json["model"].asString(); + if (json.isMember("version")) + version = json["version"].asString(); + if (json.isMember("stop") && json["stop"].isArray()) { + stop.clear(); + for (const auto& s : json["stop"]) { + stop.push_back(s.asString()); + } + } + + if (json.isMember("stream")) + stream = json["stream"].asBool(); + if (json.isMember("top_p")) + top_p = json["top_p"].asFloat(); + if (json.isMember("temperature")) + temperature = json["temperature"].asFloat(); + if (json.isMember("frequency_penalty")) + frequency_penalty = json["frequency_penalty"].asFloat(); + if (json.isMember("presence_penalty")) + presence_penalty = json["presence_penalty"].asFloat(); + if (json.isMember("max_tokens")) + max_tokens = json["max_tokens"].asInt(); + if (json.isMember("seed")) + seed = json["seed"].asInt(); + if (json.isMember("dynatemp_range")) + dynatemp_range = json["dynatemp_range"].asFloat(); + if (json.isMember("dynatemp_exponent")) + dynatemp_exponent = json["dynatemp_exponent"].asFloat(); + if (json.isMember("top_k")) + top_k = json["top_k"].asInt(); + if (json.isMember("min_p")) + min_p = json["min_p"].asFloat(); + if (json.isMember("tfs_z")) + tfs_z = json["tfs_z"].asFloat(); + if (json.isMember("typ_p")) + typ_p = json["typ_p"].asFloat(); + if (json.isMember("repeat_last_n")) + repeat_last_n = json["repeat_last_n"].asInt(); + if (json.isMember("repeat_penalty")) + repeat_penalty = json["repeat_penalty"].asFloat(); + if (json.isMember("mirostat")) + mirostat = json["mirostat"].asBool(); + if (json.isMember("mirostat_tau")) + mirostat_tau = json["mirostat_tau"].asFloat(); + if (json.isMember("mirostat_eta")) + mirostat_eta = json["mirostat_eta"].asFloat(); + if (json.isMember("penalize_nl")) + penalize_nl = json["penalize_nl"].asBool(); + if (json.isMember("ignore_eos")) + ignore_eos = json["ignore_eos"].asBool(); + if (json.isMember("n_probs")) + n_probs = json["n_probs"].asInt(); + if (json.isMember("min_keep")) + min_keep = json["min_keep"].asInt(); + if (json.isMember("ngl")) + ngl = json["ngl"].asInt(); + if (json.isMember("ctx_len")) + ctx_len = json["ctx_len"].asInt(); + if (json.isMember("engine")) + engine = json["engine"].asString(); + if (json.isMember("prompt_template")) + prompt_template = json["prompt_template"].asString(); + if (json.isMember("system_template")) + system_template = json["system_template"].asString(); + if (json.isMember("user_template")) + user_template = json["user_template"].asString(); + if (json.isMember("ai_template")) + ai_template = json["ai_template"].asString(); + if (json.isMember("os")) + os = json["os"].asString(); + if (json.isMember("gpu_arch")) + gpu_arch = json["gpu_arch"].asString(); + if (json.isMember("quantization_method")) + quantization_method = json["quantization_method"].asString(); + if (json.isMember("precision")) + precision = json["precision"].asString(); + + if (json.isMember("files") && json["files"].isArray()) { + files.clear(); + for (const auto& file : json["files"]) { + files.push_back(file.asString()); + } + } + + if (json.isMember("created")) + created = json["created"].asUInt64(); + if (json.isMember("object")) + object = json["object"].asString(); + if (json.isMember("owned_by")) + owned_by = json["owned_by"].asString(); + if (json.isMember("text_model")) + text_model = json["text_model"].asBool(); + + if (engine == "cortex.tensorrt-llm") { + if (json.isMember("trtllm_version")) + trtllm_version = json["trtllm_version"].asString(); + if (json.isMember("tp")) + tp = json["tp"].asInt(); + } + } Json::Value ToJson() const { Json::Value obj; diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index d64104197..74155a316 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -14,6 +14,7 @@ #include "commands/model_pull_cmd.h" #include "commands/model_start_cmd.h" #include "commands/model_stop_cmd.h" +#include "commands/model_upd_cmd.h" #include "commands/run_cmd.h" #include "commands/server_start_cmd.h" #include "commands/server_stop_cmd.h" @@ -131,17 +132,10 @@ void CommandLineParser::SetupCommonCommands() { CLI_LOG(chat_cmd->help()); return; } - commands::CmdInfo ci(cml_data_.model_id); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - config::YamlHandler yaml_handler; - yaml_handler.ModelConfigFromFile( - file_manager_utils::GetModelsContainerPath().string() + "/" + - model_file + ".yaml"); - commands::ChatCmd cc(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - yaml_handler.GetModelConfig()); - cc.Exec(cml_data_.msg); + + commands::ChatCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), cml_data_.model_id, + cml_data_.msg); }); } @@ -177,17 +171,9 @@ void CommandLineParser::SetupModelCommands() { CLI_LOG(model_start_cmd->help()); return; }; - commands::CmdInfo ci(cml_data_.model_id); - std::string model_file = - ci.branch == "main" ? ci.model_name : ci.model_name + "-" + ci.branch; - config::YamlHandler yaml_handler; - yaml_handler.ModelConfigFromFile( - file_manager_utils::GetModelsContainerPath().string() + "/" + - model_file + ".yaml"); - commands::ModelStartCmd msc(cml_data_.config.apiServerHost, - std::stoi(cml_data_.config.apiServerPort), - yaml_handler.GetModelConfig()); - msc.Exec(); + commands::ModelStartCmd().Exec(cml_data_.config.apiServerHost, + std::stoi(cml_data_.config.apiServerPort), + cml_data_.model_id); }); auto stop_model_cmd = @@ -271,10 +257,8 @@ void CommandLineParser::SetupModelCommands() { commands::ModelAliasCmd mdc; mdc.Exec(cml_data_.model_id, cml_data_.model_alias); }); - - auto model_update_cmd = - models_cmd->add_subcommand("update", "Update configuration of a model"); - model_update_cmd->group(kSubcommands); + // Model update parameters comment + ModelUpdate(models_cmd); std::string model_path; auto model_import_cmd = models_cmd->add_subcommand( @@ -463,3 +447,69 @@ void CommandLineParser::EngineGet(CLI::App* parent) { [engine_name] { commands::EngineGetCmd().Exec(engine_name); }); } } + +void CommandLineParser::ModelUpdate(CLI::App* parent) { + auto model_update_cmd = + parent->add_subcommand("update", "Update configuration of a model"); + model_update_cmd->group(kSubcommands); + model_update_cmd->add_option("--model_id", cml_data_.model_id, "Model ID") + ->required(); + + // Add options dynamically + std::vector option_names = {"name", + "model", + "version", + "stop", + "top_p", + "temperature", + "frequency_penalty", + "presence_penalty", + "max_tokens", + "stream", + "ngl", + "ctx_len", + "engine", + "prompt_template", + "system_template", + "user_template", + "ai_template", + "os", + "gpu_arch", + "quantization_method", + "precision", + "tp", + "trtllm_version", + "text_model", + "files", + "created", + "object", + "owned_by", + "seed", + "dynatemp_range", + "dynatemp_exponent", + "top_k", + "min_p", + "tfs_z", + "typ_p", + "repeat_last_n", + "repeat_penalty", + "mirostat", + "mirostat_tau", + "mirostat_eta", + "penalize_nl", + "ignore_eos", + "n_probs", + "min_keep", + "grammar"}; + + for (const auto& option_name : option_names) { + model_update_cmd->add_option("--" + option_name, + cml_data_.model_update_options[option_name], + option_name); + } + + model_update_cmd->callback([this]() { + commands::ModelUpdCmd command(cml_data_.model_id); + command.Exec(cml_data_.model_update_options); + }); +} \ No newline at end of file diff --git a/engine/controllers/command_line_parser.h b/engine/controllers/command_line_parser.h index 87a8063fd..aaa24e064 100644 --- a/engine/controllers/command_line_parser.h +++ b/engine/controllers/command_line_parser.h @@ -1,9 +1,9 @@ #pragma once #include "CLI/CLI.hpp" +#include "commands/model_upd_cmd.h" #include "services/engine_service.h" #include "utils/config_yaml_utils.h" - class CommandLineParser { public: CommandLineParser(); @@ -11,13 +11,13 @@ class CommandLineParser { private: void SetupCommonCommands(); - + void SetupInferenceCommands(); - + void SetupModelCommands(); - + void SetupEngineCommands(); - + void SetupSystemCommands(); void EngineInstall(CLI::App* parent, const std::string& engine_name, @@ -26,10 +26,11 @@ class CommandLineParser { void EngineUninstall(CLI::App* parent, const std::string& engine_name); void EngineGet(CLI::App* parent); + void ModelUpdate(CLI::App* parent); CLI::App app_; EngineService engine_service_; - struct CmlData{ + struct CmlData { std::string model_id; std::string msg; std::string model_alias; @@ -40,6 +41,7 @@ class CommandLineParser { bool check_upd = true; int port; config_yaml_utils::CortexConfig config; + std::unordered_map model_update_options; }; CmlData cml_data_; }; diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index e857d89da..4660b50e5 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -114,7 +114,7 @@ void Models::GetModel( auto model_config = yaml_handler.GetModelConfig(); Json::Value obj = model_config.ToJson(); - + data.append(std::move(obj)); ret["data"] = data; ret["result"] = "OK"; @@ -155,7 +155,49 @@ void Models::DeleteModel(const HttpRequestPtr& req, callback(resp); } } +void Models::UpdateModel( + const HttpRequestPtr& req, + std::function&& callback) const { + if (!http_util::HasFieldInReq(req, callback, "modelId")) { + return; + } + auto model_id = (*(req->getJsonObject())).get("modelId", "").asString(); + auto json_body = *(req->getJsonObject()); + try { + modellist_utils::ModelListUtils model_list_utils; + auto model_entry = model_list_utils.GetModelInfo(model_id); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + config::ModelConfig model_config = yaml_handler.GetModelConfig(); + model_config.FromJson(json_body); + yaml_handler.UpdateModelConfig(model_config); + yaml_handler.WriteYamlFile(model_entry.path_to_model_yaml); + std::string message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + LOG_INFO << message; + Json::Value ret; + ret["result"] = "Updated successfully!"; + ret["modelHandle"] = model_id; + ret["message"] = message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + + } catch (const std::exception& e) { + std::string error_message = + "Error updating with model_id '" + model_id + "': " + e.what(); + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Updated failed!"; + ret["modelHandle"] = model_id; + ret["message"] = error_message; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } +} void Models::ImportModel( const HttpRequestPtr& req, std::function&& callback) const { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 4ae1ff41f..8d652c86a 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -15,6 +15,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::PullModel, "/pull", Post); METHOD_ADD(Models::ListModel, "/list", Get); METHOD_ADD(Models::GetModel, "/get", Post); + METHOD_ADD(Models::UpdateModel, "/update/", Post); METHOD_ADD(Models::ImportModel, "/import", Post); METHOD_ADD(Models::DeleteModel, "/{1}", Delete); METHOD_ADD(Models::SetModelAlias, "/alias", Post); @@ -26,8 +27,11 @@ class Models : public drogon::HttpController { std::function&& callback) const; void GetModel(const HttpRequestPtr& req, std::function&& callback) const; - void ImportModel(const HttpRequestPtr& req, - std::function&& callback) const; + void UpdateModel(const HttpRequestPtr& req, + std::function&& callback) const; + void ImportModel( + const HttpRequestPtr& req, + std::function&& callback) const; void DeleteModel(const HttpRequestPtr& req, std::function&& callback, const std::string& model_id) const; diff --git a/engine/main.cc b/engine/main.cc index e7fe9bd22..c461342c9 100644 --- a/engine/main.cc +++ b/engine/main.cc @@ -29,8 +29,8 @@ void RunServer() { auto config = file_manager_utils::GetCortexConfig(); - LOG_INFO << "Host: " << config.apiServerHost - << " Port: " << config.apiServerPort << "\n"; + std::cout << "Host: " << config.apiServerHost + << " Port: " << config.apiServerPort << "\n"; // Create logs/ folder and setup log to file std::filesystem::create_directories( @@ -46,6 +46,8 @@ void RunServer() { asyncFileLogger.output_(msg, len); }, [&]() { asyncFileLogger.flush(); }); + LOG_INFO << "Host: " << config.apiServerHost + << " Port: " << config.apiServerPort << "\n"; // Number of cortex.cpp threads // if (argc > 1) { // thread_num = std::atoi(argv[1]); diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index 471a70013..e3754fa76 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -133,7 +133,7 @@ void DownloadService::Download(const std::string& download_id, << " need to be downloaded."); std::cout << "Continue download [Y/n]: " << std::flush; std::string answer{""}; - std::cin >> answer; + std::getline(std::cin, answer); if (answer == "Y" || answer == "y" || answer.empty()) { mode = "ab"; CLI_LOG("Resuming download.."); @@ -146,7 +146,7 @@ void DownloadService::Download(const std::string& download_id, std::cout << "Re-download? [Y/n]: " << std::flush; std::string answer = ""; - std::cin >> answer; + std::getline(std::cin, answer); if (answer == "Y" || answer == "y" || answer.empty()) { CLI_LOG("Re-downloading.."); } else { diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 29575dfab..485bec869 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -2,15 +2,18 @@ #include #include #include +#include "config/gguf_parser.h" +#include "config/yaml_config.h" #include "utils/cli_selection_utils.h" #include "utils/cortexso_parser.h" #include "utils/file_manager_utils.h" #include "utils/huggingface_utils.h" #include "utils/logging_utils.h" -#include "utils/model_callback_utils.h" +#include "utils/modellist_utils.h" #include "utils/string_utils.h" -void ModelService::DownloadModel(const std::string& input) { +std::optional ModelService::DownloadModel( + const std::string& input) { if (input.empty()) { throw std::runtime_error( "Input must be Cortex Model Hub handle or HuggingFace url!"); @@ -32,15 +35,15 @@ void ModelService::DownloadModel(const std::string& input) { return DownloadModelByModelName(model_name); } - DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); CLI_LOG("Model " << model_name << " downloaded successfully!") - return; + return DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); } return DownloadModelByModelName(input); } -void ModelService::DownloadModelByModelName(const std::string& modelName) { +std::optional ModelService::DownloadModelByModelName( + const std::string& modelName) { try { auto branches = huggingface_utils::GetModelRepositoryBranches("cortexso", modelName); @@ -52,12 +55,13 @@ void ModelService::DownloadModelByModelName(const std::string& modelName) { } if (options.empty()) { CLI_LOG("No variant found"); - return; + return std::nullopt; } auto selection = cli_selection_utils::PrintSelection(options); - DownloadModelFromCortexso(modelName, selection.value()); + return DownloadModelFromCortexso(modelName, selection.value()); } catch (const std::runtime_error& e) { CLI_LOG("Error downloading model, " << e.what()); + return std::nullopt; } } @@ -87,7 +91,8 @@ std::optional ModelService::GetDownloadedModel( return std::nullopt; } -void ModelService::DownloadModelByDirectUrl(const std::string& url) { +std::optional ModelService::DownloadModelByDirectUrl( + const std::string& url) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { @@ -95,12 +100,19 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { url_obj.pathParams[2] = "resolve"; } } - + auto author{url_obj.pathParams[0]}; auto model_id{url_obj.pathParams[1]}; auto file_name{url_obj.pathParams.back()}; - auto local_path = - file_manager_utils::GetModelsContainerPath() / model_id / model_id; + if (author == "cortexso") { + return DownloadModelFromCortexso(model_id); + } + + std::string huggingFaceHost{kHuggingFaceHost}; + std::string unique_model_id{huggingFaceHost + "/" + author + "/" + model_id + + "/" + file_name}; + auto local_path{file_manager_utils::GetModelsContainerPath() / + "huggingface.co" / author / model_id / file_name}; try { std::filesystem::create_directories(local_path.parent_path()); @@ -115,33 +127,68 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { auto downloadTask{DownloadTask{.id = model_id, .type = DownloadType::Model, .items = {DownloadItem{ - .id = url_obj.pathParams.back(), + .id = unique_model_id, .downloadUrl = download_url, .localPath = local_path, }}}}; - auto on_finished = [](const DownloadTask& finishedTask) { + auto on_finished = [&](const DownloadTask& finishedTask) { CLI_LOG("Model " << finishedTask.id << " downloaded successfully!") auto gguf_download_item = finishedTask.items[0]; - model_callback_utils::ParseGguf(gguf_download_item); + ParseGguf(gguf_download_item, author); }; download_service_.AddDownloadTask(downloadTask, on_finished); + return unique_model_id; } -void ModelService::DownloadModelFromCortexso(const std::string& name, - const std::string& branch) { +std::optional ModelService::DownloadModelFromCortexso( + const std::string& name, const std::string& branch) { + auto downloadTask = cortexso_parser::getDownloadTask(name, branch); if (downloadTask.has_value()) { - DownloadService().AddDownloadTask(downloadTask.value(), - model_callback_utils::DownloadModelCb); - CLI_LOG("Model " << name << " downloaded successfully!") + std::string model_id{name + ":" + branch}; + DownloadService().AddDownloadTask( + downloadTask.value(), [&](const DownloadTask& finishedTask) { + const DownloadItem* model_yml_item = nullptr; + auto need_parse_gguf = true; + + for (const auto& item : finishedTask.items) { + if (item.localPath.filename().string() == "model.yml") { + model_yml_item = &item; + } + } + + if (model_yml_item != nullptr) { + auto url_obj = + url_parser::FromUrlString(model_yml_item->downloadUrl); + CTL_INF("Adding model to modellist with branch: " << branch); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile( + model_yml_item->localPath.string()); + auto mc = yaml_handler.GetModelConfig(); + + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = model_yml_item->localPath.string(), + .model_alias = model_id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); + } + }); + + CLI_LOG("Model " << model_id << " downloaded successfully!") + return model_id; } else { CTL_ERR("Model not found"); + return std::nullopt; } } -void ModelService::DownloadHuggingFaceGgufModel( +std::optional ModelService::DownloadHuggingFaceGgufModel( const std::string& author, const std::string& modelName, std::optional fileName) { auto repo_info = @@ -149,7 +196,7 @@ void ModelService::DownloadHuggingFaceGgufModel( if (!repo_info.has_value()) { // throw is better? CTL_ERR("Model not found"); - return; + return std::nullopt; } if (!repo_info->gguf.has_value()) { @@ -168,5 +215,40 @@ void ModelService::DownloadHuggingFaceGgufModel( auto download_url = huggingface_utils::GetDownloadableUrl(author, modelName, selection.value()); - DownloadModelByDirectUrl(download_url); + return DownloadModelByDirectUrl(download_url); +} + +void ModelService::ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author) const { + + config::GGUFHandler gguf_handler; + config::YamlHandler yaml_handler; + gguf_handler.Parse(ggufDownloadItem.localPath.string()); + config::ModelConfig model_config = gguf_handler.GetModelConfig(); + model_config.id = + ggufDownloadItem.localPath.parent_path().filename().string(); + model_config.files = {ggufDownloadItem.localPath.string()}; + yaml_handler.UpdateModelConfig(model_config); + + auto yaml_path{ggufDownloadItem.localPath}; + auto yaml_name = yaml_path.replace_extension(".yml"); + + if (!std::filesystem::exists(yaml_path)) { + yaml_handler.WriteYamlFile(yaml_path.string()); + } + + auto url_obj = url_parser::FromUrlString(ggufDownloadItem.downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + + auto author_id = author.has_value() ? author.value() : "cortexso"; + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = yaml_name.string(), + .model_alias = ggufDownloadItem.id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry, true); } diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 06212aaee..4237f1b17 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -8,27 +8,34 @@ class ModelService { public: ModelService() : download_service_{DownloadService()} {}; - void DownloadModel(const std::string& input); + /** + * Return model id if download successfully + */ + std::optional DownloadModel(const std::string& input); std::optional GetDownloadedModel( const std::string& modelId) const; private: - void DownloadModelByDirectUrl(const std::string& url); + std::optional DownloadModelByDirectUrl(const std::string& url); - void DownloadModelFromCortexso(const std::string& name, - const std::string& branch = "main"); + std::optional DownloadModelFromCortexso( + const std::string& name, const std::string& branch = "main"); /** * Handle downloading model which have following pattern: author/model_name */ - void DownloadHuggingFaceGgufModel(const std::string& author, - const std::string& modelName, - std::optional fileName); + std::optional DownloadHuggingFaceGgufModel( + const std::string& author, const std::string& modelName, + std::optional fileName); - void DownloadModelByModelName(const std::string& modelName); + std::optional DownloadModelByModelName( + const std::string& modelName); DownloadService download_service_; + void ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author = nullptr) const; + constexpr auto static kHuggingFaceHost = "huggingface.co"; }; diff --git a/engine/test/components/test_modellist_utils.cc b/engine/test/components/test_modellist_utils.cc index 2a7abc05a..d1dbf91e3 100644 --- a/engine/test/components/test_modellist_utils.cc +++ b/engine/test/components/test_modellist_utils.cc @@ -19,6 +19,7 @@ class ModelListUtilsTestSuite : public ::testing::Test { void TearDown() { // Clean up the temporary directory + std::remove((file_manager_utils::GetModelsContainerPath() / "model.list").string().c_str()); } TEST_F(ModelListUtilsTestSuite, TestAddModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel)); @@ -120,4 +121,14 @@ TEST_F(ModelListUtilsTestSuite, TestUpdateModelAlias) { // Clean up model_list_.DeleteModelEntry("test_model_id"); model_list_.DeleteModelEntry("another_model_id"); +} + +TEST_F(ModelListUtilsTestSuite, TestHasModel) { + model_list_.AddModelEntry(kTestModel); + + EXPECT_TRUE(model_list_.HasModel("test_model_id")); + EXPECT_TRUE(model_list_.HasModel("test_alias")); + EXPECT_FALSE(model_list_.HasModel("non_existent_model")); + // Clean up + model_list_.DeleteModelEntry("test_model_id"); } \ No newline at end of file diff --git a/engine/utils/cli_selection_utils.h b/engine/utils/cli_selection_utils.h index d3848c5bb..0c2453478 100644 --- a/engine/utils/cli_selection_utils.h +++ b/engine/utils/cli_selection_utils.h @@ -20,7 +20,7 @@ inline std::optional PrintSelection( std::string selection{""}; PrintMenu(options); std::cout << "Select an option (" << 1 << "-" << options.size() << "): "; - std::cin >> selection; + std::getline(std::cin, selection); if (selection.empty()) { return std::nullopt; diff --git a/engine/utils/cortexso_parser.h b/engine/utils/cortexso_parser.h index d4e85bee9..af3372022 100644 --- a/engine/utils/cortexso_parser.h +++ b/engine/utils/cortexso_parser.h @@ -1,5 +1,4 @@ #include -#include #include #include @@ -7,57 +6,57 @@ #include #include "httplib.h" #include "utils/file_manager_utils.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" namespace cortexso_parser { -constexpr static auto kHuggingFaceHost = "https://huggingface.co"; +constexpr static auto kHuggingFaceHost = "huggingface.co"; inline std::optional getDownloadTask( const std::string& modelId, const std::string& branch = "main") { using namespace nlohmann; - std::ostringstream oss; - oss << "/api/models/cortexso/" << modelId << "/tree/" << branch; - const std::string url = oss.str(); + url_parser::Url url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"api", "models", "cortexso", modelId, "tree", branch}}; - std::ostringstream repoAndModelId; - repoAndModelId << "cortexso/" << modelId; - const std::string repoAndModelIdStr = repoAndModelId.str(); - - httplib::Client cli(kHuggingFaceHost); - if (auto res = cli.Get(url)) { + httplib::Client cli(url.GetProtocolAndHost()); + if (auto res = cli.Get(url.GetPathAndQuery())) { if (res->status == httplib::StatusCode::OK_200) { try { auto jsonResponse = json::parse(res->body); - std::vector downloadItems{}; - std::filesystem::path model_container_path = - file_manager_utils::GetModelsContainerPath() / modelId; + std::vector download_items{}; + auto model_container_path = + file_manager_utils::GetModelsContainerPath() / "cortex.so" / + modelId / branch; file_manager_utils::CreateDirectoryRecursively( model_container_path.string()); for (const auto& [key, value] : jsonResponse.items()) { - std::ostringstream downloadUrlOutput; auto path = value["path"].get(); if (path == ".gitattributes" || path == ".gitignore" || path == "README.md") { continue; } - downloadUrlOutput << kHuggingFaceHost << "/" << repoAndModelIdStr - << "/resolve/" << branch << "/" << path; - const std::string download_url = downloadUrlOutput.str(); - auto local_path = model_container_path / path; + url_parser::Url download_url = { + .protocol = "https", + .host = kHuggingFaceHost, + .pathParams = {"cortexso", modelId, "resolve", branch, path}}; - downloadItems.push_back(DownloadItem{.id = path, - .downloadUrl = download_url, - .localPath = local_path}); + auto local_path = model_container_path / path; + download_items.push_back( + DownloadItem{.id = path, + .downloadUrl = download_url.ToFullPath(), + .localPath = local_path}); } - DownloadTask downloadTask{ + DownloadTask download_tasks{ .id = branch == "main" ? modelId : modelId + "-" + branch, .type = DownloadType::Model, - .items = downloadItems}; + .items = download_items}; - return downloadTask; + return download_tasks; } catch (const json::parse_error& e) { CTL_ERR("JSON parse error: {}" << e.what()); } diff --git a/engine/utils/model_callback_utils.h b/engine/utils/model_callback_utils.h index 3a3b0f288..c6e98dd48 100644 --- a/engine/utils/model_callback_utils.h +++ b/engine/utils/model_callback_utils.h @@ -6,27 +6,14 @@ #include "config/gguf_parser.h" #include "config/yaml_config.h" #include "services/download_service.h" +#include "utils/huggingface_utils.h" #include "utils/logging_utils.h" +#include "utils/modellist_utils.h" namespace model_callback_utils { -inline void WriteYamlOutput(const DownloadItem& modelYmlDownloadItem) { - config::YamlHandler handler; - handler.ModelConfigFromFile(modelYmlDownloadItem.localPath.string()); - config::ModelConfig model_config = handler.GetModelConfig(); - model_config.id = - modelYmlDownloadItem.localPath.parent_path().filename().string(); - - CTL_INF("Updating model config in " - << modelYmlDownloadItem.localPath.string()); - handler.UpdateModelConfig(model_config); - std::string yaml_filename{model_config.id + ".yaml"}; - std::filesystem::path yaml_output = - modelYmlDownloadItem.localPath.parent_path().parent_path() / - yaml_filename; - handler.WriteYamlFile(yaml_output.string()); -} -inline void ParseGguf(const DownloadItem& ggufDownloadItem) { +inline void ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author = nullptr) { config::GGUFHandler gguf_handler; config::YamlHandler yaml_handler; gguf_handler.Parse(ggufDownloadItem.localPath.string()); @@ -36,17 +23,27 @@ inline void ParseGguf(const DownloadItem& ggufDownloadItem) { model_config.files = {ggufDownloadItem.localPath.string()}; yaml_handler.UpdateModelConfig(model_config); - std::string yaml_filename{model_config.id + ".yaml"}; - std::filesystem::path yaml_output = - ggufDownloadItem.localPath.parent_path().parent_path() / yaml_filename; - std::filesystem::path yaml_path(ggufDownloadItem.localPath.parent_path() / - "model.yml"); - if (!std::filesystem::exists(yaml_output)) { // if model.yml doesn't exist - yaml_handler.WriteYamlFile(yaml_output.string()); - } + auto yaml_path{ggufDownloadItem.localPath}; + auto yaml_name = yaml_path.replace_extension(".yml"); + if (!std::filesystem::exists(yaml_path)) { yaml_handler.WriteYamlFile(yaml_path.string()); } + + auto url_obj = url_parser::FromUrlString(ggufDownloadItem.downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + + auto author_id = author.has_value() ? author.value() : "cortexso"; + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = model_config.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = yaml_name.string(), + .model_alias = model_config.id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); } inline void DownloadModelCb(const DownloadTask& finishedTask) { @@ -67,12 +64,27 @@ inline void DownloadModelCb(const DownloadTask& finishedTask) { } } - if (model_yml_di != nullptr) { - WriteYamlOutput(*model_yml_di); - } - if (need_parse_gguf && gguf_di != nullptr) { ParseGguf(*gguf_di); } + + if (model_yml_di != nullptr) { + auto url_obj = url_parser::FromUrlString(model_yml_di->downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(model_yml_di->localPath.string()); + auto mc = yaml_handler.GetModelConfig(); + + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = mc.name, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = model_yml_di->localPath.string(), + .model_alias = mc.name, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); + } } } // namespace model_callback_utils diff --git a/engine/utils/modellist_utils.cc b/engine/utils/modellist_utils.cc index 261bf58d5..d577519f3 100644 --- a/engine/utils/modellist_utils.cc +++ b/engine/utils/modellist_utils.cc @@ -3,10 +3,10 @@ #include #include #include -#include #include #include #include "file_manager_utils.h" + namespace modellist_utils { const std::string ModelListUtils::kModelListPath = (file_manager_utils::GetModelsContainerPath() / @@ -208,7 +208,8 @@ bool ModelListUtils::UpdateModelAlias(const std::string& model_id, }); bool check_alias_unique = std::none_of( entries.begin(), entries.end(), [&](const ModelEntry& entry) { - return (entry.model_id == new_model_alias && entry.model_id != model_id) || + return (entry.model_id == new_model_alias && + entry.model_id != model_id) || entry.model_alias == new_model_alias; }); if (it != entries.end() && check_alias_unique) { @@ -237,4 +238,19 @@ bool ModelListUtils::DeleteModelEntry(const std::string& identifier) { } return false; // Entry not found or not in READY state } -} // namespace modellist_utils \ No newline at end of file + +bool ModelListUtils::HasModel(const std::string& identifier) const { + std::lock_guard lock(mutex_); + auto entries = LoadModelList(); + auto it = std::find_if( + entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) { + return entry.model_id == identifier || entry.model_alias == identifier; + }); + + if (it != entries.end()) { + return true; + } else { + return false; + } +} +} // namespace modellist_utils diff --git a/engine/utils/modellist_utils.h b/engine/utils/modellist_utils.h index 75a41d880..113591f25 100644 --- a/engine/utils/modellist_utils.h +++ b/engine/utils/modellist_utils.h @@ -1,9 +1,10 @@ #pragma once + #include #include #include #include -#include "logging_utils.h" + namespace modellist_utils { enum class ModelStatus { READY, RUNNING }; @@ -22,7 +23,7 @@ class ModelListUtils { private: mutable std::mutex mutex_; // For thread safety - bool IsUnique(const std::vector& entries, + bool IsUnique(const std::vector& entries, const std::string& model_id, const std::string& model_alias) const; void SaveModelList(const std::vector& entries) const; @@ -40,6 +41,8 @@ class ModelListUtils { bool UpdateModelEntry(const std::string& identifier, const ModelEntry& updated_entry); bool DeleteModelEntry(const std::string& identifier); - bool UpdateModelAlias(const std::string& model_id, const std::string& model_alias); + bool UpdateModelAlias(const std::string& model_id, + const std::string& model_alias); + bool HasModel(const std::string& identifier) const; }; -} // namespace modellist_utils \ No newline at end of file +} // namespace modellist_utils diff --git a/engine/utils/url_parser.h b/engine/utils/url_parser.h index 6a6e01179..90b62143e 100644 --- a/engine/utils/url_parser.h +++ b/engine/utils/url_parser.h @@ -1,3 +1,5 @@ +#pragma once + #include #include #include @@ -54,6 +56,10 @@ struct Url { } return path; }; + + std::string ToFullPath() const { + return GetProtocolAndHost() + GetPathAndQuery(); + } }; const std::regex url_regex(