@@ -33,9 +33,6 @@ llama_context::llama_context(
33
33
throw std::runtime_error (" n_seq_max must be <= " + std::to_string (LLAMA_MAX_SEQ));
34
34
}
35
35
36
- const char * LLAMA_HT = getenv (" LLAMA_HT" );
37
- cparams.kv_unified = (LLAMA_HT && atoi (LLAMA_HT) > 0 ) ? false : true ;
38
-
39
36
cparams.n_threads = params.n_threads ;
40
37
cparams.n_threads_batch = params.n_threads_batch ;
41
38
cparams.yarn_ext_factor = params.yarn_ext_factor ;
@@ -104,7 +101,8 @@ llama_context::llama_context(
104
101
105
102
cparams.n_ubatch = std::min (cparams.n_batch , params.n_ubatch == 0 ? params.n_batch : params.n_ubatch );
106
103
107
- cparams.op_offload = params.op_offload ;
104
+ cparams.op_offload = params.op_offload ;
105
+ cparams.attn_streams = params.attn_streams ;
108
106
109
107
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max ;
110
108
@@ -115,6 +113,7 @@ llama_context::llama_context(
115
113
LLAMA_LOG_INFO (" %s: n_ubatch = %u\n " , __func__, cparams.n_ubatch );
116
114
LLAMA_LOG_INFO (" %s: causal_attn = %d\n " , __func__, cparams.causal_attn );
117
115
LLAMA_LOG_INFO (" %s: flash_attn = %d\n " , __func__, cparams.flash_attn );
116
+ LLAMA_LOG_INFO (" %s: attn_streams = %s\n " , __func__, cparams.attn_streams ? " true" : " false" );
118
117
LLAMA_LOG_INFO (" %s: freq_base = %.1f\n " , __func__, cparams.rope_freq_base );
119
118
LLAMA_LOG_INFO (" %s: freq_scale = %g\n " , __func__, cparams.rope_freq_scale );
120
119
@@ -270,7 +269,7 @@ llama_context::llama_context(
270
269
271
270
// reserve worst-case graph
272
271
if (!hparams.vocab_only && memory) {
273
- const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max ;
272
+ const uint32_t n_seqs = cparams.attn_streams ? cparams.n_seq_max : 1 ;
274
273
const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
275
274
276
275
LLAMA_LOG_DEBUG (" %s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n " , __func__, n_tokens, n_seqs, n_outputs);
@@ -314,6 +313,10 @@ llama_context::llama_context(
314
313
315
314
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
316
315
{
316
+ // TODO: not sure if the following graph would be worster case for multi-stream KV caches:
317
+ //
318
+ // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
319
+ //
317
320
auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
318
321
if (!gf) {
319
322
throw std::runtime_error (" failed to allocate compute pp buffers" );
@@ -478,7 +481,7 @@ bool llama_context::kv_self_update(bool optimize) {
478
481
throw std::runtime_error (" failed to initialize memory context" );
479
482
}
480
483
481
- const uint32_t n_seqs = cparams.n_seq_max ;
484
+ const uint32_t n_seqs = cparams. attn_streams ? cparams.n_seq_max : 1 ;
482
485
const uint32_t n_tokens = std::min (cparams.n_ctx , cparams.n_ubatch );
483
486
484
487
auto * gf = graph_reserve (n_tokens, n_seqs, n_tokens, mctx.get ());
@@ -2192,6 +2195,7 @@ llama_context_params llama_context_default_params() {
2192
2195
/* .no_perf =*/ true ,
2193
2196
/* .op_offload =*/ true ,
2194
2197
/* .swa_full =*/ true ,
2198
+ /* .attn_streams =*/ false ,
2195
2199
};
2196
2200
2197
2201
return result;
0 commit comments