@@ -8436,74 +8436,33 @@ static enum ggml_status llama_graph_compute(
84368436 return status;
84378437}
84388438
8439- // decode a batch of tokens by evaluating the transformer
8440- // in case of unsuccessful decoding (error or warning),
8441- // the kv_cache state will be returned to its original state
8442- // (for non-recurrent models) or cleaned (for recurrent models)
8443- //
8444- // - lctx: llama context
8445- // - batch: batch to evaluate
8446- //
8447- // return 0 on success
8448- // return positive int on warning
8449- // return negative int on error
8450- //
8451- static int llama_decode_impl (
8452- llama_context & lctx,
8453- llama_batch inp_batch) {
8454-
8455- lctx.is_encoding = false ;
8456-
8457- if (inp_batch.n_tokens == 0 ) {
8458- LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8459- return -1 ;
8460- }
8461-
8462- // temporary allocate memory for the input batch if needed
8463- llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8464-
8465- const llama_batch & batch = batch_allocr.batch ;
8466- const uint32_t n_tokens_all = batch.n_tokens ;
8467-
8439+ static int llama_prepare_sbatch (
8440+ llama_context & lctx,
8441+ const llama_batch & batch,
8442+ uint32_t & n_outputs) {
84688443 const auto & model = lctx.model ;
8469- const auto & vocab = model.vocab ;
84708444 const auto & hparams = model.hparams ;
84718445 const auto & cparams = lctx.cparams ;
84728446
8473- GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
8447+ const uint32_t n_tokens_all = batch.n_tokens ;
8448+ const int64_t n_embd = hparams.n_embd ;
8449+
8450+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8451+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
84748452
8453+ GGML_ASSERT ((!batch.token && batch.embd ) || (batch.token && !batch.embd )); // NOLINT
84758454 if (batch.token ) {
84768455 for (uint32_t i = 0 ; i < n_tokens_all; ++i) {
8477- if (batch.token [i] < 0 || ( uint32_t ) batch.token [i] >= model.vocab .n_tokens ()) {
8456+ if (batch.token [i] < 0 || uint32_t ( batch.token [i]) >= model.vocab .n_tokens ()) {
84788457 LLAMA_LOG_ERROR (" %s: invalid token[%d] = %d\n " , __func__, i, batch.token [i]);
84798458 return -1 ;
84808459 }
84818460 }
84828461 }
8483-
84848462 GGML_ASSERT (n_tokens_all <= cparams.n_batch );
8485-
84868463 GGML_ASSERT ((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && " non-causal attention requires n_ubatch >= n_tokens" );
84878464
8488- if (lctx.t_compute_start_us == 0 ) {
8489- lctx.t_compute_start_us = ggml_time_us ();
8490- }
84918465 lctx.n_queued_tokens += n_tokens_all;
8492-
8493- auto & kv_self = lctx.kv_self ;
8494- llama_kv_slot_restorer kv_slot_restorer (kv_self);
8495-
8496- const int64_t n_embd = hparams.n_embd ;
8497- const int64_t n_vocab = vocab.n_tokens ();
8498-
8499- uint32_t n_outputs = 0 ;
8500- uint32_t n_outputs_prev = 0 ;
8501-
8502- const auto n_ubatch = cparams.n_ubatch ;
8503-
8504- // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8505- const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8506-
85078466 lctx.embd_seq .clear ();
85088467
85098468 // count outputs
@@ -8519,7 +8478,7 @@ static int llama_decode_impl(
85198478 }
85208479
85218480 lctx.sbatch .from_batch (batch, n_embd,
8522- /* simple_split */ !kv_self.recurrent ,
8481+ /* simple_split */ !lctx. kv_self .recurrent ,
85238482 /* logits_all */ n_outputs == n_tokens_all);
85248483
85258484 // reserve output buffer
@@ -8528,70 +8487,148 @@ static int llama_decode_impl(
85288487 return -2 ;
85298488 };
85308489
8531- while (lctx.sbatch .n_tokens > 0 ) {
8532- llama_ubatch ubatch;
8533- if (kv_self.recurrent ) {
8534- if (embd_pooled) {
8535- // Pooled embeddings cannot be split across ubatches (yet)
8536- ubatch = lctx.sbatch .split_seq (n_ubatch);
8537- } else {
8538- // recurrent model architectures are easier to implement
8539- // with equal-length sequences
8540- ubatch = lctx.sbatch .split_equal (n_ubatch);
8541- }
8490+ return 0 ;
8491+ }
8492+
8493+ static int llama_prepare_ubatch (
8494+ llama_context & lctx,
8495+ llama_kv_slot_restorer & kv_slot_restorer,
8496+ llama_ubatch & ubatch,
8497+ const uint32_t n_outputs,
8498+ const uint32_t n_tokens_all) {
8499+ GGML_ASSERT (lctx.sbatch .n_tokens > 0 );
8500+
8501+ auto & kv_self = lctx.kv_self ;
8502+ const auto & cparams = lctx.cparams ;
8503+ const auto & hparams = lctx.model .hparams ;
8504+
8505+ // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
8506+ const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
8507+
8508+ if (lctx.kv_self .recurrent ) {
8509+ if (embd_pooled) {
8510+ // Pooled embeddings cannot be split across ubatches (yet)
8511+ ubatch = lctx.sbatch .split_seq (cparams.n_ubatch );
85428512 } else {
8543- ubatch = lctx.sbatch .split_simple (n_ubatch);
8513+ // recurrent model architectures are easier to implement
8514+ // with equal-length sequences
8515+ ubatch = lctx.sbatch .split_equal (cparams.n_ubatch );
85448516 }
8545- const uint32_t n_tokens = ubatch.n_tokens ;
8517+ } else {
8518+ ubatch = lctx.sbatch .split_simple (cparams.n_ubatch );
8519+ }
85468520
8547- // count the outputs in this u_batch
8548- {
8549- int32_t n_outputs_new = 0 ;
8521+ // count the outputs in this u_batch
8522+ {
8523+ int32_t n_outputs_new = 0 ;
85508524
8551- if (n_outputs == n_tokens_all) {
8552- n_outputs_new = n_tokens;
8553- } else {
8554- GGML_ASSERT (ubatch.output );
8555- for (uint32_t i = 0 ; i < n_tokens; i++) {
8556- n_outputs_new += (int32_t ) (ubatch.output [i] != 0 );
8557- }
8525+ if (n_outputs == n_tokens_all) {
8526+ n_outputs_new = ubatch.n_tokens ;
8527+ } else {
8528+ GGML_ASSERT (ubatch.output );
8529+ for (uint32_t i = 0 ; i < ubatch.n_tokens ; i++) {
8530+ n_outputs_new += int32_t (ubatch.output [i] != 0 );
85588531 }
8532+ }
8533+
8534+ // needs to happen before the graph is built
8535+ lctx.n_outputs = n_outputs_new;
8536+ }
8537+
8538+ // non-causal masks do not use the KV cache
8539+ if (hparams.causal_attn ) {
8540+ llama_kv_cache_update (&lctx);
85598541
8560- // needs to happen before the graph is built
8561- lctx.n_outputs = n_outputs_new;
8542+ // if we have enough unused cells before the current head ->
8543+ // better to start searching from the beginning of the cache, hoping to fill it
8544+ if (kv_self.head > kv_self.used + 2 *ubatch.n_tokens ) {
8545+ kv_self.head = 0 ;
85628546 }
85638547
8564- int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8565- ggml_threadpool_t threadpool = n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8548+ const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8549+ if (!slot) {
8550+ return 1 ;
8551+ }
8552+ kv_slot_restorer.save (slot);
85668553
8567- GGML_ASSERT (n_threads > 0 );
8554+ if (!kv_self.recurrent ) {
8555+ // a heuristic, to avoid attending the full cache if it is not yet utilized
8556+ // after enough generations, the benefit from this heuristic disappears
8557+ // if we start defragmenting the cache, the benefit from this will be more important
8558+ const uint32_t pad = llama_kv_cache_get_padding (cparams);
8559+ kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8560+ // kv_self.n = llama_kv_cache_cell_max(kv_self);
8561+ }
8562+ }
85688563
8569- // non-causal masks do not use the KV cache
8570- if (hparams.causal_attn ) {
8571- llama_kv_cache_update (&lctx);
8564+ return 0 ;
8565+ }
85728566
8573- // if we have enough unused cells before the current head ->
8574- // better to start searching from the beginning of the cache, hoping to fill it
8575- if (kv_self.head > kv_self.used + 2 *n_tokens) {
8576- kv_self.head = 0 ;
8577- }
8567+ // decode a batch of tokens by evaluating the transformer
8568+ // in case of unsuccessful decoding (error or warning),
8569+ // the kv_cache state will be returned to its original state
8570+ // (for non-recurrent models) or cleaned (for recurrent models)
8571+ //
8572+ // - lctx: llama context
8573+ // - inp_batch: batch to evaluate
8574+ //
8575+ // return 0 on success
8576+ // return positive int on warning
8577+ // return negative int on error
8578+ //
8579+ static int llama_decode_impl (
8580+ llama_context & lctx,
8581+ llama_batch inp_batch) {
85788582
8579- const auto slot = llama_kv_cache_find_slot (kv_self, ubatch);
8580- if (!slot) {
8581- return 1 ;
8582- }
8583- kv_slot_restorer.save (slot);
8583+ lctx.is_encoding = false ;
85848584
8585- if (!kv_self.recurrent ) {
8586- // a heuristic, to avoid attending the full cache if it is not yet utilized
8587- // after enough generations, the benefit from this heuristic disappears
8588- // if we start defragmenting the cache, the benefit from this will be more important
8589- const uint32_t pad = llama_kv_cache_get_padding (cparams);
8590- kv_self.n = std::min (kv_self.size , std::max (pad, GGML_PAD (llama_kv_cache_cell_max (kv_self), pad)));
8591- // kv_self.n = llama_kv_cache_cell_max(kv_self);
8585+ if (inp_batch.n_tokens == 0 ) {
8586+ LLAMA_LOG_ERROR (" %s: n_tokens == 0\n " , __func__);
8587+ return -1 ;
8588+ }
8589+
8590+ // temporarily allocate memory for the input batch if needed
8591+ llama_batch_allocr batch_allocr (inp_batch, inp_batch.pos ? -1 : lctx.kv_self .max_pos () + 1 );
8592+ const llama_batch & batch = batch_allocr.batch ;
8593+
8594+ const auto & model = lctx.model ;
8595+ const auto & vocab = model.vocab ;
8596+ const auto & hparams = model.hparams ;
8597+ const auto & cparams = lctx.cparams ;
8598+
8599+ if (lctx.t_compute_start_us == 0 ) {
8600+ lctx.t_compute_start_us = ggml_time_us ();
8601+ }
8602+ auto & kv_self = lctx.kv_self ;
8603+ llama_kv_slot_restorer kv_slot_restorer (kv_self);
8604+
8605+ const int64_t n_embd = hparams.n_embd ;
8606+ const int64_t n_vocab = vocab.n_tokens ();
8607+
8608+ uint32_t n_outputs = 0 ;
8609+ uint32_t n_outputs_prev = 0 ;
8610+
8611+ {
8612+ const int ret = llama_prepare_sbatch (lctx, batch, n_outputs);
8613+ if (ret != 0 ) {
8614+ return ret;
8615+ }
8616+ }
8617+
8618+ while (lctx.sbatch .n_tokens > 0 ) {
8619+ llama_ubatch ubatch;
8620+ {
8621+ const int ret = llama_prepare_ubatch (lctx, kv_slot_restorer, ubatch, n_outputs, batch.n_tokens );
8622+ if (ret != 0 ) {
8623+ return ret;
85928624 }
85938625 }
85948626
8627+ const int n_threads = ubatch.n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch ;
8628+ ggml_threadpool_t threadpool = ubatch.n_tokens == 1 ? lctx.threadpool : lctx.threadpool_batch ;
8629+
8630+ GGML_ASSERT (n_threads > 0 );
8631+
85958632 // printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
85968633
85978634 ggml_backend_sched_reset (lctx.sched .get ());
@@ -8644,7 +8681,7 @@ static int llama_decode_impl(
86448681
86458682 // update the kv ring buffer
86468683 {
8647- kv_self.head += n_tokens;
8684+ kv_self.head += ubatch. n_tokens ;
86488685
86498686 // Ensure kv cache head points to a valid index.
86508687 if (kv_self.head >= kv_self.size ) {
0 commit comments