@@ -95,93 +95,77 @@ llama_pos llama_kv_cache_unified_iswa::seq_pos_max(llama_seq_id seq_id) const {
9595 return kv_swa->seq_pos_max (seq_id);
9696}
9797
98- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch (llama_batch_allocr & balloc , uint32_t n_ubatch, bool embd_all ) {
99- GGML_UNUSED (embd_all );
98+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_batch (const llama_batch & batch , uint32_t n_ubatch, bool embd_pooled, bool logits_all ) {
99+ GGML_UNUSED (embd_pooled );
100100
101101 // first try simple split
102102 do {
103- balloc. split_reset ( );
103+ auto sbatch = llama_sbatch (batch, hparams. n_embd , true , logits_all );
104104
105105 std::vector<llama_ubatch> ubatches;
106- while (true ) {
107- auto ubatch = balloc.split_simple (n_ubatch);
108106
109- if (ubatch.n_tokens == 0 ) {
110- break ;
111- }
107+ while (sbatch.n_tokens > 0 ) {
108+ auto ubatch = sbatch.split_simple (n_ubatch);
112109
113- ubatches.push_back (std::move ( ubatch)); // NOLINT
110+ ubatches.push_back (ubatch);
114111 }
115112
116- if (balloc. get_n_used () < balloc. get_n_tokens ()) {
117- // failed to find a suitable split
113+ auto heads_base = kv_base-> prepare (ubatches);
114+ if (heads_base. empty ()) {
118115 break ;
119116 }
120117
121- auto sinfos_base = kv_base ->prepare (ubatches);
122- if (sinfos_base .empty ()) {
118+ auto heads_swa = kv_swa ->prepare (ubatches);
119+ if (heads_swa .empty ()) {
123120 break ;
124121 }
125122
126- auto sinfos_swa = kv_swa->prepare (ubatches);
127- if (sinfos_swa.empty ()) {
128- break ;
129- }
123+ assert (heads_base.size () == heads_swa.size ());
130124
131- assert (sinfos_base.size () == sinfos_swa.size ());
132-
133- return std::make_unique<llama_kv_cache_unified_iswa_context>(
134- this , std::move (sinfos_base), std::move (sinfos_swa), std::move (ubatches));
125+ return std::make_unique<llama_kv_cache_unified_iswa_state>(
126+ this , std::move (sbatch), std::move (heads_base), std::move (heads_swa), std::move (ubatches));
135127 } while (false );
136128
137129 // if it fails, try equal split
138130 do {
139- balloc. split_reset ( );
131+ auto sbatch = llama_sbatch (batch, hparams. n_embd , false , logits_all );
140132
141133 std::vector<llama_ubatch> ubatches;
142- while (true ) {
143- auto ubatch = balloc.split_equal (n_ubatch, false );
144134
145- if (ubatch.n_tokens == 0 ) {
146- break ;
147- }
135+ while (sbatch.n_tokens > 0 ) {
136+ auto ubatch = sbatch.split_equal (n_ubatch);
148137
149- ubatches.push_back (std::move ( ubatch)); // NOLINT
138+ ubatches.push_back (ubatch);
150139 }
151140
152- if (balloc. get_n_used () < balloc. get_n_tokens ()) {
153- // failed to find a suitable split
141+ auto heads_base = kv_base-> prepare (ubatches);
142+ if (heads_base. empty ()) {
154143 break ;
155144 }
156145
157- auto sinfos_base = kv_base ->prepare (ubatches);
158- if (sinfos_base .empty ()) {
146+ auto heads_swa = kv_swa ->prepare (ubatches);
147+ if (heads_swa .empty ()) {
159148 break ;
160149 }
161150
162- auto sinfos_swa = kv_swa->prepare (ubatches);
163- if (sinfos_swa.empty ()) {
164- break ;
165- }
166-
167- assert (sinfos_base.size () == sinfos_swa.size ());
151+ assert (heads_base.size () == heads_swa.size ());
168152
169- return std::make_unique<llama_kv_cache_unified_iswa_context >(
170- this , std::move (sinfos_base ), std::move (sinfos_swa ), std::move (ubatches));
153+ return std::make_unique<llama_kv_cache_unified_iswa_state >(
154+ this , std::move (sbatch ), std::move (heads_base), std::move (heads_swa ), std::move (ubatches));
171155 } while (false );
172156
173157 // TODO: if we fail again, we should attempt different splitting strategies
174158 // but to do that properly, we first have to refactor the batches to be more flexible
175159
176- return std::make_unique<llama_kv_cache_unified_iswa_context >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
160+ return std::make_unique<llama_kv_cache_unified_iswa_state >(LLAMA_MEMORY_STATUS_FAILED_PREPARE);
177161}
178162
179- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_full () {
180- return std::make_unique<llama_kv_cache_unified_iswa_context >(this );
163+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_full () {
164+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this );
181165}
182166
183- llama_memory_context_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
184- return std::make_unique<llama_kv_cache_unified_iswa_context >(this , lctx, optimize);
167+ llama_memory_state_ptr llama_kv_cache_unified_iswa::init_update (llama_context * lctx, bool optimize) {
168+ return std::make_unique<llama_kv_cache_unified_iswa_state >(this , lctx, optimize);
185169}
186170
187171bool llama_kv_cache_unified_iswa::get_can_shift () const {
@@ -207,46 +191,52 @@ llama_kv_cache_unified * llama_kv_cache_unified_iswa::get_swa() const {
207191}
208192
209193//
210- // llama_kv_cache_unified_iswa_context
194+ // llama_kv_cache_unified_iswa_state
211195//
212196
213- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (llama_memory_status status) : status(status) {}
197+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (llama_memory_status status) : status(status) {}
198+
199+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
200+ llama_kv_cache_unified_iswa * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
201+ state_base = kv->get_base ()->init_full ();
202+ state_swa = kv->get_swa ()->init_full ();
214203
215- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
216- llama_kv_cache_unified_iswa * kv) :
217- ctx_base(kv->get_base ()->init_full()),
218- ctx_swa (kv->get_swa ()->init_full()),
219- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
204+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
220205}
221206
222- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
207+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
223208 llama_kv_cache_unified_iswa * kv,
224209 llama_context * lctx,
225- bool optimize) :
226- ctx_base(kv->get_base ()->init_update(lctx, optimize)),
227- ctx_swa (kv->get_swa ()->init_update(lctx, optimize)),
228- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
210+ bool optimize) : status(LLAMA_MEMORY_STATUS_SUCCESS) {
211+ state_base = kv->get_base ()->init_update (lctx, optimize);
212+ state_swa = kv->get_swa ()->init_update (lctx, optimize);
213+
214+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
229215}
230216
231- llama_kv_cache_unified_iswa_context::llama_kv_cache_unified_iswa_context (
217+ llama_kv_cache_unified_iswa_state::llama_kv_cache_unified_iswa_state (
232218 llama_kv_cache_unified_iswa * kv,
233- slot_info_vec_t sinfos_base,
234- slot_info_vec_t sinfos_swa,
235- std::vector<llama_ubatch> ubatches) :
236- ubatches(std::move(ubatches)),
219+ llama_sbatch sbatch,
220+ std::vector<uint32_t > heads_base,
221+ std::vector<uint32_t > heads_swa,
222+ std::vector<llama_ubatch> ubatches)
223+ : status(LLAMA_MEMORY_STATUS_SUCCESS),
224+ sbatch(std::move(sbatch)),
225+ ubatches(std::move(ubatches)) {
237226 // note: here we copy the ubatches. not sure if this is ideal
238- ctx_base(new llama_kv_cache_unified_context(kv->get_base (), std::move(sinfos_base), this->ubatches)),
239- ctx_swa (new llama_kv_cache_unified_context(kv->get_swa (), std::move(sinfos_swa), this->ubatches)),
240- status(llama_memory_status_combine(ctx_base->get_status (), ctx_swa->get_status())) {
227+ state_base.reset (new llama_kv_cache_unified_state (kv->get_base (), {}, std::move (heads_base), this ->ubatches ));
228+ state_swa .reset (new llama_kv_cache_unified_state (kv->get_swa (), {}, std::move (heads_swa), this ->ubatches ));
229+
230+ status = llama_memory_status_combine (state_base->get_status (), state_swa->get_status ());
241231}
242232
243- llama_kv_cache_unified_iswa_context :: ~llama_kv_cache_unified_iswa_context () = default ;
233+ llama_kv_cache_unified_iswa_state :: ~llama_kv_cache_unified_iswa_state () = default ;
244234
245- bool llama_kv_cache_unified_iswa_context ::next () {
235+ bool llama_kv_cache_unified_iswa_state ::next () {
246236 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
247237
248- ctx_base ->next ();
249- ctx_swa ->next ();
238+ state_base ->next ();
239+ state_swa ->next ();
250240
251241 if (++i_next >= ubatches.size ()) {
252242 return false ;
@@ -255,35 +245,41 @@ bool llama_kv_cache_unified_iswa_context::next() {
255245 return true ;
256246}
257247
258- bool llama_kv_cache_unified_iswa_context ::apply () {
259- assert (! llama_memory_status_is_fail ( status) );
248+ bool llama_kv_cache_unified_iswa_state ::apply () {
249+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS );
260250
261251 bool res = true ;
262252
263- res = res & ctx_base ->apply ();
264- res = res & ctx_swa ->apply ();
253+ res = res & state_base ->apply ();
254+ res = res & state_swa ->apply ();
265255
266256 return res;
267257}
268258
269- llama_memory_status llama_kv_cache_unified_iswa_context::get_status () const {
259+ std::vector<int64_t > & llama_kv_cache_unified_iswa_state::out_ids () {
260+ assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
261+
262+ return sbatch.out_ids ;
263+ }
264+
265+ llama_memory_status llama_kv_cache_unified_iswa_state::get_status () const {
270266 return status;
271267}
272268
273- const llama_ubatch & llama_kv_cache_unified_iswa_context ::get_ubatch () const {
269+ const llama_ubatch & llama_kv_cache_unified_iswa_state ::get_ubatch () const {
274270 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
275271
276272 return ubatches[i_next];
277273}
278274
279- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_base () const {
275+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_base () const {
280276 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
281277
282- return static_cast <const llama_kv_cache_unified_context *>(ctx_base .get ());
278+ return static_cast <const llama_kv_cache_unified_state *>(state_base .get ());
283279}
284280
285- const llama_kv_cache_unified_context * llama_kv_cache_unified_iswa_context ::get_swa () const {
281+ const llama_kv_cache_unified_state * llama_kv_cache_unified_iswa_state ::get_swa () const {
286282 assert (status == LLAMA_MEMORY_STATUS_SUCCESS);
287283
288- return static_cast <const llama_kv_cache_unified_context *>(ctx_swa .get ());
284+ return static_cast <const llama_kv_cache_unified_state *>(state_swa .get ());
289285}
0 commit comments