Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.

Commit b8078af

Browse files
Init model.list utils (#1240)
* Init model.list utils * Add cmakelist compile * Add cmakelist compile * Fix CI build windows * add unitest * Add test * Fix fail unitest
1 parent 142adf0 commit b8078af

File tree

6 files changed

+358
-2
lines changed

6 files changed

+358
-2
lines changed

engine/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ find_package(CURL REQUIRED)
8282
add_executable(${TARGET_NAME} main.cc
8383
${CMAKE_CURRENT_SOURCE_DIR}/utils/cpuid/cpu_info.cc
8484
${CMAKE_CURRENT_SOURCE_DIR}/utils/file_logger.cc
85+
${CMAKE_CURRENT_SOURCE_DIR}/utils/modellist_utils.cc
8586
)
8687

8788
target_link_libraries(${TARGET_NAME} PRIVATE httplib::httplib)

engine/test/components/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ project(test-components)
33

44
enable_testing()
55

6-
add_executable(${PROJECT_NAME} ${SRCS})
6+
add_executable(${PROJECT_NAME} ${SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/../../utils/modellist_utils.cc)
77

88
find_package(Drogon CONFIG REQUIRED)
99
find_package(GTest CONFIG REQUIRED)
10+
find_package(yaml-cpp CONFIG REQUIRED)
1011

11-
target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main
12+
target_link_libraries(${PROJECT_NAME} PRIVATE Drogon::Drogon GTest::gtest GTest::gtest_main yaml-cpp::yaml-cpp
1213
${CMAKE_THREAD_LIBS_INIT})
1314
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../)
1415

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
#include <filesystem>
2+
#include <iostream>
3+
#include "gtest/gtest.h"
4+
#include "utils/modellist_utils.h"
5+
#include "utils/file_manager_utils.h"
6+
class ModelListUtilsTestSuite : public ::testing::Test {
7+
protected:
8+
modellist_utils::ModelListUtils model_list_;
9+
10+
const modellist_utils::ModelEntry kTestModel{
11+
"test_model_id", "test_author",
12+
"main", "/path/to/model.yaml",
13+
"test_alias", modellist_utils::ModelStatus::READY};
14+
};
15+
void SetUp() {
16+
// Create a temporary directory for tests
17+
file_manager_utils::CreateConfigFileIfNotExist();
18+
}
19+
20+
void TearDown() {
21+
// Clean up the temporary directory
22+
}
23+
TEST_F(ModelListUtilsTestSuite, TestAddModelEntry) {
24+
EXPECT_TRUE(model_list_.AddModelEntry(kTestModel));
25+
26+
auto retrieved_model = model_list_.GetModelInfo("test_model_id");
27+
EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id);
28+
EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id);
29+
}
30+
31+
TEST_F(ModelListUtilsTestSuite, TestGetModelInfo) {
32+
model_list_.AddModelEntry(kTestModel);
33+
34+
auto model_by_id = model_list_.GetModelInfo("test_model_id");
35+
EXPECT_EQ(model_by_id.model_id, kTestModel.model_id);
36+
37+
auto model_by_alias = model_list_.GetModelInfo("test_alias");
38+
EXPECT_EQ(model_by_alias.model_id, kTestModel.model_id);
39+
40+
EXPECT_THROW(model_list_.GetModelInfo("non_existent_model"),
41+
std::runtime_error);
42+
}
43+
44+
TEST_F(ModelListUtilsTestSuite, TestUpdateModelEntry) {
45+
model_list_.AddModelEntry(kTestModel);
46+
47+
modellist_utils::ModelEntry updated_model = kTestModel;
48+
updated_model.status = modellist_utils::ModelStatus::RUNNING;
49+
50+
EXPECT_TRUE(model_list_.UpdateModelEntry("test_model_id", updated_model));
51+
52+
auto retrieved_model = model_list_.GetModelInfo("test_model_id");
53+
EXPECT_EQ(retrieved_model.status, modellist_utils::ModelStatus::RUNNING);
54+
updated_model.status = modellist_utils::ModelStatus::READY;
55+
model_list_.UpdateModelEntry("test_model_id", updated_model);
56+
}
57+
58+
TEST_F(ModelListUtilsTestSuite, TestDeleteModelEntry) {
59+
model_list_.AddModelEntry(kTestModel);
60+
61+
EXPECT_TRUE(model_list_.DeleteModelEntry("test_model_id"));
62+
EXPECT_THROW(model_list_.GetModelInfo("test_model_id"), std::runtime_error);
63+
}
64+
65+
TEST_F(ModelListUtilsTestSuite, TestGenerateShortenedAlias) {
66+
auto alias = model_list_.GenerateShortenedAlias(
67+
"huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf", {});
68+
EXPECT_EQ(alias, "model_id_xxx");
69+
70+
// Test with existing entries to force longer alias
71+
modellist_utils::ModelEntry existing_model = kTestModel;
72+
existing_model.model_alias = "model_id_xxx";
73+
std::vector<modellist_utils::ModelEntry> existing_entries = {existing_model};
74+
75+
alias = model_list_.GenerateShortenedAlias(
76+
"huggingface.co/bartowski/llama3.1-7b-gguf/Model_ID_Xxx.gguf",
77+
existing_entries);
78+
EXPECT_EQ(alias, "llama3.1-7b-gguf:model_id_xxx");
79+
}
80+
81+
TEST_F(ModelListUtilsTestSuite, TestPersistence) {
82+
model_list_.AddModelEntry(kTestModel);
83+
84+
// Create a new ModelListUtils instance to test if it loads from file
85+
modellist_utils::ModelListUtils new_model_list;
86+
auto retrieved_model = new_model_list.GetModelInfo("test_model_id");
87+
88+
EXPECT_EQ(retrieved_model.model_id, kTestModel.model_id);
89+
EXPECT_EQ(retrieved_model.author_repo_id, kTestModel.author_repo_id);
90+
model_list_.DeleteModelEntry("test_model_id");
91+
}

