From 084592b5ce95eff5128f2d1b6a15909493bc4e62 Mon Sep 17 00:00:00 2001 From: James Date: Tue, 24 Sep 2024 09:20:22 +0700 Subject: [PATCH] feat: return model id when download model success --- engine/services/model_service.cc | 116 ++++++++++++++++++++++++++----- engine/services/model_service.h | 23 +++--- 2 files changed, 112 insertions(+), 27 deletions(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index dc6fc3f68..485bec869 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -2,15 +2,18 @@ #include #include #include +#include "config/gguf_parser.h" +#include "config/yaml_config.h" #include "utils/cli_selection_utils.h" #include "utils/cortexso_parser.h" #include "utils/file_manager_utils.h" #include "utils/huggingface_utils.h" #include "utils/logging_utils.h" -#include "utils/model_callback_utils.h" +#include "utils/modellist_utils.h" #include "utils/string_utils.h" -void ModelService::DownloadModel(const std::string& input) { +std::optional ModelService::DownloadModel( + const std::string& input) { if (input.empty()) { throw std::runtime_error( "Input must be Cortex Model Hub handle or HuggingFace url!"); @@ -32,15 +35,15 @@ void ModelService::DownloadModel(const std::string& input) { return DownloadModelByModelName(model_name); } - DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); CLI_LOG("Model " << model_name << " downloaded successfully!") - return; + return DownloadHuggingFaceGgufModel(author, model_name, std::nullopt); } return DownloadModelByModelName(input); } -void ModelService::DownloadModelByModelName(const std::string& modelName) { +std::optional ModelService::DownloadModelByModelName( + const std::string& modelName) { try { auto branches = huggingface_utils::GetModelRepositoryBranches("cortexso", modelName); @@ -52,12 +55,13 @@ void ModelService::DownloadModelByModelName(const std::string& modelName) { } if (options.empty()) { CLI_LOG("No variant found"); - return; + return std::nullopt; } auto selection = cli_selection_utils::PrintSelection(options); - DownloadModelFromCortexso(modelName, selection.value()); + return DownloadModelFromCortexso(modelName, selection.value()); } catch (const std::runtime_error& e) { CLI_LOG("Error downloading model, " << e.what()); + return std::nullopt; } } @@ -87,7 +91,8 @@ std::optional ModelService::GetDownloadedModel( return std::nullopt; } -void ModelService::DownloadModelByDirectUrl(const std::string& url) { +std::optional ModelService::DownloadModelByDirectUrl( + const std::string& url) { auto url_obj = url_parser::FromUrlString(url); if (url_obj.host == kHuggingFaceHost) { @@ -103,6 +108,9 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { return DownloadModelFromCortexso(model_id); } + std::string huggingFaceHost{kHuggingFaceHost}; + std::string unique_model_id{huggingFaceHost + "/" + author + "/" + model_id + + "/" + file_name}; auto local_path{file_manager_utils::GetModelsContainerPath() / "huggingface.co" / author / model_id / file_name}; @@ -119,33 +127,68 @@ void ModelService::DownloadModelByDirectUrl(const std::string& url) { auto downloadTask{DownloadTask{.id = model_id, .type = DownloadType::Model, .items = {DownloadItem{ - .id = url_obj.pathParams.back(), + .id = unique_model_id, .downloadUrl = download_url, .localPath = local_path, }}}}; - auto on_finished = [&author](const DownloadTask& finishedTask) { + auto on_finished = [&](const DownloadTask& finishedTask) { CLI_LOG("Model " << finishedTask.id << " downloaded successfully!") auto gguf_download_item = finishedTask.items[0]; - model_callback_utils::ParseGguf(gguf_download_item, author); + ParseGguf(gguf_download_item, author); }; download_service_.AddDownloadTask(downloadTask, on_finished); + return unique_model_id; } -void ModelService::DownloadModelFromCortexso(const std::string& name, - const std::string& branch) { +std::optional ModelService::DownloadModelFromCortexso( + const std::string& name, const std::string& branch) { + auto downloadTask = cortexso_parser::getDownloadTask(name, branch); if (downloadTask.has_value()) { - DownloadService().AddDownloadTask(downloadTask.value(), - model_callback_utils::DownloadModelCb); - CLI_LOG("Model " << name << " downloaded successfully!") + std::string model_id{name + ":" + branch}; + DownloadService().AddDownloadTask( + downloadTask.value(), [&](const DownloadTask& finishedTask) { + const DownloadItem* model_yml_item = nullptr; + auto need_parse_gguf = true; + + for (const auto& item : finishedTask.items) { + if (item.localPath.filename().string() == "model.yml") { + model_yml_item = &item; + } + } + + if (model_yml_item != nullptr) { + auto url_obj = + url_parser::FromUrlString(model_yml_item->downloadUrl); + CTL_INF("Adding model to modellist with branch: " << branch); + config::YamlHandler yaml_handler; + yaml_handler.ModelConfigFromFile( + model_yml_item->localPath.string()); + auto mc = yaml_handler.GetModelConfig(); + + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = model_id, + .author_repo_id = "cortexso", + .branch_name = branch, + .path_to_model_yaml = model_yml_item->localPath.string(), + .model_alias = model_id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry); + } + }); + + CLI_LOG("Model " << model_id << " downloaded successfully!") + return model_id; } else { CTL_ERR("Model not found"); + return std::nullopt; } } -void ModelService::DownloadHuggingFaceGgufModel( +std::optional ModelService::DownloadHuggingFaceGgufModel( const std::string& author, const std::string& modelName, std::optional fileName) { auto repo_info = @@ -153,7 +196,7 @@ void ModelService::DownloadHuggingFaceGgufModel( if (!repo_info.has_value()) { // throw is better? CTL_ERR("Model not found"); - return; + return std::nullopt; } if (!repo_info->gguf.has_value()) { @@ -172,5 +215,40 @@ void ModelService::DownloadHuggingFaceGgufModel( auto download_url = huggingface_utils::GetDownloadableUrl(author, modelName, selection.value()); - DownloadModelByDirectUrl(download_url); + return DownloadModelByDirectUrl(download_url); +} + +void ModelService::ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author) const { + + config::GGUFHandler gguf_handler; + config::YamlHandler yaml_handler; + gguf_handler.Parse(ggufDownloadItem.localPath.string()); + config::ModelConfig model_config = gguf_handler.GetModelConfig(); + model_config.id = + ggufDownloadItem.localPath.parent_path().filename().string(); + model_config.files = {ggufDownloadItem.localPath.string()}; + yaml_handler.UpdateModelConfig(model_config); + + auto yaml_path{ggufDownloadItem.localPath}; + auto yaml_name = yaml_path.replace_extension(".yml"); + + if (!std::filesystem::exists(yaml_path)) { + yaml_handler.WriteYamlFile(yaml_path.string()); + } + + auto url_obj = url_parser::FromUrlString(ggufDownloadItem.downloadUrl); + auto branch = url_obj.pathParams[3]; + CTL_INF("Adding model to modellist with branch: " << branch); + + auto author_id = author.has_value() ? author.value() : "cortexso"; + modellist_utils::ModelListUtils modellist_utils_obj; + modellist_utils::ModelEntry model_entry{ + .model_id = ggufDownloadItem.id, + .author_repo_id = author_id, + .branch_name = branch, + .path_to_model_yaml = yaml_name.string(), + .model_alias = ggufDownloadItem.id, + .status = modellist_utils::ModelStatus::READY}; + modellist_utils_obj.AddModelEntry(model_entry, true); } diff --git a/engine/services/model_service.h b/engine/services/model_service.h index 06212aaee..4237f1b17 100644 --- a/engine/services/model_service.h +++ b/engine/services/model_service.h @@ -8,27 +8,34 @@ class ModelService { public: ModelService() : download_service_{DownloadService()} {}; - void DownloadModel(const std::string& input); + /** + * Return model id if download successfully + */ + std::optional DownloadModel(const std::string& input); std::optional GetDownloadedModel( const std::string& modelId) const; private: - void DownloadModelByDirectUrl(const std::string& url); + std::optional DownloadModelByDirectUrl(const std::string& url); - void DownloadModelFromCortexso(const std::string& name, - const std::string& branch = "main"); + std::optional DownloadModelFromCortexso( + const std::string& name, const std::string& branch = "main"); /** * Handle downloading model which have following pattern: author/model_name */ - void DownloadHuggingFaceGgufModel(const std::string& author, - const std::string& modelName, - std::optional fileName); + std::optional DownloadHuggingFaceGgufModel( + const std::string& author, const std::string& modelName, + std::optional fileName); - void DownloadModelByModelName(const std::string& modelName); + std::optional DownloadModelByModelName( + const std::string& modelName); DownloadService download_service_; + void ParseGguf(const DownloadItem& ggufDownloadItem, + std::optional author = nullptr) const; + constexpr auto static kHuggingFaceHost = "huggingface.co"; };