@@ -263,6 +263,9 @@ bool llm_graph_input_rs::can_reuse(const llm_graph_params & params) {
263263 res &= s_copy_main->ne [0 ] == params.ubatch .n_seqs ;
264264 res &= s_copy_extra->ne [0 ] == mctx->get_n_rs () - params.ubatch .n_seqs ;
265265
266+ res &= head == mctx->get_head ();
267+ res &= rs_z == mctx->get_rs_z ();
268+
266269 return res;
267270}
268271
@@ -487,6 +490,9 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
487490 res &= inp_rs->s_copy_main ->ne [0 ] == params.ubatch .n_seqs ;
488491 res &= inp_rs->s_copy_extra ->ne [0 ] == mctx->get_recr ()->get_n_rs () - params.ubatch .n_seqs ;
489492
493+ res &= inp_rs->head == mctx->get_recr ()->get_head ();
494+ res &= inp_rs->rs_z == mctx->get_recr ()->get_rs_z ();
495+
490496 return res;
491497}
492498
@@ -1827,6 +1833,9 @@ static std::unique_ptr<llm_graph_input_rs> build_rs_inp_impl(
18271833 inp->s_copy_main = ggml_view_1d (ctx0, inp->s_copy , n_seqs, 0 );
18281834 inp->s_copy_extra = ggml_view_1d (ctx0, inp->s_copy , n_rs - n_seqs, n_seqs * inp->s_copy ->nb [0 ]);
18291835
1836+ inp->head = mctx_cur->get_head ();
1837+ inp->rs_z = mctx_cur->get_rs_z ();
1838+
18301839 return inp;
18311840}
18321841
@@ -1895,7 +1904,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
18951904llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid () const {
18961905 const auto * mctx_cur = static_cast <const llama_memory_hybrid_context *>(mctx);
18971906
1898- auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
1907+ auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr ());
18991908 auto inp_attn = build_attn_inp_kv_impl (ctx0, ubatch, hparams, cparams, mctx_cur->get_attn ());
19001909
19011910 auto inp = std::make_unique<llm_graph_input_mem_hybrid>(cparams, std::move (inp_attn), std::move (inp_rs), mctx_cur);
0 commit comments