Skip to content
This repository was archived by the owner on Jul 4, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions engine/commands/model_upd_cmd.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#include "model_upd_cmd.h"

#include "utils/logging_utils.h"

namespace commands {

ModelUpdCmd::ModelUpdCmd(std::string model_handle)
: model_handle_(std::move(model_handle)) {}

void ModelUpdCmd::Exec(
const std::unordered_map<std::string, std::string>& options) {
try {
auto model_entry = model_list_utils_.GetModelInfo(model_handle_);
yaml_handler_.ModelConfigFromFile(model_entry.path_to_model_yaml);
model_config_ = yaml_handler_.GetModelConfig();

for (const auto& [key, value] : options) {
if (!value.empty()) {
UpdateConfig(key, value);
}
}

yaml_handler_.UpdateModelConfig(model_config_);
yaml_handler_.WriteYamlFile(model_entry.path_to_model_yaml);
CLI_LOG("Successfully updated model ID '" + model_handle_ + "'!");
} catch (const std::exception& e) {
CLI_LOG("Failed to update model with model ID '" + model_handle_ +
"': " + e.what());
}
}

void ModelUpdCmd::UpdateConfig(const std::string& key,
const std::string& value) {
static const std::unordered_map<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I think there's a better way. We can define a function: UpdateModelConfig(const std::string& attr, std::variant<std::string, explicit_int, explicit_bool>) and handle the update logic there.

However if this works for you then I'm fine with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think should keep it this way. We don't know which format user will input, so need to handle it for specific cases

std::string,
std::function<void(ModelUpdCmd*, const std::string&, const std::string&)>>
updaters = {
{"name",
[](ModelUpdCmd* self, const std::string&, const std::string& v) {
self->model_config_.name = v;
}},
{"model",
[](ModelUpdCmd* self, const std::string&, const std::string& v) {
self->model_config_.model = v;
}},
{"version",
[](ModelUpdCmd* self, const std::string&, const std::string& v) {
self->model_config_.version = v;
}},
{"stop", &ModelUpdCmd::UpdateVectorField},
{"top_p",
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
self->UpdateNumericField(
k, v, [self](float f) { self->model_config_.top_p = f; });
}},
{"temperature",
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
self->UpdateNumericField(k, v, [self](float f) {
self->model_config_.temperature = f;
});
}},
{"frequency_penalty",
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
self->UpdateNumericField(k, v, [self](float f) {
self->model_config_.frequency_penalty = f;
});
}},
{"presence_penalty",
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
self->UpdateNumericField(k, v, [self](float f) {
self->model_config_.presence_penalty = f;
});
}},
{"max_tokens",
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
self->UpdateNumericField(k, v, [self](float f) {
self->model_config_.max_tokens = static_cast<int>(f);
});
}},
{"stream",
[](ModelUpdCmd* self, const std::string& k, const std::string& v) {
self->UpdateBooleanField(
k, v, [self](bool b) { self->model_config_.stream = b; });
}},
// Add more fields here...
};

if (auto it = updaters.find(key); it != updaters.end()) {
it->second(this, key, value);
LogUpdate(key, value);
}
}

void ModelUpdCmd::UpdateVectorField(const std::string& key,
const std::string& value) {
std::vector<std::string> tokens;
std::istringstream iss(value);
std::string token;
while (std::getline(iss, token, ',')) {
tokens.push_back(token);
}
model_config_.stop = tokens;
}

void ModelUpdCmd::UpdateNumericField(const std::string& key,
const std::string& value,
std::function<void(float)> setter) {
try {
float numericValue = std::stof(value);
setter(numericValue);
} catch (const std::exception& e) {
CLI_LOG("Failed to parse numeric value for " << key << ": " << e.what());
}
}

void ModelUpdCmd::UpdateBooleanField(const std::string& key,
const std::string& value,
std::function<void(bool)> setter) {
bool boolValue = (value == "true" || value == "1");
setter(boolValue);
}

