@@ -100,7 +100,7 @@ const char* unused_tensors[] = {
100100 " model_ema.diffusion_model" ,
101101 " embedding_manager" ,
102102 " denoiser.sigmas" ,
103- " text_encoders.t5xxl.transformer.encoder.embed_tokens.weight" , // only used during training
103+ " text_encoders.t5xxl.transformer.encoder.embed_tokens.weight" , // only used during training
104104};
105105
106106bool is_unused_tensor (std::string name) {
@@ -1169,7 +1169,6 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
11691169 n_dims = 1 ;
11701170 }
11711171
1172-
11731172 TensorStorage tensor_storage (prefix + name, type, ne, n_dims, file_index, ST_HEADER_SIZE_LEN + header_size_ + begin);
11741173 tensor_storage.reverse_ne ();
11751174
@@ -1914,7 +1913,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19141913 };
19151914 int tensor_count = 0 ;
19161915 int64_t t1 = ggml_time_ms ();
1917- bool partial = false ;
1916+ bool partial = false ;
19181917 for (auto & tensor_storage : processed_tensor_storages) {
19191918 if (tensor_storage.file_index != file_index) {
19201919 ++tensor_count;
@@ -1997,9 +1996,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
19971996 }
19981997 }
19991998 size_t tensor_max = processed_tensor_storages.size ();
2000- int64_t t2 = ggml_time_ms ();
1999+ int64_t t2 = ggml_time_ms ();
20012000 pretty_progress (++tensor_count, tensor_max, (t2 - t1) / 1000 .0f );
2002- t1 = t2;
2001+ t1 = t2;
20032002 partial = tensor_count != tensor_max;
20042003 }
20052004
@@ -2088,6 +2087,41 @@ bool ModelLoader::load_tensors(std::map<std::string, struct ggml_tensor*>& tenso
20882087 return true ;
20892088}
20902089
2090+ std::vector<std::pair<std::string, ggml_type>> parse_tensor_type_rules (const std::string& tensor_type_rules) {
2091+ std::vector<std::pair<std::string, ggml_type>> result;
2092+ for (const auto & item : splitString (tensor_type_rules, ' ,' )) {
2093+ if (item.size () == 0 )
2094+ continue ;
2095+ std::string::size_type pos = item.find (' =' );
2096+ if (pos == std::string::npos) {
2097+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2098+ continue ;
2099+ }
2100+ std::string tensor_pattern = item.substr (0 , pos);
2101+ std::string type_name = item.substr (pos + 1 );
2102+
2103+ ggml_type tensor_type = GGML_TYPE_COUNT;
2104+
2105+ if (type_name == " f32" ) {
2106+ tensor_type = GGML_TYPE_F32;
2107+ } else {
2108+ for (size_t i = 0 ; i < SD_TYPE_COUNT; i++) {
2109+ auto trait = ggml_get_type_traits ((ggml_type)i);
2110+ if (trait->to_float && trait->type_size && type_name == trait->type_name ) {
2111+ tensor_type = (ggml_type)i;
2112+ }
2113+ }
2114+ }
2115+
2116+ if (tensor_type != GGML_TYPE_COUNT) {
2117+ result.emplace_back (tensor_pattern, tensor_type);
2118+ } else {
2119+ LOG_WARN (" ignoring invalid quant override \" %s\" " , item.c_str ());
2120+ }
2121+ }
2122+ return result;
2123+ }
2124+
20912125bool ModelLoader::tensor_should_be_converted (const TensorStorage& tensor_storage, ggml_type type) {
20922126 const std::string& name = tensor_storage.name ;
20932127 if (type != GGML_TYPE_COUNT) {
@@ -2119,7 +2153,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
21192153 return false ;
21202154}
21212155
2122- bool ModelLoader::save_to_gguf_file (const std::string& file_path, ggml_type type) {
2156+ bool ModelLoader::save_to_gguf_file (const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str ) {
21232157 auto backend = ggml_backend_cpu_init ();
21242158 size_t mem_size = 1 * 1024 * 1024 ; // for padding
21252159 mem_size += tensor_storages.size () * ggml_tensor_overhead ();
@@ -2129,12 +2163,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type
21292163
21302164 gguf_context* gguf_ctx = gguf_init_empty ();
21312165
2166+ auto tensor_type_rules = parse_tensor_type_rules (tensor_type_rules_str);
2167+
21322168 auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool {
21332169 const std::string& name = tensor_storage.name ;
2170+ ggml_type tensor_type = tensor_storage.type ;
2171+ ggml_type dst_type = type;
21342172
2135- ggml_type tensor_type = tensor_storage.type ;
2136- if (tensor_should_be_converted (tensor_storage, type)) {
2137- tensor_type = type;
2173+ for (const auto & tensor_type_rule : tensor_type_rules) {
2174+ std::regex pattern (tensor_type_rule.first );
2175+ if (std::regex_search (name, pattern)) {
2176+ dst_type = tensor_type_rule.second ;
2177+ break ;
2178+ }
2179+ }
2180+
2181+ if (tensor_should_be_converted (tensor_storage, dst_type)) {
2182+ tensor_type = dst_type;
21382183 }
21392184
21402185 ggml_tensor* tensor = ggml_new_tensor (ggml_ctx, tensor_type, tensor_storage.n_dims , tensor_storage.ne );
@@ -2193,7 +2238,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type)
21932238 return mem_size;
21942239}
21952240
2196- bool convert (const char * input_path, const char * vae_path, const char * output_path, sd_type_t output_type) {
2241+ bool convert (const char * input_path, const char * vae_path, const char * output_path, sd_type_t output_type, const char * tensor_type_rules ) {
21972242 ModelLoader model_loader;
21982243
21992244 if (!model_loader.init_from_file (input_path)) {
@@ -2207,6 +2252,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa
22072252 return false ;
22082253 }
22092254 }
2210- bool success = model_loader.save_to_gguf_file (output_path, (ggml_type)output_type);
2255+ bool success = model_loader.save_to_gguf_file (output_path, (ggml_type)output_type, tensor_type_rules );
22112256 return success;
22122257}
0 commit comments