@@ -4104,22 +4104,20 @@ static void llm_build_k_shift(
41044104 struct ggml_cgraph * graph,
41054105 llm_rope_type type,
41064106 int64_t n_ctx,
4107- int n_rot,
41084107 float freq_base,
41094108 float freq_scale,
41104109 const llm_build_cb & cb) {
41114110 const int64_t n_layer = hparams.n_layer ;
41124111 const int64_t n_head_kv = hparams.n_head_kv ;
41134112 const int64_t n_embd_head_k = hparams.n_embd_head_k ;
41144113 const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa ();
4114+ const int32_t n_rot = hparams.n_rot ;
41154115 const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx ;
41164116 const float ext_factor = cparams.yarn_ext_factor ;
41174117 const float attn_factor = cparams.yarn_attn_factor ;
41184118 const float beta_fast = cparams.yarn_beta_fast ;
41194119 const float beta_slow = cparams.yarn_beta_slow ;
41204120
4121- GGML_ASSERT (n_embd_head_k % n_rot == 0 );
4122-
41234121 struct ggml_tensor * K_shift = ggml_new_tensor_1d (ctx, GGML_TYPE_I32, n_ctx);
41244122 cb (K_shift, " K_shift" , -1 );
41254123
@@ -4523,7 +4521,7 @@ struct llm_build_context {
45234521
45244522 // shift the entire K-cache if needed
45254523 if (do_rope_shift) {
4526- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
4524+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
45274525 }
45284526
45294527 for (int il = 0 ; il < n_layer; ++il) {
@@ -4561,14 +4559,14 @@ struct llm_build_context {
45614559
45624560 Qcur = ggml_rope_custom (
45634561 ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4564- n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4562+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
45654563 ext_factor, attn_factor, beta_fast, beta_slow
45664564 );
45674565 cb (Qcur, " Qcur" , il);
45684566
45694567 Kcur = ggml_rope_custom (
45704568 ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4571- n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4569+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
45724570 ext_factor, attn_factor, beta_fast, beta_slow
45734571 );
45744572 cb (Kcur, " Kcur" , il);
@@ -4691,6 +4689,7 @@ struct llm_build_context {
46914689
46924690 const int64_t n_embd_head = hparams.n_embd_head_v ;
46934691 GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
4692+ GGML_ASSERT (n_embd_head == hparams.n_rot );
46944693
46954694 struct ggml_tensor * cur;
46964695 struct ggml_tensor * inpL;
@@ -4708,7 +4707,7 @@ struct llm_build_context {
47084707
47094708 // shift the entire K-cache if needed
47104709 if (do_rope_shift) {
4711- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
4710+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
47124711 }
47134712
47144713 for (int il = 0 ; il < n_layer; ++il) {
@@ -4734,12 +4733,12 @@ struct llm_build_context {
47344733 case MODEL_7B:
47354734 Qcur = ggml_rope_custom (
47364735 ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
4737- n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4736+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
47384737 ext_factor, attn_factor, beta_fast, beta_slow
47394738 );
47404739 Kcur = ggml_rope_custom (
47414740 ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
4742- n_embd_head , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
4741+ hparams. n_rot , 0 , 0 , n_orig_ctx, freq_base, freq_scale,
47434742 ext_factor, attn_factor, beta_fast, beta_slow
47444743 );
47454744 break ;
@@ -4812,6 +4811,7 @@ struct llm_build_context {
48124811 const int64_t n_embd_head = hparams.n_embd_head_v ;
48134812 const int64_t n_embd_gqa = hparams.n_embd_v_gqa ();
48144813 GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
4814+ GGML_ASSERT (n_embd_head == hparams.n_rot );
48154815
48164816 struct ggml_tensor * cur;
48174817 struct ggml_tensor * inpL;
@@ -4829,7 +4829,7 @@ struct llm_build_context {
48294829
48304830 // shift the entire K-cache if needed
48314831 if (do_rope_shift) {
4832- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
4832+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
48334833 }
48344834
48354835 for (int il = 0 ; il < n_layer; ++il) {
@@ -4870,13 +4870,13 @@ struct llm_build_context {
48704870
48714871 // using mode = 2 for neox mode
48724872 Qcur = ggml_rope_custom (
4873- ctx0, Qcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
4873+ ctx0, Qcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
48744874 freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
48754875 );
48764876 cb (Qcur, " Qcur" , il);
48774877
48784878 Kcur = ggml_rope_custom (
4879- ctx0, Kcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
4879+ ctx0, Kcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
48804880 freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
48814881 );
48824882 cb (Kcur, " Kcur" , il);
@@ -5033,9 +5033,8 @@ struct llm_build_context {
50335033 struct ggml_cgraph * gf = ggml_new_graph_custom (ctx0, LLAMA_MAX_NODES, false );
50345034
50355035 const int64_t n_embd_head = hparams.n_embd_head_v ;
5036- GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5037-
5038- const int64_t n_rot = n_embd_head_k / 2 ;
5036+ GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5037+ GGML_ASSERT (n_embd_head/2 == hparams.n_rot );
50395038
50405039 struct ggml_tensor * cur;
50415040 struct ggml_tensor * inpL;
@@ -5052,7 +5051,7 @@ struct llm_build_context {
50525051 cb (KQ_mask, " KQ_mask" , -1 );
50535052
50545053 if (do_rope_shift) {
5055- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5054+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
50565055 }
50575056
50585057 for (int il = 0 ; il < n_layer; ++il) {
@@ -5112,15 +5111,15 @@ struct llm_build_context {
51125111
51135112 // RoPE the first n_rot of q/k, pass the other half, and concat.
51145113 struct ggml_tensor * qrot = ggml_view_3d (
5115- ctx0, tmpq, n_rot, n_head, n_tokens,
5114+ ctx0, tmpq, hparams. n_rot , n_head, n_tokens,
51165115 ggml_element_size (tmpq) * n_embd_head,
51175116 ggml_element_size (tmpq) * n_embd_head * n_head,
51185117 0
51195118 );
51205119 cb (qrot, " qrot" , il);
51215120
51225121 struct ggml_tensor * krot = ggml_view_3d (
5123- ctx0, tmpk, n_rot, n_head, n_tokens,
5122+ ctx0, tmpk, hparams. n_rot , n_head, n_tokens,
51245123 ggml_element_size (tmpk) * n_embd_head,
51255124 ggml_element_size (tmpk) * n_embd_head * n_head,
51265125 0
@@ -5129,29 +5128,29 @@ struct llm_build_context {
51295128
51305129 // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
51315130 struct ggml_tensor * qpass = ggml_view_3d (
5132- ctx0, tmpq, n_rot, n_head, n_tokens,
5131+ ctx0, tmpq, hparams. n_rot , n_head, n_tokens,
51335132 ggml_element_size (tmpq) * n_embd_head,
51345133 ggml_element_size (tmpq) * n_embd_head * n_head,
5135- ggml_element_size (tmpq) * n_rot
5134+ ggml_element_size (tmpq) * hparams. n_rot
51365135 );
51375136 cb (qpass, " qpass" , il);
51385137
51395138 struct ggml_tensor * kpass = ggml_view_3d (
5140- ctx0, tmpk, n_rot, n_head, n_tokens,
5139+ ctx0, tmpk, hparams. n_rot , n_head, n_tokens,
51415140 ggml_element_size (tmpk) * n_embd_head,
51425141 ggml_element_size (tmpk) * n_embd_head * n_head,
5143- ggml_element_size (tmpk) * n_rot
5142+ ggml_element_size (tmpk) * hparams. n_rot
51445143 );
51455144 cb (kpass, " kpass" , il);
51465145
51475146 struct ggml_tensor * qrotated = ggml_rope_custom (
5148- ctx0, qrot, inp_pos, n_rot, 2 , 0 , n_orig_ctx,
5147+ ctx0, qrot, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
51495148 freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
51505149 );
51515150 cb (qrotated, " qrotated" , il);
51525151
51535152 struct ggml_tensor * krotated = ggml_rope_custom (
5154- ctx0, krot, inp_pos, n_rot, 2 , 0 , n_orig_ctx,
5153+ ctx0, krot, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
51555154 freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
51565155 );
51575156 cb (krotated, " krotated" , il);
@@ -5531,6 +5530,7 @@ struct llm_build_context {
55315530
55325531 const int64_t n_embd_head = hparams.n_embd_head_v ;
55335532 GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5533+ GGML_ASSERT (n_embd_head == hparams.n_rot );
55345534
55355535 struct ggml_tensor * cur;
55365536 struct ggml_tensor * inpL;
@@ -5548,7 +5548,7 @@ struct llm_build_context {
55485548
55495549 // shift the entire K-cache if needed
55505550 if (do_rope_shift) {
5551- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams. n_rot , freq_base, freq_scale, cb);
5551+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
55525552 }
55535553
55545554 for (int il = 0 ; il < n_layer; ++il) {
@@ -5661,7 +5661,7 @@ struct llm_build_context {
56615661
56625662 // shift the entire K-cache if needed
56635663 if (do_rope_shift) {
5664- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5664+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
56655665 }
56665666
56675667 for (int il = 0 ; il < n_layer; ++il) {
@@ -5693,13 +5693,13 @@ struct llm_build_context {
56935693
56945694 // using mode = 2 for neox mode
56955695 Qcur = ggml_rope_custom (
5696- ctx0, Qcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
5696+ ctx0, Qcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
56975697 freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
56985698 );
56995699 cb (Qcur, " Qcur" , il);
57005700
57015701 Kcur = ggml_rope_custom (
5702- ctx0, Kcur, inp_pos, n_embd_head , 2 , 0 , n_orig_ctx,
5702+ ctx0, Kcur, inp_pos, hparams. n_rot , 2 , 0 , n_orig_ctx,
57035703 freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
57045704 );
57055705 cb (Kcur, " Kcur" , il);
@@ -5778,7 +5778,7 @@ struct llm_build_context {
57785778
57795779 // shift the entire K-cache if needed
57805780 if (do_rope_shift) {
5781- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5781+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
57825782 }
57835783
57845784 for (int il = 0 ; il < n_layer; ++il) {
@@ -5874,6 +5874,7 @@ struct llm_build_context {
58745874
58755875 const int64_t n_embd_head = hparams.n_embd_head_v ;
58765876 GGML_ASSERT (n_embd_head == hparams.n_embd_head_k );
5877+ GGML_ASSERT (n_embd_head == hparams.n_rot );
58775878
58785879 struct ggml_tensor * cur;
58795880 struct ggml_tensor * inpL;
@@ -5891,7 +5892,7 @@ struct llm_build_context {
58915892
58925893 // shift the entire K-cache if needed
58935894 if (do_rope_shift) {
5894- llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
5895+ llm_build_k_shift (ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
58955896 }
58965897
58975898 for (int il = 0 ; il < n_layer; ++il) {
@@ -5917,13 +5918,13 @@ struct llm_build_context {
59175918 cb (Vcur, " Vcur" , il);
59185919
59195920 Qcur = ggml_rope_custom (
5920- ctx0, ggml_reshape_3d (ctx0, Qcur, n_embd_head , n_head, n_tokens), inp_pos,
5921+ ctx0, ggml_reshape_3d (ctx0, Qcur, hparams. n_rot , n_head, n_tokens), inp_pos,
59215922 n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
59225923 ext_factor, attn_factor, beta_fast, beta_slow);
59235924 cb (Qcur, " Qcur" , il);
59245925
59255926 Kcur = ggml_rope_custom (
5926- ctx0, ggml_reshape_3d (ctx0, Kcur, n_embd_head , n_head_kv, n_tokens), inp_pos,
5927+ ctx0, ggml_reshape_3d (ctx0, Kcur, hparams. n_rot , n_head_kv, n_tokens), inp_pos,
59275928 n_embd_head, 2 , 0 , n_orig_ctx, freq_base, freq_scale,
59285929 ext_factor, attn_factor, beta_fast, beta_slow);
59295930 cb (Kcur, " Kcur" , il);
0 commit comments