@@ -1000,12 +1000,13 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const {
1000
1000
{
1001
1001
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Hybrid recurrent is not supported with SWA attention layers" );
1002
1002
1003
- const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1003
+ const auto n_kv = inp->mctx ->get_attn ()->get_n_kv ();
1004
+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1004
1005
1005
1006
inp->self_k_idxs = mctx_cur->get_attn ()->build_input_k_idxs (ctx0, ubatch);
1006
1007
inp->self_v_idxs = mctx_cur->get_attn ()->build_input_v_idxs (ctx0, ubatch);
1007
1008
1008
- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1009
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1009
1010
ggml_set_input (inp->self_kq_mask );
1010
1011
1011
1012
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1032,6 +1033,10 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1032
1033
float kq_scale) const {
1033
1034
const bool v_trans = v->nb [1 ] > v->nb [2 ];
1034
1035
1036
+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1037
+
1038
+ q = ggml_reshape_4d (ctx0, q, q->ne [0 ], q->ne [1 ], q->ne [2 ]/n_seqs, n_seqs);
1039
+
1035
1040
q = ggml_permute (ctx0, q, 0 , 2 , 1 , 3 );
1036
1041
k = ggml_permute (ctx0, k, 0 , 2 , 1 , 3 );
1037
1042
v = ggml_permute (ctx0, v, 0 , 2 , 1 , 3 );
@@ -1080,7 +1085,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1080
1085
#endif
1081
1086
}
1082
1087
1083
- cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1088
+ cur = ggml_reshape_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
1084
1089
} else {
1085
1090
ggml_tensor * kq = ggml_mul_mat (ctx0, k, q);
1086
1091
@@ -1125,7 +1130,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
1125
1130
1126
1131
cur = ggml_permute (ctx0, kqv, 0 , 2 , 1 , 3 );
1127
1132
1128
- cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens);
1133
+ cur = ggml_cont_2d (ctx0, cur, cur->ne [0 ]*n_head, n_tokens*n_seqs );
1129
1134
1130
1135
if (!cparams.offload_kqv ) {
1131
1136
// all nodes between the KV store and the attention output are run on the CPU
@@ -1202,12 +1207,13 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified()
1202
1207
{
1203
1208
GGML_ASSERT (hparams.swa_type == LLAMA_SWA_TYPE_NONE && " Use llama_kv_cache_unified_iswa for SWA" );
1204
1209
1205
- const auto n_kv = mctx_cur->get_n_kv ();
1210
+ const auto n_kv = mctx_cur->get_n_kv ();
1211
+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1206
1212
1207
1213
inp->self_k_idxs = mctx_cur->build_input_k_idxs (ctx0, ubatch);
1208
1214
inp->self_v_idxs = mctx_cur->build_input_v_idxs (ctx0, ubatch);
1209
1215
1210
- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1216
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1211
1217
ggml_set_input (inp->self_kq_mask );
1212
1218
1213
1219
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1449,13 +1455,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1449
1455
1450
1456
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
1451
1457
1458
+ const auto n_seqs = cparams.n_seq_virt > 1 ? ubatch.n_seqs_unq : 1 ;
1459
+
1452
1460
{
1453
1461
const auto n_kv = mctx_cur->get_base ()->get_n_kv ();
1454
1462
1455
1463
inp->self_k_idxs = mctx_cur->get_base ()->build_input_k_idxs (ctx0, ubatch);
1456
1464
inp->self_v_idxs = mctx_cur->get_base ()->build_input_v_idxs (ctx0, ubatch);
1457
1465
1458
- inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1466
+ inp->self_kq_mask = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1459
1467
ggml_set_input (inp->self_kq_mask );
1460
1468
1461
1469
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
@@ -1469,7 +1477,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
1469
1477
inp->self_k_idxs_swa = mctx_cur->get_swa ()->build_input_k_idxs (ctx0, ubatch);
1470
1478
inp->self_v_idxs_swa = mctx_cur->get_swa ()->build_input_v_idxs (ctx0, ubatch);
1471
1479
1472
- inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD), 1 , 1 );
1480
+ inp->self_kq_mask_swa = ggml_new_tensor_4d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens/n_seqs , GGML_KQ_MASK_PAD), 1 , n_seqs );
1473
1481
ggml_set_input (inp->self_kq_mask_swa );
1474
1482
1475
1483
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask_swa , GGML_TYPE_F16) : inp->self_kq_mask_swa ;
0 commit comments