Skip to content

Commit 16bcc12

Browse files
authored
kv-cache : pad the cache size to 256 for performance (#17046)
* kv-cache : pad the size of the small SWA cache for performance * context : pad the total context to 256 * cont : future-proof the swa pad * server : adjust test params to new logic
1 parent 9eb9a13 commit 16bcc12

File tree

4 files changed

+14
-7
lines changed

4 files changed

+14
-7
lines changed

include/llama.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,7 @@ extern "C" {
463463

464464
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
465465
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
466+
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
466467
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
467468
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
468469
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);

src/llama-context.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,14 @@ llama_context::llama_context(
114114
}
115115
}
116116

117+
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
118+
cparams.n_ctx = GGML_PAD(cparams.n_ctx, 256);
119+
117120
if (cparams.kv_unified) {
118121
cparams.n_ctx_seq = cparams.n_ctx;
119122
} else {
120123
cparams.n_ctx_seq = cparams.n_ctx / cparams.n_seq_max;
124+
cparams.n_ctx_seq = GGML_PAD(cparams.n_ctx_seq, 256);
121125

122126
if (cparams.n_ctx_seq == 0) {
123127
throw std::runtime_error("n_ctx_seq == 0");

src/llama-kv-cache-iswa.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
4545

4646
const uint32_t size_base = kv_size;
4747

48-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
48+
// note: the SWA cache is always padded to 256 for performance
49+
// https://github.com/ggml-org/llama.cpp/issues/17037
50+
uint32_t size_swa = GGML_PAD(std::min(size_base, hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch), 256);
4951

5052
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
5153
if (swa_full) {

tools/server/tests/unit/test_speculative.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,10 @@ def test_different_draft_min_draft_max():
7777

7878
def test_slot_ctx_not_exceeded():
7979
global server
80-
server.n_ctx = 64
80+
server.n_ctx = 256
8181
server.start()
8282
res = server.make_request("POST", "/completion", data={
83-
"prompt": "Hello " * 56,
83+
"prompt": "Hello " * 248,
8484
"temperature": 0.0,
8585
"top_k": 1,
8686
"speculative.p_min": 0.0,
@@ -91,19 +91,19 @@ def test_slot_ctx_not_exceeded():
9191

9292
def test_with_ctx_shift():
9393
global server
94-
server.n_ctx = 64
94+
server.n_ctx = 256
9595
server.enable_ctx_shift = True
9696
server.start()
9797
res = server.make_request("POST", "/completion", data={
98-
"prompt": "Hello " * 56,
98+
"prompt": "Hello " * 248,
9999
"temperature": 0.0,
100100
"top_k": 1,
101-
"n_predict": 64,
101+
"n_predict": 256,
102102
"speculative.p_min": 0.0,
103103
})
104104
assert res.status_code == 200
105105
assert len(res.body["content"]) > 0
106-
assert res.body["tokens_predicted"] == 64
106+
assert res.body["tokens_predicted"] == 256
107107
assert res.body["truncated"] == True
108108

109109

0 commit comments

Comments
 (0)