From bc17643f5c98906f133f82a39139bfd620bc8fc7 Mon Sep 17 00:00:00 2001 From: James Date: Wed, 2 Oct 2024 02:49:28 +0700 Subject: [PATCH] fix: validate url before pull --- engine/services/model_service.cc | 33 ++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/engine/services/model_service.cc b/engine/services/model_service.cc index 7c55d5adf..0f65f7fbd 100644 --- a/engine/services/model_service.cc +++ b/engine/services/model_service.cc @@ -1,6 +1,7 @@ #include "model_service.h" #include #include +#include #include #include "config/gguf_parser.h" #include "config/yaml_config.h" @@ -58,8 +59,7 @@ cpp::result GetDownloadTask( httplib::Client cli(url.GetProtocolAndHost()); auto res = cli.Get(url.GetPathAndQuery()); if (res->status != httplib::StatusCode::OK_200) { - auto err = res.error(); - return cpp::fail("HTTP error: " + httplib::to_string(err)); + return cpp::fail("Model " + modelId + " not found"); } auto jsonResponse = json::parse(res->body); @@ -103,7 +103,6 @@ cpp::result ModelService::DownloadModel( } if (string_utils::StartsWith(input, "https://")) { - // TODO: better name, for example handle url return HandleUrl(input, async); } @@ -195,13 +194,27 @@ cpp::result ModelService::HandleUrl( auto file_name{url_obj.pathParams.back()}; if (author == "cortexso") { - // TODO: try to get the branch return DownloadModelFromCortexso(model_id); } + if (url_obj.pathParams.size() < 5) { + if (url_obj.pathParams.size() < 2) { + return cpp::fail("Invalid url: " + url); + } + return DownloadHuggingFaceGgufModel(author, model_id, std::nullopt, async); + } + std::string huggingFaceHost{kHuggingFaceHost}; - std::string unique_model_id{huggingFaceHost + "/" + author + "/" + model_id + - "/" + file_name}; + std::string unique_model_id{author + ":" + model_id + ":" + file_name}; + + cortex::db::Models modellist_handler; + auto model_entry = modellist_handler.GetModelInfo(unique_model_id); + + if (model_entry.has_value()) { + CLI_LOG("Model already downloaded: " << unique_model_id); + return cpp::fail("Please delete the model before downloading again"); + } + auto local_path{file_manager_utils::GetModelsContainerPath() / "huggingface.co" / author / model_id / file_name}; @@ -240,7 +253,7 @@ cpp::result ModelService::HandleUrl( } else { auto result = download_service_.AddDownloadTask(downloadTask, on_finished); if (result.has_error()) { - // CTL_ERR(result.error()); + CTL_ERR(result.error()); return cpp::fail(result.error()); } else { CLI_LOG("Model " << model_id << " downloaded successfully!") @@ -415,7 +428,7 @@ cpp::result ModelService::StartModel( auto res = cli.Post("/inferences/server/loadmodel", httplib::Headers(), data_str.data(), data_str.size(), "application/json"); if (res) { - if (res->status == httplib::StatusCode::OK_200) { + if (res->status == httplib::StatusCode::OK_200) { return true; } else { CTL_ERR("Model failed to load with status code: " << res->status); @@ -459,7 +472,7 @@ cpp::result ModelService::StopModel( auto res = cli.Post("/inferences/server/unloadmodel", httplib::Headers(), data_str.data(), data_str.size(), "application/json"); if (res) { - if (res->status == httplib::StatusCode::OK_200) { + if (res->status == httplib::StatusCode::OK_200) { return true; } else { CTL_ERR("Model failed to unload with status code: " << res->status); @@ -519,4 +532,4 @@ cpp::result ModelService::GetModelStatus( return cpp::fail("Fail to get model status with ID '" + model_handle + "': " + e.what()); } -} \ No newline at end of file +}