@@ -51,7 +51,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
5151
5252 std::string trigger_word = " img" ; // should be user settable
5353 std::string embd_dir;
54- int32_t num_custom_embeddings = 0 ;
54+ int32_t num_custom_embeddings = 0 ;
55+ int32_t num_custom_embeddings_2 = 0 ;
5556 std::vector<uint8_t > token_embed_custom;
5657 std::vector<std::string> readed_embeddings;
5758
@@ -131,28 +132,55 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
131132 params.no_alloc = false ;
132133 struct ggml_context * embd_ctx = ggml_init (params);
133134 struct ggml_tensor * embd = NULL ;
134- int64_t hidden_size = text_model-> model . hidden_size ;
135+ struct ggml_tensor * embd2 = NULL ;
135136 auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) {
136- if (tensor_storage.ne [0 ] != hidden_size) {
137- LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i" , tensor_storage.ne [0 ], hidden_size);
138- return false ;
137+ if (tensor_storage.ne [0 ] != text_model->model .hidden_size ) {
138+ if (text_model2) {
139+ if (tensor_storage.ne [0 ] == text_model2->model .hidden_size ) {
140+ embd2 = ggml_new_tensor_2d (embd_ctx, tensor_storage.type , text_model2->model .hidden_size , tensor_storage.n_dims > 1 ? tensor_storage.ne [1 ] : 1 );
141+ *dst_tensor = embd2;
142+ } else {
143+ LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i or %i" , tensor_storage.ne [0 ], text_model->model .hidden_size , text_model2->model .hidden_size );
144+ return false ;
145+ }
146+ } else {
147+ LOG_DEBUG (" embedding wrong hidden size, got %i, expected %i" , tensor_storage.ne [0 ], text_model->model .hidden_size );
148+ return false ;
149+ }
150+ } else {
151+ embd = ggml_new_tensor_2d (embd_ctx, tensor_storage.type , text_model->model .hidden_size , tensor_storage.n_dims > 1 ? tensor_storage.ne [1 ] : 1 );
152+ *dst_tensor = embd;
139153 }
140- embd = ggml_new_tensor_2d (embd_ctx, tensor_storage.type , hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne [1 ] : 1 );
141- *dst_tensor = embd;
142154 return true ;
143155 };
144156 model_loader.load_tensors (on_load, NULL );
145157 readed_embeddings.push_back (embd_name);
146- token_embed_custom.resize (token_embed_custom.size () + ggml_nbytes (embd));
147- memcpy ((void *)(token_embed_custom.data () + num_custom_embeddings * hidden_size * ggml_type_size (embd->type )),
148- embd->data ,
149- ggml_nbytes (embd));
150- for (int i = 0 ; i < embd->ne [1 ]; i++) {
151- bpe_tokens.push_back (text_model->model .vocab_size + num_custom_embeddings);
152- // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
153- num_custom_embeddings++;
158+ if (embd) {
159+ int64_t hidden_size = text_model->model .hidden_size ;
160+ token_embed_custom.resize (token_embed_custom.size () + ggml_nbytes (embd));
161+ memcpy ((void *)(token_embed_custom.data () + num_custom_embeddings * hidden_size * ggml_type_size (embd->type )),
162+ embd->data ,
163+ ggml_nbytes (embd));
164+ for (int i = 0 ; i < embd->ne [1 ]; i++) {
165+ bpe_tokens.push_back (text_model->model .vocab_size + num_custom_embeddings);
166+ // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
167+ num_custom_embeddings++;
168+ }
169+ LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i" , embd_name.c_str (), num_custom_embeddings);
170+ }
171+ if (embd2) {
172+ int64_t hidden_size = text_model2->model .hidden_size ;
173+ token_embed_custom.resize (token_embed_custom.size () + ggml_nbytes (embd2));
174+ memcpy ((void *)(token_embed_custom.data () + num_custom_embeddings_2 * hidden_size * ggml_type_size (embd2->type )),
175+ embd2->data ,
176+ ggml_nbytes (embd2));
177+ for (int i = 0 ; i < embd2->ne [1 ]; i++) {
178+ bpe_tokens.push_back (text_model2->model .vocab_size + num_custom_embeddings_2);
179+ // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings);
180+ num_custom_embeddings_2++;
181+ }
182+ LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i (text model 2)" , embd_name.c_str (), num_custom_embeddings_2);
154183 }
155- LOG_DEBUG (" embedding '%s' applied, custom embeddings: %i" , embd_name.c_str (), num_custom_embeddings);
156184 return true ;
157185 }
158186
0 commit comments