Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1163,10 +1163,6 @@ void parse_args(int argc, const char** argv, SDParams& params) {
exit(1);
}

if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) {
fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n");
}

if (params.mode == VID_GEN && params.video_frames <= 0) {
fprintf(stderr, "warning: --video-frames must be at least 1\n");
exit(1);
Expand Down Expand Up @@ -1643,6 +1639,7 @@ int main(int argc, const char* argv[]) {
params.lora_model_dir.c_str(),
params.embedding_dir.c_str(),
params.photo_maker_path.c_str(),
params.tensor_type_rules.c_str(),
vae_decode_only,
true,
params.n_threads,
Expand Down
87 changes: 48 additions & 39 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1877,15 +1877,59 @@ std::map<ggml_type, uint32_t> ModelLoader::get_vae_wtype_stat() {
return wtype_stat;
}

void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
static std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : split_string(tensor_type_rules, ',')) {
if (item.size() == 0)
continue;
std::string::size_type pos = item.find('=');
if (pos == std::string::npos) {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
continue;
}
std::string tensor_pattern = item.substr(0, pos);
std::string type_name = item.substr(pos + 1);

ggml_type tensor_type = GGML_TYPE_COUNT;

if (type_name == "f32") {
tensor_type = GGML_TYPE_F32;
} else {
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
auto trait = ggml_get_type_traits((ggml_type)i);
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
tensor_type = (ggml_type)i;
}
}
}

if (tensor_type != GGML_TYPE_COUNT) {
result.emplace_back(tensor_pattern, tensor_type);
} else {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
}
}
return result;
}

void ModelLoader::set_wtype_override(ggml_type wtype, std::string tensor_type_rules) {
auto map_rules = parse_tensor_type_rules(tensor_type_rules);
for (auto& [name, tensor_storage] : tensor_storage_map) {
if (!starts_with(name, prefix)) {
ggml_type dst_type = wtype;
for (const auto& tensor_type_rule : map_rules) {
std::regex pattern(tensor_type_rule.first);
if (std::regex_search(name, pattern)) {
dst_type = tensor_type_rule.second;
break;
}
}
if (dst_type == GGML_TYPE_COUNT) {
continue;
}
if (!tensor_should_be_converted(tensor_storage, wtype)) {
if (!tensor_should_be_converted(tensor_storage, dst_type)) {
continue;
}
tensor_storage.expected_type = wtype;
tensor_storage.expected_type = dst_type;
}
}

Expand Down Expand Up @@ -2226,41 +2270,6 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
return true;
}

std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules(const std::string& tensor_type_rules) {
std::vector<std::pair<std::string, ggml_type>> result;
for (const auto& item : split_string(tensor_type_rules, ',')) {
if (item.size() == 0)
continue;
std::string::size_type pos = item.find('=');
if (pos == std::string::npos) {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
continue;
}
std::string tensor_pattern = item.substr(0, pos);
std::string type_name = item.substr(pos + 1);

ggml_type tensor_type = GGML_TYPE_COUNT;

if (type_name == "f32") {
tensor_type = GGML_TYPE_F32;
} else {
for (size_t i = 0; i < GGML_TYPE_COUNT; i++) {
auto trait = ggml_get_type_traits((ggml_type)i);
if (trait->to_float && trait->type_size && type_name == trait->type_name) {
tensor_type = (ggml_type)i;
}
}
}

if (tensor_type != GGML_TYPE_COUNT) {
result.emplace_back(tensor_pattern, tensor_type);
} else {
LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str());
}
}
return result;
}

bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) {
const std::string& name = tensor_storage.name;
if (type != GGML_TYPE_COUNT) {
Expand Down
2 changes: 1 addition & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class ModelLoader {
std::map<ggml_type, uint32_t> get_diffusion_model_wtype_stat();
std::map<ggml_type, uint32_t> get_vae_wtype_stat();
String2TensorStorage& get_tensor_storage_map() { return tensor_storage_map; }
void set_wtype_override(ggml_type wtype, std::string prefix = "");
void set_wtype_override(ggml_type wtype, std::string tensor_type_rules = "");
bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, int n_threads = 0);
bool load_tensors(std::map<std::string, struct ggml_tensor*>& tensors,
std::set<std::string> ignore_tensors = {},
Expand Down
7 changes: 5 additions & 2 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,9 @@ class StableDiffusionGGML {
ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
: GGML_TYPE_COUNT;
if (wtype != GGML_TYPE_COUNT) {
model_loader.set_wtype_override(wtype);
std::string tensor_type_rules = SAFE_STR(sd_ctx_params->tensor_type_rules);
if (wtype != GGML_TYPE_COUNT || tensor_type_rules.size() > 0) {
model_loader.set_wtype_override(wtype, tensor_type_rules);
}

std::map<ggml_type, uint32_t> wtype_stat = model_loader.get_wtype_stat();
Expand Down Expand Up @@ -1893,6 +1894,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
"lora_model_dir: %s\n"
"embedding_dir: %s\n"
"photo_maker_path: %s\n"
"tensor_type_rules: %s\n"
"vae_decode_only: %s\n"
"free_params_immediately: %s\n"
"n_threads: %d\n"
Expand Down Expand Up @@ -1922,6 +1924,7 @@ char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) {
SAFE_STR(sd_ctx_params->lora_model_dir),
SAFE_STR(sd_ctx_params->embedding_dir),
SAFE_STR(sd_ctx_params->photo_maker_path),
SAFE_STR(sd_ctx_params->tensor_type_rules),
BOOL_STR(sd_ctx_params->vae_decode_only),
BOOL_STR(sd_ctx_params->free_params_immediately),
sd_ctx_params->n_threads,
Expand Down
1 change: 1 addition & 0 deletions stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ typedef struct {
const char* lora_model_dir;
const char* embedding_dir;
const char* photo_maker_path;
const char* tensor_type_rules;
bool vae_decode_only;
bool free_params_immediately;
int n_threads;
Expand Down
Loading