@@ -2839,8 +2839,8 @@ static void llm_load_tensors(
28392839 auto & layer = model.layers [i];
28402840
28412841 layer.attn_norm = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_NORM, " weight" , i), {n_embd}, backend);
2842- layer.wqkv = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_QKV, " weight" , i), {n_embd, 3 * n_embd}, backend_split);
2843- layer.wo = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd, n_embd}, backend_split);
2842+ layer.wqkv = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_QKV, " weight" , i), {n_embd, n_embd + 2 *n_embd_gqa }, backend_split);
2843+ layer.wo = ml.create_tensor (ctx, tn (LLM_TENSOR_ATTN_OUT, " weight" , i), {n_embd, n_embd}, backend_split);
28442844
28452845 layer.ffn_norm = ml.create_tensor (ctx, tn (LLM_TENSOR_FFN_NORM, " weight" , i), {n_embd}, backend);
28462846
@@ -5368,7 +5368,7 @@ static struct ggml_cgraph * llm_build_mpt(
53685368 const int64_t n_layer = hparams.n_layer ;
53695369 const int64_t n_ctx = cparams.n_ctx ;
53705370 const int64_t n_head = hparams.n_head ;
5371- const int64_t n_head_kv = hparams.n_head_kv ; // == n_head for MPT, as there's no MQA/GQA
5371+ const int64_t n_head_kv = hparams.n_head_kv ;
53725372 const int64_t n_embd_head = hparams.n_embd_head ();
53735373 const int64_t n_embd_gqa = hparams.n_embd_gqa ();
53745374
0 commit comments