engine/utils/file_manager_utils.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ inline void CreateDirectoryRecursively(const std::string& path) {
202202
}
203203

204204
inline std::filesystem::path GetModelsContainerPath() {
205+
CreateConfigFileIfNotExist();
205206
auto cortex_path = GetCortexDataPath();
206207
auto models_container_path = cortex_path / "models";
207208

engine/utils/modellist_utils.cc

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
#include "modellist_utils.h"
2+
#include <algorithm>
3+
#include <filesystem>
4+
#include <fstream>
5+
#include <iostream>
6+
#include <regex>
7+
#include <sstream>
8+
#include <stdexcept>
9+
#include "file_manager_utils.h"
10+
namespace modellist_utils {
11+
const std::string ModelListUtils::kModelListPath =
12+
(file_manager_utils::GetModelsContainerPath() /
13+
std::filesystem::path("model.list"))
14+
.string();
15+
16+
std::vector<ModelEntry> ModelListUtils::LoadModelList() const {
17+
std::vector<ModelEntry> entries;
18+
std::filesystem::path file_path(kModelListPath);
19+
20+
// Check if the file exists, if not, create it
21+
if (!std::filesystem::exists(file_path)) {
22+
std::ofstream create_file(kModelListPath);
23+
if (!create_file) {
24+
throw std::runtime_error("Unable to create model.list file: " +
25+
kModelListPath);
26+
}
27+
create_file.close();
28+
return entries; // Return empty vector for newly created file
29+
}
30+
31+
std::ifstream file(kModelListPath);
32+
if (!file.is_open()) {
33+
throw std::runtime_error("Unable to open model.list file: " +
34+
kModelListPath);
35+
}
36+
37+
std::string line;
38+
while (std::getline(file, line)) {
39+
std::istringstream iss(line);
40+
ModelEntry entry;
41+
std::string status_str;
42+
if (!(iss >> entry.model_id >> entry.author_repo_id >> entry.branch_name >>
43+
entry.path_to_model_yaml >> entry.model_alias >> status_str)) {
44+
LOG_WARN << "Invalid entry in model.list: " << line;
45+
} else {
46+
entry.status =
47+
(status_str == "RUNNING") ? ModelStatus::RUNNING : ModelStatus::READY;
48+
entries.push_back(entry);
49+
}
50+
}
51+
return entries;
52+
}
53+
54+
bool ModelListUtils::IsUnique(const std::vector<ModelEntry>& entries,
55+
const std::string& model_id,
56+
const std::string& model_alias) const {
57+
return std::none_of(
58+
entries.begin(), entries.end(), [&](const ModelEntry& entry) {
59+
return entry.model_id == model_id || entry.model_alias == model_id ||
60+
entry.model_id == model_alias ||
61+
entry.model_alias == model_alias;
62+
});
63+
}
64+
65+
void ModelListUtils::SaveModelList(
66+
const std::vector<ModelEntry>& entries) const {
67+
std::ofstream file(kModelListPath);
68+
if (!file.is_open()) {
69+
throw std::runtime_error("Unable to open model.list file for writing: " +
70+
kModelListPath);
71+
}
72+
73+
for (const auto& entry : entries) {
74+
file << entry.model_id << " " << entry.author_repo_id << " "
75+
<< entry.branch_name << " " << entry.path_to_model_yaml << " "
76+
<< entry.model_alias << " "
77+
<< (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY")
78+
<< std::endl;
79+
}
80+
}
81+
82+
std::string ModelListUtils::GenerateShortenedAlias(
83+
const std::string& model_id, const std::vector<ModelEntry>& entries) const {
84+
std::vector<std::string> parts;
85+
std::istringstream iss(model_id);
86+
std::string part;
87+
while (std::getline(iss, part, '/')) {
88+
parts.push_back(part);
89+
}
90+
91+
if (parts.empty()) {
92+
return model_id; // Return original if no parts
93+
}
94+
95+
// Extract the filename without extension
96+
std::string filename = parts.back();
97+
size_t last_dot_pos = filename.find_last_of('.');
98+
if (last_dot_pos != std::string::npos) {
99+
filename = filename.substr(0, last_dot_pos);
100+
}
101+
102+
// Convert to lowercase
103+
std::transform(filename.begin(), filename.end(), filename.begin(),
104+
[](unsigned char c) { return std::tolower(c); });
105+
106+
// Generate alias candidates
107+
std::vector<std::string> candidates;
108+
candidates.push_back(filename);
109+
110+
if (parts.size() >= 2) {
111+
candidates.push_back(parts[parts.size() - 2] + ":" + filename);
112+
}
113+
114+
if (parts.size() >= 3) {
115+
candidates.push_back(parts[parts.size() - 3] + ":" +
116+
parts[parts.size() - 2] + "/" + filename);
117+
}
118+
119+
if (parts.size() >= 4) {
120+
candidates.push_back(parts[0] + ":" + parts[1] + "/" +
121+
parts[parts.size() - 2] + "/" + filename);
122+
}
123+
124+
// Find the first unique candidate
125+
for (const auto& candidate : candidates) {
126+
if (IsUnique(entries, model_id, candidate)) {
127+
return candidate;
128+
}
129+
}
130+
131+
// If all candidates are taken, append a number to the last candidate
132+
std::string base_candidate = candidates.back();
133+
int suffix = 1;
134+
std::string unique_candidate = base_candidate;
135+
while (!IsUnique(entries, model_id, unique_candidate)) {
136+
unique_candidate = base_candidate + "-" + std::to_string(suffix++);
137+
}
138+
139+
return unique_candidate;
140+
}
141+
142+
ModelEntry ModelListUtils::GetModelInfo(const std::string& identifier) const {
143+
std::lock_guard<std::mutex> lock(mutex_);
144+
auto entries = LoadModelList();
145+
auto it = std::find_if(
146+
entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) {
147+
return entry.model_id == identifier || entry.model_alias == identifier;
148+
});
149+
150+
if (it != entries.end()) {
151+
return *it;
152+
} else {
153+
throw std::runtime_error("Model not found: " + identifier);
154+
}
155+
}
156+
157+
void ModelListUtils::PrintModelInfo(const ModelEntry& entry) const {
158+
LOG_INFO << "Model ID: " << entry.model_id;
159+
LOG_INFO << "Author/Repo ID: " << entry.author_repo_id;
160+
LOG_INFO << "Branch Name: " << entry.branch_name;
161+
LOG_INFO << "Path to model.yaml: " << entry.path_to_model_yaml;
162+
LOG_INFO << "Model Alias: " << entry.model_alias;
163+
LOG_INFO << "Status: "
164+
<< (entry.status == ModelStatus::RUNNING ? "RUNNING" : "READY");
165+
}
166+
167+
bool ModelListUtils::AddModelEntry(ModelEntry new_entry, bool use_short_alias) {
168+
std::lock_guard<std::mutex> lock(mutex_);
169+
auto entries = LoadModelList();
170+
171+
if (IsUnique(entries, new_entry.model_id, new_entry.model_alias)) {
172+
if (use_short_alias) {
173+
new_entry.model_alias =
174+
GenerateShortenedAlias(new_entry.model_id, entries);
175+
}
176+
new_entry.status = ModelStatus::READY; // Set default status to READY
177+
entries.push_back(std::move(new_entry));
178+
SaveModelList(entries);
179+
return true;
180+
}
181+
return false; // Entry not added due to non-uniqueness
182+
}
183+
184+
bool ModelListUtils::UpdateModelEntry(const std::string& identifier,
185+
const ModelEntry& updated_entry) {
186+
std::lock_guard<std::mutex> lock(mutex_);
187+
auto entries = LoadModelList();
188+
auto it = std::find_if(
189+
entries.begin(), entries.end(), [&identifier](const ModelEntry& entry) {
190+
return entry.model_id == identifier || entry.model_alias == identifier;
191+
});
192+
193+
if (it != entries.end()) {
194+
*it = updated_entry;
195+
SaveModelList(entries);
196+
return true;
197+
}
198+
return false; // Entry not found
199+
}
200+
201+
bool ModelListUtils::DeleteModelEntry(const std::string& identifier) {
202+
std::lock_guard<std::mutex> lock(mutex_);
203+
auto entries = LoadModelList();
204+
auto it = std::find_if(entries.begin(), entries.end(),
205+
[&identifier](const ModelEntry& entry) {
206+
return (entry.model_id == identifier ||
207+
entry.model_alias == identifier) &&
208+
entry.status == ModelStatus::READY;
209+
});
210+
211+
if (it != entries.end()) {
212+
entries.erase(it);
213+
SaveModelList(entries);
214+
return true;
215+
}
216+
return false; // Entry not found or not in READY state
217+
}
218+
} // namespace modellist_utils

engine/utils/modellist_utils.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#pragma once
2+
#include <trantor/utils/Logger.h>
3+
#include <mutex>
4+
#include <string>
5+
#include <vector>
6+
#include "logging_utils.h"
7+
namespace modellist_utils {
8+
9+
enum class ModelStatus { READY, RUNNING };
10+
11+
struct ModelEntry {
12+
std::string model_id;
13+
std::string author_repo_id;
14+
std::string branch_name;
15+
std::string path_to_model_yaml;
16+
std::string model_alias;
17+
ModelStatus status;
18+
};
19+
20+
class ModelListUtils {
21+
22+
private:
23+
mutable std::mutex mutex_; // For thread safety
24+
25+
std::vector<ModelEntry> LoadModelList() const;
26+
bool IsUnique(const std::vector<ModelEntry>& entries,
27+
const std::string& model_id,
28+
const std::string& model_alias) const;
29+
void SaveModelList(const std::vector<ModelEntry>& entries) const;
30+
31+
public:
32+
static const std::string kModelListPath;
33+
ModelListUtils() = default;
34+
std::string GenerateShortenedAlias(
35+
const std::string& model_id,
36+
const std::vector<ModelEntry>& entries) const;
37+
ModelEntry GetModelInfo(const std::string& identifier) const;
38+
void PrintModelInfo(const ModelEntry& entry) const;
39+
bool AddModelEntry(ModelEntry new_entry, bool use_short_alias = false);
40+
bool UpdateModelEntry(const std::string& identifier,
41+
const ModelEntry& updated_entry);
42+
bool DeleteModelEntry(const std::string& identifier);
43+
};
44+
} // namespace modellist_utils

0 commit comments

Comments
 (0)