@@ -12,6 +12,8 @@ struct LoraModel : public GGMLRunner {
1212 ModelLoader model_loader;
1313 bool load_failed = false ;
1414 bool applied = false ;
15+ std::vector<int > zero_index_vec = {0 };
16+ ggml_tensor* zero_index = NULL ;
1517
1618 LoraModel (ggml_backend_t backend,
1719 ggml_type wtype,
@@ -68,9 +70,19 @@ struct LoraModel : public GGMLRunner {
6870 return true ;
6971 }
7072
73+ ggml_tensor* to_f32 (ggml_context* ctx, ggml_tensor* a) {
74+ auto out = ggml_reshape_1d (ctx, a, ggml_nelements (a));
75+ out = ggml_get_rows (ctx, out, zero_index);
76+ out = ggml_reshape (ctx, out, a);
77+ return out;
78+ }
79+
7180 struct ggml_cgraph * build_lora_graph (std::map<std::string, struct ggml_tensor *> model_tensors) {
7281 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, LORA_GRAPH_SIZE, false );
7382
83+ zero_index = ggml_new_tensor_1d (compute_ctx, GGML_TYPE_I32, 1 );
84+ set_backend_tensor_data (zero_index, zero_index_vec.data ());
85+
7486 std::set<std::string> applied_lora_tensors;
7587 for (auto it : model_tensors) {
7688 std::string k_tensor = it.first ;
@@ -141,15 +153,16 @@ struct LoraModel : public GGMLRunner {
141153 GGML_ASSERT (ggml_nelements (updown) == ggml_nelements (weight));
142154 updown = ggml_scale_inplace (compute_ctx, updown, scale_value);
143155 ggml_tensor* final_weight;
144- // if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
145- // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, weight->n_dims, weight->ne);
146- // final_weight = ggml_cpy_inplace(compute_ctx, weight, final_weight);
147- // final_weight = ggml_add_inplace(compute_ctx, final_weight, updown);
148- // final_weight = ggml_cpy_inplace(compute_ctx, final_weight, weight);
149- // } else {
150- // final_weight = ggml_add_inplace(compute_ctx, weight, updown);
151- // }
152- final_weight = ggml_add_inplace (compute_ctx, weight, updown); // apply directly
156+ if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) {
157+ // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne);
158+ // final_weight = ggml_cpy(compute_ctx, weight, final_weight);
159+ final_weight = to_f32 (compute_ctx, weight);
160+ final_weight = ggml_add_inplace (compute_ctx, final_weight, updown);
161+ final_weight = ggml_cpy (compute_ctx, final_weight, weight);
162+ } else {
163+ final_weight = ggml_add_inplace (compute_ctx, weight, updown);
164+ }
165+ // final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly
153166 ggml_build_forward_expand (gf, final_weight);
154167 }
155168
0 commit comments