Skip to content

Commit c610b6c

Browse files
authored
kv-cache : fix SWA checks + disable cacheless iSWA (#15811)
ggml-ci
1 parent 5d6688d commit c610b6c

File tree

9 files changed

+29
-11
lines changed

9 files changed

+29
-11
lines changed

src/llama-graph.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,9 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
297297

298298
float * data = (float *) kq_mask->data;
299299

300+
// [TAG_NO_CACHE_ISWA]
301+
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "TODO: implement");
302+
300303
for (int h = 0; h < 1; ++h) {
301304
for (int i1 = 0; i1 < n_tokens; ++i1) {
302305
const llama_seq_id s1 = ubatch->seq_id[i1][0];
@@ -315,9 +318,10 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
315318
continue; // skip future tokens for causal attention
316319
}
317320

318-
if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
319-
continue; // skip masked tokens for SWA
320-
}
321+
// TODO: this does not take into account that some layers are SWA and others are note (i.e. iSWA) [TAG_NO_CACHE_ISWA]
322+
//if (hparams.is_masked_swa(ubatch->pos[i0], ubatch->pos[i1])) {
323+
// continue; // skip masked tokens for SWA
324+
//}
321325

322326
// TODO: reimplement this like in llama_kv_cache_unified
323327
if (hparams.use_alibi) {

src/llama-hparams.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ uint32_t llama_hparams::n_layer_kv() const {
180180
return res;
181181
}
182182

183-
bool llama_hparams::is_masked_swa(llama_pos p0, llama_pos p1) const {
183+
bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) {
184184
assert(p0 >= 0 && p1 >= 0);
185185

186186
switch (swa_type) {

src/llama-hparams.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ struct llama_hparams {
229229
// number of layers for which has_kv() returns true
230230
uint32_t n_layer_kv() const;
231231

232-
bool is_masked_swa(llama_pos p0, llama_pos p1) const;
232+
// note that this function uses different SWA parameters from those in the hparams
233+
// TODO: think of a better place for this function
234+
// TODO: pack the SWA params in a struct?
235+
static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1);
233236
};
234237

235238
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

src/llama-kv-cache-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ llama_kv_cache_iswa::llama_kv_cache_iswa(
6060
kv_base = std::make_unique<llama_kv_cache>(
6161
model, type_k, type_v,
6262
v_trans, offload, unified, size_base, n_seq_max, n_pad,
63-
0, filter_base, reuse);
63+
0, LLAMA_SWA_TYPE_NONE, filter_base, reuse);
6464

6565
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
6666

6767
kv_swa = std::make_unique<llama_kv_cache>(
6868
model, type_k, type_v,
6969
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
70-
hparams.n_swa, filter_swa, reuse);
70+
hparams.n_swa, hparams.swa_type, filter_swa, reuse);
7171
}
7272

7373
void llama_kv_cache_iswa::clear(bool data) {

src/llama-kv-cache.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@ llama_kv_cache::llama_kv_cache(
2727
uint32_t n_seq_max,
2828
uint32_t n_pad,
2929
uint32_t n_swa,
30+
llama_swa_type swa_type,
3031
const layer_filter_cb & filter,
3132
const layer_reuse_cb & reuse) :
3233
model(model), hparams(model.hparams), v_trans(v_trans),
33-
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa) {
34+
n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
3435

3536
GGML_ASSERT(kv_size % n_pad == 0);
3637

@@ -1392,7 +1393,7 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
13921393
}
13931394

13941395
bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const {
1395-
return hparams.is_masked_swa(p0, p1);
1396+
return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1);
13961397
}
13971398

13981399
void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {

src/llama-kv-cache.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class llama_kv_cache : public llama_memory_i {
8989
uint32_t n_seq_max,
9090
uint32_t n_pad,
9191
uint32_t n_swa,
92+
llama_swa_type swa_type,
9293
const layer_filter_cb & filter,
9394
const layer_reuse_cb & reuse);
9495

@@ -211,6 +212,9 @@ class llama_kv_cache : public llama_memory_i {
211212
// env: LLAMA_KV_CACHE_DEBUG
212213
int debug = 0;
213214

215+
// this is the SWA type of the cache - not to be confused with the model SWA type
216+
const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
217+
214218
std::vector<ggml_context_ptr> ctxs;
215219
std::vector<ggml_backend_buffer_ptr> bufs;
216220

src/llama-memory-hybrid.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ llama_memory_hybrid::llama_memory_hybrid(
1717
uint32_t kv_size,
1818
uint32_t n_pad,
1919
uint32_t n_swa,
20+
llama_swa_type swa_type,
2021
/* recurrent */
2122
ggml_type type_r,
2223
ggml_type type_s,
@@ -40,6 +41,7 @@ llama_memory_hybrid::llama_memory_hybrid(
4041
n_seq_max,
4142
n_pad,
4243
n_swa,
44+
swa_type,
4345
filter_attn == nullptr ?
4446
[&](int32_t il) { return !hparams.is_recurrent(il); }
4547
: filter_attn,

src/llama-memory-hybrid.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class llama_memory_hybrid : public llama_memory_i {
2727
uint32_t kv_size,
2828
uint32_t n_pad,
2929
uint32_t n_swa,
30+
llama_swa_type swa_type,
3031
/* recurrent */
3132
ggml_type type_r,
3233
ggml_type type_s,

src/llama-model.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11084,7 +11084,8 @@ struct llm_build_gemma_embedding_iswa : public llm_graph_context {
1108411084
// inp_pos - contains the positions
1108511085
ggml_tensor * inp_pos = build_inp_pos();
1108611086

11087-
auto * inp_attn = build_attn_inp_no_cache();
11087+
// TODO: support cacheless iSWA embeddings [TAG_NO_CACHE_ISWA]
11088+
auto * inp_attn = build_attn_inp_kv_iswa();
1108811089

1108911090
ggml_tensor * inp_out_ids = build_inp_out_ids();
1109011091

@@ -18632,7 +18633,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1863218633
case LLM_ARCH_NOMIC_BERT_MOE:
1863318634
case LLM_ARCH_NEO_BERT:
1863418635
case LLM_ARCH_WAVTOKENIZER_DEC:
18635-
case LLM_ARCH_GEMMA_EMBEDDING:
18636+
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
1863618637
case LLM_ARCH_DREAM:
1863718638
case LLM_ARCH_LLADA:
1863818639
{
@@ -18681,6 +18682,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1868118682
/* attn_kv_size */ cparams.n_ctx,
1868218683
/* attn_n_pad */ padding,
1868318684
/* attn_n_swa */ hparams.n_swa,
18685+
/* attn_swa_type */ hparams.swa_type,
1868418686
/* recurrent_type_k */ GGML_TYPE_F32,
1868518687
/* recurrent_type_v */ GGML_TYPE_F32,
1868618688
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
@@ -18750,6 +18752,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
1875018752
cparams.n_seq_max,
1875118753
padding,
1875218754
hparams.n_swa,
18755+
hparams.swa_type,
1875318756
nullptr,
1875418757
nullptr);
1875518758
}

0 commit comments

Comments
 (0)