@@ -4745,16 +4745,6 @@ static void llm_load_hparams(
47454745
47464746 // non-transformer models do not have attention heads
47474747 if (hparams.n_head() > 0) {
4748- // sanity check for n_rot (optional)
4749- hparams.n_rot = hparams.n_embd / hparams.n_head();
4750-
4751- ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4752-
4753- if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4754- if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
4755- throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
4756- }
4757- }
47584748 // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
47594749 // gpt-j n_rot = rotary_dim
47604750
@@ -4763,6 +4753,17 @@ static void llm_load_hparams(
47634753
47644754 hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
47654755 ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
4756+
4757+ // sanity check for n_rot (optional)
4758+ hparams.n_rot = hparams.n_embd_head_k;
4759+
4760+ ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
4761+
4762+ if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
4763+ if (hparams.n_rot != hparams.n_embd_head_k) {
4764+ throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
4765+ }
4766+ }
47664767 } else {
47674768 hparams.n_rot = 0;
47684769 hparams.n_embd_head_k = 0;
@@ -11633,7 +11634,7 @@ struct llm_build_context {
1163311634
1163411635 Qcur = ggml_rope_ext(
1163511636 ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11636- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11637+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1163711638 ext_factor, attn_factor, beta_fast, beta_slow);
1163811639 cb(Qcur, "Qcur", il);
1163911640
@@ -11642,7 +11643,7 @@ struct llm_build_context {
1164211643
1164311644 Kcur = ggml_rope_ext(
1164411645 ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11645- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11646+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1164611647 ext_factor, attn_factor, beta_fast, beta_slow);
1164711648 cb(Kcur, "Kcur", il);
1164811649
@@ -11746,7 +11747,7 @@ struct llm_build_context {
1174611747
1174711748 Qcur = ggml_rope_ext(
1174811749 ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
11749- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11750+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1175011751 ext_factor, attn_factor, beta_fast, beta_slow);
1175111752 cb(Qcur, "Qcur", il);
1175211753
@@ -11755,7 +11756,7 @@ struct llm_build_context {
1175511756
1175611757 Kcur = ggml_rope_ext(
1175711758 ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
11758- n_embd_head_k , rope_type, n_ctx_orig, freq_base, freq_scale,
11759+ n_rot , rope_type, n_ctx_orig, freq_base, freq_scale,
1175911760 ext_factor, attn_factor, beta_fast, beta_slow);
1176011761 cb(Kcur, "Kcur", il);
1176111762
0 commit comments