@@ -101,8 +101,8 @@ llama_context::llama_context(
101
101
102
102
cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
103
103
104
- cparams.op_offload = params.op_offload ;
105
- cparams.attn_streams = params.attn_streams ;
104
+ cparams.op_offload = params.op_offload ;
105
+ cparams.kv_unified = params.kv_unified ;
106
106
107
107
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
108
108
@@ -113,7 +113,7 @@ llama_context::llama_context(
113
113
LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
114
114
LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
115
115
LLAMA_LOG_INFO (" %s: flash_attn = %d\n " , __func__, cparams.flash_attn );
116
- LLAMA_LOG_INFO (" %s: attn_streams = %s\n " , __func__, cparams.attn_streams ? " true" : " false" );
116
+ LLAMA_LOG_INFO (" %s: kv_unified = %s\n " , __func__, cparams.kv_unified ? " true" : " false" );
117
117
LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
118
118
LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
119
119
@@ -269,7 +269,7 @@ llama_context::llama_context(
269
269
270
270
// reserve worst-case graph
271
271
if (!hparams.vocab_only && memory) {
272
- const uint32_t n_seqs = cparams.attn_streams ? cparams. n_seq_max : 1 ;
272
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams. n_seq_max ;
273
273
const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
274
274
275
275
LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
@@ -481,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) {
481
481
throw std::runtime_error (" failed to initialize memory context" );
482
482
}
483
483
484
- const uint32_t n_seqs = cparams.attn_streams ? cparams. n_seq_max : 1 ;
484
+ const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams. n_seq_max ;
485
485
const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
486
486
487
487
auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
@@ -740,7 +740,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
740
740
const int64_t n_embd = hparams.n_embd ;
741
741
742
742
// note: during encode, we always pass the full sequence starting from pos = 0
743
- if (!balloc->init (batch_inp, model.vocab , nullptr , n_embd, cparams.attn_streams ? cparams. n_seq_max : LLAMA_MAX_SEQ , true )) {
743
+ if (!balloc->init (batch_inp, model.vocab , nullptr , n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams. n_seq_max , true )) {
744
744
LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
745
745
return -1 ;
746
746
}
@@ -907,7 +907,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
907
907
// when computing embeddings, all tokens are output
908
908
const bool output_all = cparams.embeddings ;
909
909
910
- if (!balloc->init (batch_inp, vocab, memory.get (), n_embd, cparams.attn_streams ? cparams. n_seq_max : LLAMA_MAX_SEQ , output_all)) {
910
+ if (!balloc->init (batch_inp, vocab, memory.get (), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams. n_seq_max , output_all)) {
911
911
LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
912
912
return -1 ;
913
913
}
@@ -2036,7 +2036,7 @@ void llama_context::opt_epoch_iter(
2036
2036
batch.logits [pos_batch] = true ;
2037
2037
}
2038
2038
2039
- if (!balloc->init (batch, model.vocab , nullptr , model.hparams .n_embd , cparams.attn_streams ? cparams. n_seq_max : LLAMA_MAX_SEQ , true )) {
2039
+ if (!balloc->init (batch, model.vocab , nullptr , model.hparams .n_embd , cparams.kv_unified ? LLAMA_MAX_SEQ : cparams. n_seq_max , true )) {
2040
2040
LLAMA_LOG_ERROR (" %s: failed to initialize batch\n " , __func__);
2041
2041
return ;
2042
2042
}
@@ -2195,7 +2195,7 @@ llama_context_params llama_context_default_params() {
2195
2195
/* .no_perf =*/ true ,
2196
2196
/* .op_offload =*/ true ,
2197
2197
/* .swa_full =*/ true ,
2198
- /* .attn_streams =*/ false ,
2198
+ /* .kv_unified =*/ true ,
2199
2199
};
2200
2200
2201
2201
return result;
0 commit comments