Skip to content

Commit 18212b0

Browse files
committed
cont : fix
1 parent bbcda78 commit 18212b0

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

src/llama-graph.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
235235
}
236236
}
237237

238+
llm_graph_input_rs::llm_graph_input_rs(const llama_memory_recurrent_context * mctx) :
239+
mctx(mctx),
240+
head(mctx->get_head()),
241+
rs_z(mctx->get_rs_z()) {
242+
}
243+
238244
void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
239245
GGML_UNUSED(ubatch);
240246

@@ -254,9 +260,6 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
254260
bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
255261
const auto * mctx = static_cast<const llama_memory_recurrent_context *>(params.mctx);
256262

257-
const auto prev_head = this->mctx->get_head();
258-
const auto prev_rs_z = this->mctx->get_rs_z();
259-
260263
this->mctx = mctx;
261264

262265
bool res = true;
@@ -266,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
266269
res &= s_copy_main->ne[0] == params.ubatch.n_seqs;
267270
res &= s_copy_extra->ne[0] == mctx->get_n_rs() - params.ubatch.n_seqs;
268271

269-
res &= prev_head == mctx->get_head();
270-
res &= prev_rs_z == mctx->get_rs_z();
272+
res &= this->head == mctx->get_head();
273+
res &= this->rs_z == mctx->get_rs_z();
271274

272275
return res;
273276
}
@@ -478,9 +481,6 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
478481
bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
479482
const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx);
480483

481-
const auto prev_head = this->mctx->get_recr()->get_head();
482-
const auto prev_rs_z = this->mctx->get_recr()->get_rs_z();
483-
484484
this->mctx = mctx;
485485

486486
bool res = true;
@@ -496,8 +496,8 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
496496
res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs;
497497
res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs;
498498

499-
res &= prev_head == mctx->get_recr()->get_head();
500-
res &= prev_rs_z == mctx->get_recr()->get_rs_z();
499+
res &= inp_rs->head == mctx->get_recr()->get_head();
500+
res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z();
501501

502502
return res;
503503
}

src/llama-graph.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ class llm_graph_input_cls : public llm_graph_input_i {
219219

220220
class llm_graph_input_rs : public llm_graph_input_i {
221221
public:
222-
llm_graph_input_rs(const llama_memory_recurrent_context * mctx) : mctx(mctx) {}
222+
llm_graph_input_rs(const llama_memory_recurrent_context * mctx);
223223
virtual ~llm_graph_input_rs() = default;
224224

225225
void set_input(const llama_ubatch * ubatch) override;
@@ -234,6 +234,10 @@ class llm_graph_input_rs : public llm_graph_input_i {
234234
ggml_tensor * s_copy_extra; // I32 [n_rs - n_seqs]
235235

236236
const llama_memory_recurrent_context * mctx;
237+
238+
// need to match for valid graph reuse
239+
const uint32_t head;
240+
const int32_t rs_z;
237241
};
238242

239243
class llm_graph_input_cross_embd : public llm_graph_input_i {

0 commit comments

Comments
 (0)