@@ -314,20 +314,24 @@ llama_memory_state_ptr llama_kv_cache_unified::init_batch(
314314 bool logits_all) {
315315 GGML_UNUSED (embd_pooled);
316316
317- auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
317+ do {
318+ auto sbatch = llama_sbatch (batch, hparams.n_embd , true , logits_all);
318319
319- std::vector<llama_ubatch> ubatches;
320- while (sbatch.n_tokens > 0 ) {
321- ubatches.push_back (sbatch.split_simple (n_ubatch));
322- }
320+ std::vector<llama_ubatch> ubatches;
321+ while (sbatch.n_tokens > 0 ) {
322+ ubatches.push_back (sbatch.split_simple (n_ubatch));
323+ }
323324
324- auto heads = prepare (ubatches);
325- if (heads.empty ()) {
326- return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE) ;
327- }
325+ auto heads = prepare (ubatches);
326+ if (heads.empty ()) {
327+ break ;
328+ }
328329
329- return std::make_unique<llama_kv_cache_unified_state>(
330- this , std::move (sbatch), std::move (heads), std::move (ubatches));
330+ return std::make_unique<llama_kv_cache_unified_state>(
331+ this , std::move (sbatch), std::move (heads), std::move (ubatches));
332+ } while (false );
333+
334+ return std::make_unique<llama_kv_cache_unified_state>(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
331335}
332336
333337llama_memory_state_ptr llama_kv_cache_unified::init_full () {
@@ -521,7 +525,6 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
521525 }
522526
523527 if (debug > 0 ) {
524- LLAMA_LOG_CONT (" \n " );
525528 LLAMA_LOG_DEBUG (" %s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n " , __func__, cells.used_max_p1 (), cells.get_used (), head, get_size (), n_swa);
526529
527530 if ((debug == 2 && n_swa > 0 ) || debug > 2 ) {
@@ -530,7 +533,13 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
530533 if (cells.is_empty (i)) {
531534 ss += ' .' ;
532535 } else {
533- ss += std::to_string (cells.seq_get (i));
536+ assert (cells.seq_count (i) >= 1 );
537+
538+ if (cells.seq_count (i) == 1 ) {
539+ ss += std::to_string (cells.seq_get (i));
540+ } else {
541+ ss += ' M' ;
542+ }
534543 }
535544 if (i%256 == 255 ) {
536545 ss += " *" ;
@@ -636,29 +645,39 @@ int32_t llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) const {
636645}
637646
638647void llama_kv_cache_unified::apply_ubatch (uint32_t head_cur, const llama_ubatch & ubatch) {
648+ if (debug > 0 ) {
649+ LLAMA_LOG_DEBUG (" %s: ubatch info:\n " , __func__);
650+ LLAMA_LOG_DEBUG (" %s: n_tokens = %d, equal_seqs = %d\n " , __func__, ubatch.n_tokens , ubatch.equal_seqs );
651+ LLAMA_LOG_DEBUG (" %s: n_seq_tokens = %d, n_seqs = %d\n " , __func__, ubatch.n_seq_tokens , ubatch.n_seqs );
652+ }
653+
639654 // keep track of the max sequence position that we would overwrite with this ubatch
640655 // for non-SWA cache, this would be always empty
641656 llama_seq_id seq_pos_max_rm[LLAMA_MAX_PARALLEL_SEQUENCES];
642657 for (int s = 0 ; s < LLAMA_MAX_PARALLEL_SEQUENCES; ++s) {
643658 seq_pos_max_rm[s] = -1 ;
644659 }
645660
646- for (uint32_t i = 0 ; i < ubatch.n_tokens ; ++i ) {
647- if (!cells. is_empty (head_cur + i) ) {
648- assert (cells. seq_count (head_cur + i) == 1 ) ;
661+ for (uint32_t s = 0 ; s < ubatch.n_seqs ; ++s ) {
662+ for ( uint32_t j = 0 ; j < ubatch. n_seq_tokens ; ++j ) {
663+ const uint32_t idx = s*ubatch. n_seq_tokens + j ;
649664
650- const llama_seq_id seq_id = cells.seq_get (head_cur + i);
651- const llama_pos pos = cells.pos_get (head_cur + i );
665+ if (! cells.is_empty (head_cur + idx)) {
666+ assert ( cells.seq_count (head_cur + idx) == 1 );
652667
653- seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
668+ const llama_seq_id seq_id = cells.seq_get (head_cur + idx);
669+ const llama_pos pos = cells.pos_get (head_cur + idx);
654670
655- cells.rm (head_cur + i);
656- }
671+ seq_pos_max_rm[seq_id] = std::max (seq_pos_max_rm[seq_id], pos);
672+
673+ cells.rm (head_cur + idx);
674+ }
657675
658- cells.pos_set (head_cur + i , ubatch.pos [i ]);
676+ cells.pos_set (head_cur + idx , ubatch.pos [idx ]);
659677
660- for (int32_t j = 0 ; j < ubatch.n_seq_id [i]; j++) {
661- cells.seq_add (head_cur + i, ubatch.seq_id [i][j]);
678+ for (int32_t i = 0 ; i < ubatch.n_seq_id [s]; i++) {
679+ cells.seq_add (head_cur + idx, ubatch.seq_id [s][i]);
680+ }
662681 }
663682 }
664683
@@ -677,7 +696,6 @@ void llama_kv_cache_unified::apply_ubatch(uint32_t head_cur, const llama_ubatch
677696 seq_rm (s, cells.seq_pos_min (s), seq_pos_max_rm[s] + 1 );
678697 }
679698 }
680-
681699 // move the head at the end of the slot
682700 head = head_cur + ubatch.n_tokens ;
683701}
@@ -774,14 +792,14 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
774792}
775793
776794void llama_kv_cache_unified::set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const {
777- const int64_t n_tokens = ubatch->n_tokens ;
778- const int64_t n_seq_tokens = ubatch->n_seq_tokens ;
779- const int64_t n_seqs = ubatch->n_seqs ;
795+ const uint32_t n_tokens = ubatch->n_tokens ;
796+ const uint32_t n_seq_tokens = ubatch->n_seq_tokens ;
797+ const uint32_t n_seqs = ubatch->n_seqs ;
780798
781799 GGML_ASSERT (ggml_backend_buffer_is_host (dst->buffer ));
782800 float * data = (float *) dst->data ;
783801
784- const auto n_kv = dst->ne [0 ];
802+ const int64_t n_kv = dst->ne [0 ];
785803
786804 // Use only the previous KV cells of the correct sequence for each token of the ubatch.
787805 // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
@@ -795,12 +813,14 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
795813 // xxxxx-----
796814 // xxxxx-----
797815 // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
798- for (int h = 0 ; h < 1 ; ++h) {
799- for (int s = 0 ; s < n_seqs; ++s) {
816+ for (uint32_t h = 0 ; h < 1 ; ++h) {
817+ for (uint32_t s = 0 ; s < n_seqs; ++s) {
800818 const llama_seq_id seq_id = ubatch->seq_id [s][0 ];
801819
802- for (int j = 0 ; j < n_seq_tokens; ++j) {
803- const llama_pos p1 = ubatch->pos [s*n_seq_tokens + j];
820+ for (uint32_t j = 0 ; j < n_seq_tokens; ++j) {
821+ const uint32_t idx = s*n_seq_tokens + j;
822+
823+ const llama_pos p1 = ubatch->pos [idx];
804824
805825 for (uint32_t i = 0 ; i < n_kv; ++i) {
806826 float f = 0 .0f ;
@@ -830,16 +850,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub
830850 f = -INFINITY;
831851 }
832852
833- data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j *n_kv + i] = f;
853+ data[h*(n_kv*n_tokens) + idx *n_kv + i] = f;
834854 }
835855 }
836856 }
837857
838858 // mask padded tokens
839859 if (data) {
840- for (int i = n_tokens; i < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++i ) {
841- for (uint32_t j = 0 ; j < n_kv; ++j ) {
842- data[h*(n_kv*n_tokens) + i *n_kv + j ] = -INFINITY;
860+ for (uint32_t j = n_tokens; j < GGML_PAD (n_tokens, GGML_KQ_MASK_PAD); ++j ) {
861+ for (uint32_t i = 0 ; i < n_kv; ++i ) {
862+ data[h*(n_kv*n_tokens) + j *n_kv + i ] = -INFINITY;
843863 }
844864 }
845865 }
@@ -1490,9 +1510,11 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
14901510 seq_rm (dest_seq_id, -1 , -1 );
14911511
14921512 llama_sbatch sbatch;
1493- llama_ubatch batch = sbatch.reserve_ubatch (cell_count, /* has_embd */ false );
1513+ llama_ubatch ubatch = sbatch.reserve_ubatch (cell_count, /* has_embd */ false );
14941514
1495- batch.n_tokens = cell_count;
1515+ ubatch.n_tokens = cell_count;
1516+ ubatch.n_seq_tokens = cell_count;
1517+ ubatch.n_seqs = 1 ;
14961518
14971519 for (uint32_t i = 0 ; i < cell_count; ++i) {
14981520 llama_pos pos;
@@ -1512,27 +1534,27 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15121534 io.read_to (&seq_id, sizeof (seq_id));
15131535 }
15141536
1515- batch .pos [i] = pos;
1516- batch .n_seq_id [i] = n_seq_id;
1517- batch .seq_id [i] = &dest_seq_id;
1537+ ubatch .pos [i] = pos;
1538+ ubatch .n_seq_id [i] = n_seq_id;
1539+ ubatch .seq_id [i] = &dest_seq_id;
15181540 }
15191541
1520- const auto head_cur = find_slot (batch );
1542+ const auto head_cur = find_slot (ubatch );
15211543 if (head_cur < 0 ) {
15221544 LLAMA_LOG_ERROR (" %s: failed to find available cells in kv cache\n " , __func__);
15231545 return false ;
15241546 }
15251547
1526- apply_ubatch (head_cur, batch );
1548+ apply_ubatch (head_cur, ubatch );
15271549
15281550 // keep the head at the old position because we will read the KV data into it in state_read_data()
15291551 head = head_cur;
15301552
15311553 // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values)
15321554 // Assume that this is one contiguous block of cells
15331555 GGML_ASSERT (head_cur + cell_count <= cells.size ());
1534- GGML_ASSERT (cells.pos_get (head_cur) == batch .pos [0 ]);
1535- GGML_ASSERT (cells.pos_get (head_cur + cell_count - 1 ) == batch .pos [cell_count - 1 ]);
1556+ GGML_ASSERT (cells.pos_get (head_cur) == ubatch .pos [0 ]);
1557+ GGML_ASSERT (cells.pos_get (head_cur + cell_count - 1 ) == ubatch .pos [cell_count - 1 ]);
15361558 GGML_ASSERT (cells.seq_has (head_cur, dest_seq_id));
15371559 GGML_ASSERT (cells.seq_has (head_cur + cell_count - 1 , dest_seq_id));
15381560 } else {
0 commit comments