diff --git a/engine/commands/model_upd_cmd.cc b/engine/commands/model_upd_cmd.cc new file mode 100644 index 000000000..eb7edd3df --- /dev/null +++ b/engine/commands/model_upd_cmd.cc @@ -0,0 +1,127 @@ +#include "model_upd_cmd.h" + +#include "utils/logging_utils.h" + +namespace commands { + +ModelUpdCmd::ModelUpdCmd(std::string model_handle) + : model_handle_(std::move(model_handle)) {} + +void ModelUpdCmd::Exec( + const std::unordered_map& options) { + try { + auto model_entry = model_list_utils_.GetModelInfo(model_handle_); + yaml_handler_.ModelConfigFromFile(model_entry.path_to_model_yaml); + model_config_ = yaml_handler_.GetModelConfig(); + + for (const auto& [key, value] : options) { + if (!value.empty()) { + UpdateConfig(key, value); + } + } + + yaml_handler_.UpdateModelConfig(model_config_); + yaml_handler_.WriteYamlFile(model_entry.path_to_model_yaml); + CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!"); + } catch (const std::exception& e) { + CLI_LOG("Failed to update model with model ID '" + model_handle_ + + "': " + e.what()); + } +} + +void ModelUpdCmd::UpdateConfig(const std::string& key, + const std::string& value) { + static const std::unordered_map< + std::string, + std::function> + updaters = { + {"name", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.name = v; + }}, + {"model", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.model = v; + }}, + {"version", + [](ModelUpdCmd* self, const std::string&, const std::string& v) { + self->model_config_.version = v; + }}, + {"stop", &ModelUpdCmd::UpdateVectorField}, + {"top_p", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField( + k, v, [self](float f) { self->model_config_.top_p = f; }); + }}, + {"temperature", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.temperature = f; + }); + }}, + {"frequency_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.frequency_penalty = f; + }); + }}, + {"presence_penalty", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.presence_penalty = f; + }); + }}, + {"max_tokens", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateNumericField(k, v, [self](float f) { + self->model_config_.max_tokens = static_cast(f); + }); + }}, + {"stream", + [](ModelUpdCmd* self, const std::string& k, const std::string& v) { + self->UpdateBooleanField( + k, v, [self](bool b) { self->model_config_.stream = b; }); + }}, + // Add more fields here... + }; + + if (auto it = updaters.find(key); it != updaters.end()) { + it->second(this, key, value); + LogUpdate(key, value); + } +} + +void ModelUpdCmd::UpdateVectorField(const std::string& key, + const std::string& value) { + std::vector tokens; + std::istringstream iss(value); + std::string token; + while (std::getline(iss, token, ',')) { + tokens.push_back(token); + } + model_config_.stop = tokens; +} + +void ModelUpdCmd::UpdateNumericField(const std::string& key, + const std::string& value, + std::function setter) { + try { + float numericValue = std::stof(value); + setter(numericValue); + } catch (const std::exception& e) { + CLI_LOG("Failed to parse numeric value for " << key << ": " << e.what()); + } +} + +void ModelUpdCmd::UpdateBooleanField(const std::string& key, + const std::string& value, + std::function setter) { + bool boolValue = (value == "true" || value == "1"); + setter(boolValue); +} + +void ModelUpdCmd::LogUpdate(const std::string& key, const std::string& value) { + CLI_LOG("Updated " << key << " to: " << value); +} + +} // namespace commands \ No newline at end of file diff --git a/engine/commands/model_upd_cmd.h b/engine/commands/model_upd_cmd.h new file mode 100644 index 000000000..51f5a88d3 --- /dev/null +++ b/engine/commands/model_upd_cmd.h @@ -0,0 +1,30 @@ +#pragma once +#include +#include +#include +#include +#include +#include "config/model_config.h" +#include "utils/modellist_utils.h" +#include "config/yaml_config.h" +namespace commands { +class ModelUpdCmd { + public: + ModelUpdCmd(std::string model_handle); + void Exec(const std::unordered_map& options); + + private: + std::string model_handle_; + config::ModelConfig model_config_; + config::YamlHandler yaml_handler_; + modellist_utils::ModelListUtils model_list_utils_; + + void UpdateConfig(const std::string& key, const std::string& value); + void UpdateVectorField(const std::string& key, const std::string& value); + void UpdateNumericField(const std::string& key, const std::string& value, + std::function setter); + void UpdateBooleanField(const std::string& key, const std::string& value, + std::function setter); + void LogUpdate(const std::string& key, const std::string& value); +}; +} // namespace commands \ No newline at end of file diff --git a/engine/config/model_config.h b/engine/config/model_config.h index 74410db52..a65114ca7 100644 --- a/engine/config/model_config.h +++ b/engine/config/model_config.h @@ -58,7 +58,115 @@ struct ModelConfig { int n_probs = 0; int min_keep = 0; std::string grammar; + + void FromJson(const Json::Value& json) { + // do now allow to update ID and model field because it is unique identifier + // if (json.isMember("id")) + // id = json["id"].asString(); + if (json.isMember("name")) + name = json["name"].asString(); + // if (json.isMember("model")) + // model = json["model"].asString(); + if (json.isMember("version")) + version = json["version"].asString(); + if (json.isMember("stop") && json["stop"].isArray()) { + stop.clear(); + for (const auto& s : json["stop"]) { + stop.push_back(s.asString()); + } + } + + if (json.isMember("stream")) + stream = json["stream"].asBool(); + if (json.isMember("top_p")) + top_p = json["top_p"].asFloat(); + if (json.isMember("temperature")) + temperature = json["temperature"].asFloat(); + if (json.isMember("frequency_penalty")) + frequency_penalty = json["frequency_penalty"].asFloat(); + if (json.isMember("presence_penalty")) + presence_penalty = json["presence_penalty"].asFloat(); + if (json.isMember("max_tokens")) + max_tokens = json["max_tokens"].asInt(); + if (json.isMember("seed")) + seed = json["seed"].asInt(); + if (json.isMember("dynatemp_range")) + dynatemp_range = json["dynatemp_range"].asFloat(); + if (json.isMember("dynatemp_exponent")) + dynatemp_exponent = json["dynatemp_exponent"].asFloat(); + if (json.isMember("top_k")) + top_k = json["top_k"].asInt(); + if (json.isMember("min_p")) + min_p = json["min_p"].asFloat(); + if (json.isMember("tfs_z")) + tfs_z = json["tfs_z"].asFloat(); + if (json.isMember("typ_p")) + typ_p = json["typ_p"].asFloat(); + if (json.isMember("repeat_last_n")) + repeat_last_n = json["repeat_last_n"].asInt(); + if (json.isMember("repeat_penalty")) + repeat_penalty = json["repeat_penalty"].asFloat(); + if (json.isMember("mirostat")) + mirostat = json["mirostat"].asBool(); + if (json.isMember("mirostat_tau")) + mirostat_tau = json["mirostat_tau"].asFloat(); + if (json.isMember("mirostat_eta")) + mirostat_eta = json["mirostat_eta"].asFloat(); + if (json.isMember("penalize_nl")) + penalize_nl = json["penalize_nl"].asBool(); + if (json.isMember("ignore_eos")) + ignore_eos = json["ignore_eos"].asBool(); + if (json.isMember("n_probs")) + n_probs = json["n_probs"].asInt(); + if (json.isMember("min_keep")) + min_keep = json["min_keep"].asInt(); + if (json.isMember("ngl")) + ngl = json["ngl"].asInt(); + if (json.isMember("ctx_len")) + ctx_len = json["ctx_len"].asInt(); + if (json.isMember("engine")) + engine = json["engine"].asString(); + if (json.isMember("prompt_template")) + prompt_template = json["prompt_template"].asString(); + if (json.isMember("system_template")) + system_template = json["system_template"].asString(); + if (json.isMember("user_template")) + user_template = json["user_template"].asString(); + if (json.isMember("ai_template")) + ai_template = json["ai_template"].asString(); + if (json.isMember("os")) + os = json["os"].asString(); + if (json.isMember("gpu_arch")) + gpu_arch = json["gpu_arch"].asString(); + if (json.isMember("quantization_method")) + quantization_method = json["quantization_method"].asString(); + if (json.isMember("precision")) + precision = json["precision"].asString(); + + if (json.isMember("files") && json["files"].isArray()) { + files.clear(); + for (const auto& file : json["files"]) { + files.push_back(file.asString()); + } + } + + if (json.isMember("created")) + created = json["created"].asUInt64(); + if (json.isMember("object")) + object = json["object"].asString(); + if (json.isMember("owned_by")) + owned_by = json["owned_by"].asString(); + if (json.isMember("text_model")) + text_model = json["text_model"].asBool(); + + if (engine == "cortex.tensorrt-llm") { + if (json.isMember("trtllm_version")) + trtllm_version = json["trtllm_version"].asString(); + if (json.isMember("tp")) + tp = json["tp"].asInt(); + } + } Json::Value ToJson() const { Json::Value obj; diff --git a/engine/controllers/command_line_parser.cc b/engine/controllers/command_line_parser.cc index 6073cbbb3..74155a316 100644 --- a/engine/controllers/command_line_parser.cc +++ b/engine/controllers/command_line_parser.cc @@ -14,6 +14,7 @@ #include "commands/model_pull_cmd.h" #include "commands/model_start_cmd.h" #include "commands/model_stop_cmd.h" +#include "commands/model_upd_cmd.h" #include "commands/run_cmd.h" #include "commands/server_start_cmd.h" #include "commands/server_stop_cmd.h" @@ -256,10 +257,8 @@ void CommandLineParser::SetupModelCommands() { commands::ModelAliasCmd mdc; mdc.Exec(cml_data_.model_id, cml_data_.model_alias); }); - - auto model_update_cmd = - models_cmd->add_subcommand("update", "Update configuration of a model"); - model_update_cmd->group(kSubcommands); + // Model update parameters comment + ModelUpdate(models_cmd); std::string model_path; auto model_import_cmd = models_cmd->add_subcommand( @@ -373,6 +372,12 @@ void CommandLineParser::SetupSystemCommands() { update_cmd->group(kSystemGroup); update_cmd->add_option("-v", cml_data_.cortex_version, ""); update_cmd->callback([this] { +#if !defined(_WIN32) + if (getuid()) { + CLI_LOG("Error: Not root user. Please run with sudo."); + return; + } +#endif commands::CortexUpdCmd cuc; cuc.Exec(cml_data_.cortex_version); cml_data_.check_upd = false; @@ -442,3 +447,69 @@ void CommandLineParser::EngineGet(CLI::App* parent) { [engine_name] { commands::EngineGetCmd().Exec(engine_name); }); } } + +void CommandLineParser::ModelUpdate(CLI::App* parent) { + auto model_update_cmd = + parent->add_subcommand("update", "Update configuration of a model"); + model_update_cmd->group(kSubcommands); + model_update_cmd->add_option("--model_id", cml_data_.model_id, "Model ID") + ->required(); + + // Add options dynamically + std::vector option_names = {"name", + "model", + "version", + "stop", + "top_p", + "temperature", + "frequency_penalty", + "presence_penalty", + "max_tokens", + "stream", + "ngl", + "ctx_len", + "engine", + "prompt_template", + "system_template", + "user_template", + "ai_template", + "os", + "gpu_arch", + "quantization_method", + "precision", + "tp", + "trtllm_version", + "text_model", + "files", + "created", + "object", + "owned_by", + "seed", + "dynatemp_range", + "dynatemp_exponent", + "top_k", + "min_p", + "tfs_z", + "typ_p", + "repeat_last_n", + "repeat_penalty", + "mirostat", + "mirostat_tau", + "mirostat_eta", + "penalize_nl", + "ignore_eos", + "n_probs", + "min_keep", + "grammar"}; + + for (const auto& option_name : option_names) { + model_update_cmd->add_option("--" + option_name, + cml_data_.model_update_options[option_name], + option_name); + } + + model_update_cmd->callback([this]() { + commands::ModelUpdCmd command(cml_data_.model_id); + command.Exec(cml_data_.model_update_options); + }); +} \ No newline at end of file diff --git a/engine/controllers/command_line_parser.h b/engine/controllers/command_line_parser.h index 87a8063fd..aaa24e064 100644 --- a/engine/controllers/command_line_parser.h +++ b/engine/controllers/command_line_parser.h @@ -1,9 +1,9 @@ #pragma once #include "CLI/CLI.hpp" +#include "commands/model_upd_cmd.h" #include "services/engine_service.h" #include "utils/config_yaml_utils.h" - class CommandLineParser { public: CommandLineParser(); @@ -11,13 +11,13 @@ class CommandLineParser { private: void SetupCommonCommands(); - + void SetupInferenceCommands(); - + void SetupModelCommands(); - + void SetupEngineCommands(); - + void SetupSystemCommands(); void EngineInstall(CLI::App* parent, const std::string& engine_name, @@ -26,10 +26,11 @@ class CommandLineParser { void EngineUninstall(CLI::App* parent, const std::string& engine_name); void EngineGet(CLI::App* parent); + void ModelUpdate(CLI::App* parent); CLI::App app_; EngineService engine_service_; - struct CmlData{ + struct CmlData { std::string model_id; std::string msg; std::string model_alias; @@ -40,6 +41,7 @@ class CommandLineParser { bool check_upd = true; int port; config_yaml_utils::CortexConfig config; + std::unordered_map model_update_options; }; CmlData cml_data_; }; diff --git a/engine/controllers/models.cc b/engine/controllers/models.cc index e857d89da..4660b50e5 100644 --- a/engine/controllers/models.cc +++ b/engine/controllers/models.cc @@ -114,7 +114,7 @@ void Models::GetModel( auto model_config = yaml_handler.GetModelConfig(); Json::Value obj = model_config.ToJson(); - + data.append(std::move(obj)); ret["data"] = data; ret["result"] = "OK"; @@ -155,7 +155,49 @@ void Models::DeleteModel(const HttpRequestPtr& req, callback(resp); } } +void Models::UpdateModel( + const HttpRequestPtr& req, + std::function&& callback) const { + if (!http_util::HasFieldInReq(req, callback, "modelId")) { + return; + } + auto model_id = (*(req->getJsonObject())).get("modelId", "").asString(); + auto json_body = *(req->getJsonObject()); + try { + modellist_utils::ModelListUtils model_list_utils; + auto model_entry = model_list_utils.GetModelInfo(model_id); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile(model_entry.path_to_model_yaml); + config::ModelConfig model_config = yaml_handler.GetModelConfig(); + model_config.FromJson(json_body); + yaml_handler.UpdateModelConfig(model_config); + yaml_handler.WriteYamlFile(model_entry.path_to_model_yaml); + std::string message = "Successfully update model ID '" + model_id + + "': " + json_body.toStyledString(); + LOG_INFO << message; + Json::Value ret; + ret["result"] = "Updated successfully!"; + ret["modelHandle"] = model_id; + ret["message"] = message; + + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + + } catch (const std::exception& e) { + std::string error_message = + "Error updating with model_id '" + model_id + "': " + e.what(); + LOG_ERROR << error_message; + Json::Value ret; + ret["result"] = "Updated failed!"; + ret["modelHandle"] = model_id; + ret["message"] = error_message; + auto resp = cortex_utils::CreateCortexHttpJsonResponse(ret); + resp->setStatusCode(k400BadRequest); + callback(resp); + } +} void Models::ImportModel( const HttpRequestPtr& req, std::function&& callback) const { diff --git a/engine/controllers/models.h b/engine/controllers/models.h index 4ae1ff41f..8d652c86a 100644 --- a/engine/controllers/models.h +++ b/engine/controllers/models.h @@ -15,6 +15,7 @@ class Models : public drogon::HttpController { METHOD_ADD(Models::PullModel, "/pull", Post); METHOD_ADD(Models::ListModel, "/list", Get); METHOD_ADD(Models::GetModel, "/get", Post); + METHOD_ADD(Models::UpdateModel, "/update/", Post); METHOD_ADD(Models::ImportModel, "/import", Post); METHOD_ADD(Models::DeleteModel, "/{1}", Delete); METHOD_ADD(Models::SetModelAlias, "/alias", Post); @@ -26,8 +27,11 @@ class Models : public drogon::HttpController { std::function&& callback) const; void GetModel(const HttpRequestPtr& req, std::function&& callback) const; - void ImportModel(const HttpRequestPtr& req, - std::function&& callback) const; + void UpdateModel(const HttpRequestPtr& req, + std::function&& callback) const; + void ImportModel( + const HttpRequestPtr& req, + std::function&& callback) const; void DeleteModel(const HttpRequestPtr& req, std::function&& callback, const std::string& model_id) const; diff --git a/engine/services/download_service.cc b/engine/services/download_service.cc index 496d01116..e3754fa76 100644 --- a/engine/services/download_service.cc +++ b/engine/services/download_service.cc @@ -12,6 +12,14 @@ #include "utils/format_utils.h" #include "utils/logging_utils.h" +#ifdef _WIN32 +#define ftell64(f) _ftelli64(f) +#define fseek64(f, o, w) _fseeki64(f, o, w) +#else +#define ftell64(f) ftello(f) +#define fseek64(f, o, w) fseeko(f, o, w) +#endif + namespace { size_t WriteCallback(void* ptr, size_t size, size_t nmemb, FILE* stream) { size_t written = fwrite(ptr, size, nmemb, stream); @@ -37,12 +45,19 @@ void DownloadService::AddDownloadTask( } // all items are valid, start downloading + bool download_successfully = true; for (const auto& item : task.items) { CLI_LOG("Start downloading: " + item.localPath.filename().string()); - Download(task.id, item, true); + try { + Download(task.id, item, true); + } catch (const std::runtime_error& e) { + CTL_ERR("Failed to download: " << item.downloadUrl << " - " << e.what()); + download_successfully = false; + break; + } } - if (callback.has_value()) { + if (download_successfully && callback.has_value()) { callback.value()(task); } } @@ -102,10 +117,15 @@ void DownloadService::Download(const std::string& download_id, std::string mode = "wb"; if (allow_resume && std::filesystem::exists(download_item.localPath) && download_item.bytes.has_value()) { - FILE* existing_file = fopen(download_item.localPath.string().c_str(), "r"); - fseek(existing_file, 0, SEEK_END); - curl_off_t existing_file_size = ftell(existing_file); - fclose(existing_file); + curl_off_t existing_file_size = GetLocalFileSize(download_item.localPath); + if (existing_file_size == -1) { + CLI_LOG("Cannot get file size: " << download_item.localPath.string() + << " . Start download over!"); + return; + } + CTL_INF("Existing file size: " << download_item.downloadUrl << " - " + << download_item.localPath.string() << " - " + << existing_file_size); auto missing_bytes = download_item.bytes.value() - existing_file_size; if (missing_bytes > 0) { CLI_LOG("Found unfinished download! Additional " @@ -149,9 +169,13 @@ void DownloadService::Download(const std::string& download_id, curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); if (mode == "ab") { - fseek(file, 0, SEEK_END); - curl_off_t local_file_size = ftell(file); - curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, local_file_size); + auto local_file_size = GetLocalFileSize(download_item.localPath); + if (local_file_size != -1) { + curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, + GetLocalFileSize(download_item.localPath)); + } else { + CTL_ERR("Cannot get file size: " << download_item.localPath.string()); + } } res = curl_easy_perform(curl); @@ -159,8 +183,26 @@ void DownloadService::Download(const std::string& download_id, if (res != CURLE_OK) { fprintf(stderr, "curl_easy_perform() failed: %s\n", curl_easy_strerror(res)); + throw std::runtime_error("Failed to download file " + + download_item.localPath.filename().string()); } fclose(file); curl_easy_cleanup(curl); } + +curl_off_t DownloadService::GetLocalFileSize( + const std::filesystem::path& path) const { + FILE* file = fopen(path.string().c_str(), "r"); + if (!file) { + return -1; + } + + if (fseek64(file, 0, SEEK_END) != 0) { + return -1; + } + + curl_off_t file_size = ftell64(file); + fclose(file); + return file_size; +} \ No newline at end of file diff --git a/engine/services/download_service.h b/engine/services/download_service.h index 7063be74c..b9f93ee82 100644 --- a/engine/services/download_service.h +++ b/engine/services/download_service.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include #include @@ -73,4 +74,6 @@ class DownloadService { private: void Download(const std::string& download_id, const DownloadItem& download_item, bool allow_resume); + + curl_off_t GetLocalFileSize(const std::filesystem::path& path) const; }; diff --git a/engine/services/engine_service.cc b/engine/services/engine_service.cc index 1b1f1d278..289bebd68 100644 --- a/engine/services/engine_service.cc +++ b/engine/services/engine_service.cc @@ -119,9 +119,12 @@ void EngineService::UnzipEngine(const std::string& engine, CTL_INF("engine: " << engine); CTL_INF("CUDA version: " << hw_inf_.cuda_driver_version); std::string cuda_variant = "cuda-"; - cuda_variant += GetSuitableCudaVersion(engine, hw_inf_.cuda_driver_version) + - "-" + hw_inf_.sys_inf->os + "-" + hw_inf_.sys_inf->arch + - ".tar.gz"; + auto cuda_github = + GetSuitableCudaVersion(engine, hw_inf_.cuda_driver_version); + // Github release cuda example: cuda-12-0-windows-amd64.tar.gz + std::replace(cuda_github.begin(), cuda_github.end(), '.', '-'); + cuda_variant += cuda_github + "-" + hw_inf_.sys_inf->os + "-" + + hw_inf_.sys_inf->arch + ".tar.gz"; CTL_INF("cuda_variant: " << cuda_variant); std::vector variants; diff --git a/engine/test/components/test_modellist_utils.cc b/engine/test/components/test_modellist_utils.cc index 68b06483d..d1dbf91e3 100644 --- a/engine/test/components/test_modellist_utils.cc +++ b/engine/test/components/test_modellist_utils.cc @@ -19,6 +19,7 @@ class ModelListUtilsTestSuite : public ::testing::Test { void TearDown() { // Clean up the temporary directory + std::remove((file_manager_utils::GetModelsContainerPath() / "model.list").string().c_str()); } TEST_F(ModelListUtilsTestSuite, TestAddModelEntry) { EXPECT_TRUE(model_list_.AddModelEntry(kTestModel)); @@ -128,4 +129,6 @@ TEST_F(ModelListUtilsTestSuite, TestHasModel) { EXPECT_TRUE(model_list_.HasModel("test_model_id")); EXPECT_TRUE(model_list_.HasModel("test_alias")); EXPECT_FALSE(model_list_.HasModel("non_existent_model")); + // Clean up + model_list_.DeleteModelEntry("test_model_id"); } \ No newline at end of file