@@ -443,11 +443,38 @@ struct llm_graph_params {
443
443
// TODO: temporary
444
444
llm_graph_result_i * res;
445
445
446
- bool is_same (const llm_graph_params & other) const {
446
+ // return true if the "other" params would result in a graph with the same topology as with the current params
447
+ // having the same topology allows us to reuse the graph in some cases
448
+ bool allow_reuse (const llm_graph_params & other) const {
449
+ // first check the ubatch
450
+ bool can_reuse_ubatch =
451
+ ubatch.equal_seqs == other.ubatch .equal_seqs &&
452
+ ubatch.n_tokens == other.ubatch .n_tokens &&
453
+ ubatch.n_seq_tokens == other.ubatch .n_seq_tokens &&
454
+ ubatch.n_seqs == other.ubatch .n_seqs &&
455
+ ubatch.n_seqs_unq == other.ubatch .n_seqs_unq &&
456
+ (
457
+ (!ubatch.token && !other.ubatch .token ) ||
458
+ (!ubatch.embd && !other.ubatch .embd )
459
+ );
460
+
461
+ // TODO: this won't work because seq_id_unq ptr can point to an old balloc that has
462
+ // been freed by this point. find a way to fix this
463
+ // for (uint32_t s = 0; s < n_seqs_unq; ++s) {
464
+ // can_reuse_ubatch &= seq_id_unq[s] == other.seq_id_unq[s];
465
+ // }
466
+
467
+ // for now conservatively disallow, until the issue above is resolved
468
+ // ref: https://github.com/ggml-org/llama.cpp/pull/14363
469
+ can_reuse_ubatch = can_reuse_ubatch && !ubatch.equal_seqs ;
470
+
471
+ if (!can_reuse_ubatch) {
472
+ return false ;
473
+ }
474
+
447
475
return
448
- hparams.is_same (other.hparams ) &&
449
- cparams.is_same (other.cparams ) &&
450
- ubatch .is_same (other.ubatch ) &&
476
+ cparams.embeddings == other.cparams .embeddings &&
477
+ cparams.causal_attn == other.cparams .causal_attn &&
451
478
arch == other.arch &&
452
479
gtype == other.gtype &&
453
480
cvec == other.cvec &&
@@ -510,7 +537,7 @@ class llm_graph_result : public llm_graph_result_i {
510
537
// contexts of the input tensors of the graph and we can reuse it for another computation
511
538
// return true if the graph was updated and can be reused
512
539
bool can_reuse (const llm_graph_params & params) override {
513
- if (!this ->params .is_same (params)) {
540
+ if (!this ->params .allow_reuse (params)) {
514
541
return false ;
515
542
}
516
543
0 commit comments