void ModelUpdCmd::LogUpdate(const std::string& key, const std::string& value) {
CLI_LOG("Updated " << key << " to: " << value);
}

} // namespace commands
30 changes: 30 additions & 0 deletions engine/commands/model_upd_cmd.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once
#include <iostream>
#include <optional>
#include <string>
#include <unordered_map>
#include <vector>
#include "config/model_config.h"
#include "utils/modellist_utils.h"
#include "config/yaml_config.h"
namespace commands {
class ModelUpdCmd {
public:
ModelUpdCmd(std::string model_handle);
void Exec(const std::unordered_map<std::string, std::string>& options);

private:
std::string model_handle_;
config::ModelConfig model_config_;
config::YamlHandler yaml_handler_;
modellist_utils::ModelListUtils model_list_utils_;

void UpdateConfig(const std::string& key, const std::string& value);
void UpdateVectorField(const std::string& key, const std::string& value);
void UpdateNumericField(const std::string& key, const std::string& value,
std::function<void(float)> setter);
void UpdateBooleanField(const std::string& key, const std::string& value,
std::function<void(bool)> setter);
void LogUpdate(const std::string& key, const std::string& value);
};
} // namespace commands
108 changes: 108 additions & 0 deletions engine/config/model_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,115 @@ struct ModelConfig {
int n_probs = 0;
int min_keep = 0;
std::string grammar;

void FromJson(const Json::Value& json) {
// do now allow to update ID and model field because it is unique identifier
// if (json.isMember("id"))
// id = json["id"].asString();
if (json.isMember("name"))
name = json["name"].asString();
// if (json.isMember("model"))
// model = json["model"].asString();
if (json.isMember("version"))
version = json["version"].asString();

if (json.isMember("stop") && json["stop"].isArray()) {
stop.clear();
for (const auto& s : json["stop"]) {
stop.push_back(s.asString());
}
}

if (json.isMember("stream"))
stream = json["stream"].asBool();
if (json.isMember("top_p"))
top_p = json["top_p"].asFloat();
if (json.isMember("temperature"))
temperature = json["temperature"].asFloat();
if (json.isMember("frequency_penalty"))
frequency_penalty = json["frequency_penalty"].asFloat();
if (json.isMember("presence_penalty"))
presence_penalty = json["presence_penalty"].asFloat();
if (json.isMember("max_tokens"))
max_tokens = json["max_tokens"].asInt();
if (json.isMember("seed"))
seed = json["seed"].asInt();
if (json.isMember("dynatemp_range"))
dynatemp_range = json["dynatemp_range"].asFloat();
if (json.isMember("dynatemp_exponent"))
dynatemp_exponent = json["dynatemp_exponent"].asFloat();
if (json.isMember("top_k"))
top_k = json["top_k"].asInt();
if (json.isMember("min_p"))
min_p = json["min_p"].asFloat();
if (json.isMember("tfs_z"))
tfs_z = json["tfs_z"].asFloat();
if (json.isMember("typ_p"))
typ_p = json["typ_p"].asFloat();
if (json.isMember("repeat_last_n"))
repeat_last_n = json["repeat_last_n"].asInt();
if (json.isMember("repeat_penalty"))
repeat_penalty = json["repeat_penalty"].asFloat();
if (json.isMember("mirostat"))
mirostat = json["mirostat"].asBool();
if (json.isMember("mirostat_tau"))
mirostat_tau = json["mirostat_tau"].asFloat();
if (json.isMember("mirostat_eta"))
mirostat_eta = json["mirostat_eta"].asFloat();
if (json.isMember("penalize_nl"))
penalize_nl = json["penalize_nl"].asBool();
if (json.isMember("ignore_eos"))
ignore_eos = json["ignore_eos"].asBool();
if (json.isMember("n_probs"))
n_probs = json["n_probs"].asInt();
if (json.isMember("min_keep"))
min_keep = json["min_keep"].asInt();
if (json.isMember("ngl"))
ngl = json["ngl"].asInt();
if (json.isMember("ctx_len"))
ctx_len = json["ctx_len"].asInt();
if (json.isMember("engine"))
engine = json["engine"].asString();
if (json.isMember("prompt_template"))
prompt_template = json["prompt_template"].asString();
if (json.isMember("system_template"))
system_template = json["system_template"].asString();
if (json.isMember("user_template"))
user_template = json["user_template"].asString();
if (json.isMember("ai_template"))
ai_template = json["ai_template"].asString();
if (json.isMember("os"))
os = json["os"].asString();
if (json.isMember("gpu_arch"))
gpu_arch = json["gpu_arch"].asString();
if (json.isMember("quantization_method"))
quantization_method = json["quantization_method"].asString();
if (json.isMember("precision"))
precision = json["precision"].asString();

if (json.isMember("files") && json["files"].isArray()) {
files.clear();
for (const auto& file : json["files"]) {
files.push_back(file.asString());
}
}

if (json.isMember("created"))
created = json["created"].asUInt64();
if (json.isMember("object"))
object = json["object"].asString();
if (json.isMember("owned_by"))
owned_by = json["owned_by"].asString();
if (json.isMember("text_model"))
text_model = json["text_model"].asBool();

if (engine == "cortex.tensorrt-llm") {
if (json.isMember("trtllm_version"))
trtllm_version = json["trtllm_version"].asString();
if (json.isMember("tp"))
tp = json["tp"].asInt();
}
}
Json::Value ToJson() const {
Json::Value obj;

Expand Down
79 changes: 75 additions & 4 deletions engine/controllers/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "commands/model_pull_cmd.h"
#include "commands/model_start_cmd.h"
#include "commands/model_stop_cmd.h"
#include "commands/model_upd_cmd.h"
#include "commands/run_cmd.h"
#include "commands/server_start_cmd.h"
#include "commands/server_stop_cmd.h"
Expand Down Expand Up @@ -256,10 +257,8 @@ void CommandLineParser::SetupModelCommands() {
commands::ModelAliasCmd mdc;
mdc.Exec(cml_data_.model_id, cml_data_.model_alias);
});

auto model_update_cmd =
models_cmd->add_subcommand("update", "Update configuration of a model");
model_update_cmd->group(kSubcommands);
// Model update parameters comment
ModelUpdate(models_cmd);

std::string model_path;
auto model_import_cmd = models_cmd->add_subcommand(
Expand Down Expand Up @@ -373,6 +372,12 @@ void CommandLineParser::SetupSystemCommands() {
update_cmd->group(kSystemGroup);
update_cmd->add_option("-v", cml_data_.cortex_version, "");
update_cmd->callback([this] {
#if !defined(_WIN32)
if (getuid()) {
CLI_LOG("Error: Not root user. Please run with sudo.");
return;
}
#endif
commands::CortexUpdCmd cuc;
cuc.Exec(cml_data_.cortex_version);
cml_data_.check_upd = false;
Expand Down Expand Up @@ -442,3 +447,69 @@ void CommandLineParser::EngineGet(CLI::App* parent) {
[engine_name] { commands::EngineGetCmd().Exec(engine_name); });
}
}

void CommandLineParser::ModelUpdate(CLI::App* parent) {
auto model_update_cmd =
parent->add_subcommand("update", "Update configuration of a model");
model_update_cmd->group(kSubcommands);
model_update_cmd->add_option("--model_id", cml_data_.model_id, "Model ID")
->required();

// Add options dynamically
std::vector<std::string> option_names = {"name",
"model",
"version",
"stop",
"top_p",
"temperature",
"frequency_penalty",
"presence_penalty",
"max_tokens",
"stream",
"ngl",
"ctx_len",
"engine",
"prompt_template",
"system_template",
"user_template",
"ai_template",
"os",
"gpu_arch",
"quantization_method",
"precision",
"tp",
"trtllm_version",
"text_model",
"files",
"created",
"object",
"owned_by",
"seed",
"dynatemp_range",
"dynatemp_exponent",
"top_k",
"min_p",
"tfs_z",
"typ_p",
"repeat_last_n",
"repeat_penalty",
"mirostat",
"mirostat_tau",
"mirostat_eta",
"penalize_nl",
"ignore_eos",
"n_probs",
"min_keep",
"grammar"};

for (const auto& option_name : option_names) {
model_update_cmd->add_option("--" + option_name,
cml_data_.model_update_options[option_name],
option_name);
}

model_update_cmd->callback([this]() {
commands::ModelUpdCmd command(cml_data_.model_id);
command.Exec(cml_data_.model_update_options);
});
}
Loading