@@ -235,6 +235,12 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
235235 }
236236}
237237
238+ llm_graph_input_rs::llm_graph_input_rs (const llama_memory_recurrent_context * mctx) :
239+ mctx(mctx),
240+ head(mctx->get_head ()),
241+ rs_z(mctx->get_rs_z ()) {
242+ }
243+
238244void llm_graph_input_rs::set_input (const llama_ubatch * ubatch) {
239245 GGML_UNUSED (ubatch);
240246
@@ -254,9 +260,6 @@ void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) {
254260bool llm_graph_input_rs::can_reuse (const llm_graph_params & params) {
255261 const auto * mctx = static_cast <const llama_memory_recurrent_context *>(params.mctx );
256262
257- const auto prev_head = this ->mctx ->get_head ();
258- const auto prev_rs_z = this ->mctx ->get_rs_z ();
259-
260263 this ->mctx = mctx;
261264
262265 bool res = true ;
@@ -266,8 +269,8 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
266269 res &= s_copy_main->ne [0 ] == params.ubatch .n_seqs ;
267270 res &= s_copy_extra->ne [0 ] == mctx->get_n_rs () - params.ubatch .n_seqs ;
268271
269- res &= prev_head == mctx->get_head ();
270- res &= prev_rs_z == mctx->get_rs_z ();
272+ res &= this -> head == mctx->get_head ();
273+ res &= this -> rs_z == mctx->get_rs_z ();
271274
272275 return res;
273276}
@@ -478,9 +481,6 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
478481bool llm_graph_input_mem_hybrid::can_reuse (const llm_graph_params & params) {
479482 const auto * mctx = static_cast <const llama_memory_hybrid_context *>(params.mctx );
480483
481- const auto prev_head = this ->mctx ->get_recr ()->get_head ();
482- const auto prev_rs_z = this ->mctx ->get_recr ()->get_rs_z ();
483-
484484 this ->mctx = mctx;
485485
486486 bool res = true ;
@@ -496,8 +496,8 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
496496 res &= inp_rs->s_copy_main ->ne [0 ] == params.ubatch .n_seqs ;
497497 res &= inp_rs->s_copy_extra ->ne [0 ] == mctx->get_recr ()->get_n_rs () - params.ubatch .n_seqs ;
498498
499- res &= prev_head == mctx->get_recr ()->get_head ();
500- res &= prev_rs_z == mctx->get_recr ()->get_rs_z ();
499+ res &= inp_rs-> head == mctx->get_recr ()->get_head ();
500+ res &= inp_rs-> rs_z == mctx->get_recr ()->get_rs_z ();
501501
502502 return res;
503503}
0 commit comments