@@ -2251,6 +2251,8 @@ struct server_context {
22512251
22522252 id = common_sampler_sample (slot.smpl , ctx, slot.i_batch - i);
22532253
2254+ slot.i_batch = -1 ;
2255+
22542256 common_sampler_accept (slot.smpl , id, true );
22552257
22562258 slot.n_decoded += 1 ;
@@ -2277,73 +2279,64 @@ struct server_context {
22772279 slot.print_timings ();
22782280 send_final_response (slot);
22792281 metrics.on_prediction (slot);
2282+ continue ;
22802283 }
22812284 }
22822285
2283- slot.i_batch = -1 ;
2284-
2285- if (slot.ctx_dft ) {
2286- struct common_speculative_params params_spec;
2287- params_spec.n_draft = params.n_draft ;
2288- params_spec.n_reuse = 256 ;
2289- params_spec.p_min = 0 .9f ;
2290-
2291- llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
2286+ // check if the slot supports speculative decoding
2287+ if (!slot.ctx_dft ) {
2288+ continue ;
2289+ }
22922290
2293- if (draft.size () > params.n_draft_min ) {
2294- common_batch_clear (slot.batch_spec );
2295- common_batch_add (slot.batch_spec , id, slot.n_past ++, { slot.id }, true );
2291+ // TODO: configurable through requests
2292+ struct common_speculative_params params_spec;
2293+ params_spec.n_draft = params.n_draft ;
2294+ params_spec.n_reuse = 256 ;
2295+ params_spec.p_min = 0 .9f ;
22962296
2297- for (size_t i = 0 ; i < draft.size (); ++i) {
2298- common_batch_add (slot.batch_spec , draft[i], slot.n_past + i, { slot.id }, true );
2299- }
2297+ llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
23002298
2301- llama_decode (ctx, slot. batch_spec );
2302-
2303- const auto ids = common_sampler_sample_n (slot. smpl , ctx, draft);
2299+ if (params. n_draft_min > ( int ) draft. size ()) {
2300+ continue ;
2301+ }
23042302
2305- slot.n_past += ids.size () - 1 ;
2303+ // construct the speculation batch
2304+ common_batch_clear (slot.batch_spec );
2305+ common_batch_add (slot.batch_spec , id, slot.n_past , { slot.id }, true );
23062306
2307- slot.cache_tokens .push_back (id);
2307+ for (size_t i = 0 ; i < draft.size (); ++i) {
2308+ common_batch_add (slot.batch_spec , draft[i], slot.n_past + 1 + i, { slot.id }, true );
2309+ }
23082310
2309- for (size_t i = 0 ; i < ids.size (); ++i) {
2310- completion_token_output result;
2311+ llama_decode (ctx, slot.batch_spec );
23112312
2312- id = ids[i];
2313+ // the accepted tokens from the speculation
2314+ const auto ids = common_sampler_sample_n (slot.smpl , ctx, draft);
23132315
2314- common_sampler_accept (slot.smpl , id, true );
2316+ slot.n_past += ids.size ();
2317+ slot.n_decoded += ids.size ();
23152318
2316- slot.n_decoded += 1 ;
2317- if (slot.n_decoded == 1 ) {
2318- slot.t_start_generation = ggml_time_us ();
2319- slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt ) / 1e3 ;
2320- metrics.on_prompt_eval (slot);
2321- }
2319+ slot.cache_tokens .push_back (id);
2320+ slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
23222321
2323- result. tok = id ;
2322+ llama_kv_cache_seq_rm (ctx, slot. id , slot. n_past , - 1 ) ;
23242323
2325- const auto * cur_p = common_sampler_get_candidates (slot.smpl );
2324+ for (size_t i = 0 ; i < ids.size (); ++i) {
2325+ completion_token_output result;
23262326
2327- for (size_t i = 0 ; i < (size_t ) slot.sparams .n_probs ; ++i) {
2328- result.probs .push_back ({
2329- cur_p->data [i].id ,
2330- i >= cur_p->size ? 0 .0f : cur_p->data [i].p ,
2331- });
2332- }
2327+ id = ids[i];
23332328
2334- if (!process_token (result, slot)) {
2335- // release slot because of stop condition
2336- slot.release ();
2337- slot.print_timings ();
2338- send_final_response (slot);
2339- metrics.on_prediction (slot);
2340- break ;
2341- }
2342- }
2329+ common_sampler_accept (slot.smpl , id, true );
23432330
2344- llama_kv_cache_seq_rm (ctx, slot. id , slot. n_past , - 1 ) ;
2331+ result. tok = id ;
23452332
2346- slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
2333+ if (!process_token (result, slot)) {
2334+ // release slot because of stop condition
2335+ slot.release ();
2336+ slot.print_timings ();
2337+ send_final_response (slot);
2338+ metrics.on_prediction (slot);
2339+ break ;
23472340 }
23482341 }
23492342 }
0 commit comments