44#include " common.h"
55#include " sampling.h"
66
7+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
8+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
9+
710struct common_speculative {
811 struct common_speculative_params params;
912
10- llama_batch batch_dft ;
13+ llama_batch batch ;
1114
15+ struct llama_context * ctx;
1216 struct common_sampler * smpl;
1317
14- llama_tokens prompt_last ;
18+ llama_tokens prompt ;
1519};
1620
17- struct common_speculative * common_speculative_init (struct common_speculative_params params) {
21+ struct common_speculative * common_speculative_init (
22+ struct common_speculative_params params,
23+ struct llama_context * ctx_dft) {
1824 auto * result = new common_speculative {
19- /* .params = */ params,
20- /* .batch_dft = */ llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 ),
21- /* .smpl = */ nullptr ,
25+ /* .params = */ params,
26+ /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
27+ /* .ctx = */ ctx_dft,
28+ /* .smpl = */ nullptr ,
29+ /* .prompt = */ {},
2230 };
2331
2432 // TODO: optimize or pass from outside?
@@ -36,7 +44,7 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
3644 COMMON_SAMPLER_TYPE_INFILL,
3745 };
3846
39- result->smpl = common_sampler_init (params. model_dft , sparams);
47+ result->smpl = common_sampler_init (llama_get_model (ctx_dft) , sparams);
4048 }
4149#else
4250 {
@@ -49,46 +57,104 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
4957 COMMON_SAMPLER_TYPE_TOP_K,
5058 };
5159
52- result->smpl = common_sampler_init(params.model_dft , sparams);
60+ result->smpl = common_sampler_init(llama_get_model(ctx_dft) , sparams);
5361 }
5462 #endif
5563
56- result->batch_dft = llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 );
57-
5864 return result;
5965}
6066
6167void common_speculative_free (struct common_speculative * spec) {
6268 common_sampler_free (spec->smpl );
6369
64- llama_batch_free (spec->batch_dft );
70+ llama_batch_free (spec->batch );
6571
6672 delete spec;
6773}
6874
75+ bool common_speculative_are_compatible (
76+ const struct llama_context * ctx_tgt,
77+ const struct llama_context * ctx_dft) {
78+ const struct llama_model * model_tgt = llama_get_model (ctx_tgt);
79+ const struct llama_model * model_dft = llama_get_model (ctx_dft);
80+
81+ const bool vocab_type_tgt = llama_vocab_type (model_tgt);
82+ LOG_DBG (" %s: vocab_type tgt: %d\n " , __func__, vocab_type_tgt);
83+
84+ const bool vocab_type_dft = llama_vocab_type (model_dft);
85+ LOG_DBG (" %s: vocab_type dft: %d\n " , __func__, vocab_type_dft);
86+
87+ if (vocab_type_tgt != vocab_type_dft) {
88+ LOG_ERR (" %s: draft model vocab type must match target model to use speculation but "
89+ " vocab_type_dft = %d while vocab_type_tgt = %d\n " , __func__, vocab_type_dft, vocab_type_tgt);
90+ return false ;
91+ }
92+
93+ if (llama_add_bos_token (model_tgt) != llama_add_bos_token (model_dft) ||
94+ llama_add_eos_token (model_tgt) != llama_add_eos_token (model_dft) ||
95+ llama_token_bos (model_tgt) != llama_token_bos (model_dft) ||
96+ llama_token_eos (model_tgt) != llama_token_eos (model_dft)
97+ ) {
98+ LOG_ERR (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
99+ return false ;
100+ }
101+
102+ {
103+ const int n_vocab_tgt = llama_n_vocab (model_tgt);
104+ const int n_vocab_dft = llama_n_vocab (model_dft);
105+
106+ const int vocab_diff = std::abs (n_vocab_tgt - n_vocab_dft);
107+
108+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
109+ LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but "
110+ " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
111+ __func__, n_vocab_tgt, llama_n_vocab (model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
112+ return false ;
113+ }
114+
115+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
116+ const char * token_text_tgt = llama_token_get_text (model_tgt, i);
117+ const char * token_text_dft = llama_token_get_text (model_dft, i);
118+ if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
119+ LOG_ERR (" %s: draft model vocab must match target model to use speculation but "
120+ " token %d content differs - target '%s', draft '%s'\n " , __func__, i,
121+ common_token_to_piece (ctx_tgt, i).c_str (),
122+ common_token_to_piece (ctx_dft, i).c_str ());
123+ return false ;
124+ }
125+ }
126+ }
127+
128+ return true ;
129+ }
130+
69131void common_speculative_add_draft (
70132 struct common_speculative * spec,
71133 struct llama_batch & batch_tgt,
72- const llama_tokens & prompt ,
134+ const llama_tokens & prompt_tgt ,
73135 llama_token id_last,
74136 llama_token n_past_tgt) {
137+ auto & batch = spec->batch ;
138+ auto & ctx = spec->ctx ;
139+ auto & smpl = spec->smpl ;
140+ auto & prompt = spec->prompt ;
75141
76142 int reuse_i = 0 ;
77143 int reuse_n = 0 ;
78144
79- const int n_ctx = llama_n_ctx (spec-> params . ctx_dft ) - spec->params .n_draft ;
145+ const int n_ctx = llama_n_ctx (ctx ) - spec->params .n_draft ;
80146
81- const int i_start = std::max<int >(0 , (int ) prompt .size () - n_ctx);
147+ const int i_start = std::max<int >(0 , (int ) prompt_tgt .size () - n_ctx);
82148
83- for (int i = 0 ; i < (int ) spec-> prompt_last .size (); ++i) {
149+ for (int i = 0 ; i < (int ) prompt .size (); ++i) {
84150 int cur = 0 ;
85- while (i_start + cur < (int ) prompt .size () &&
86- i + cur < (int ) spec-> prompt_last .size () &&
87- prompt [i_start + cur] == spec-> prompt_last [i + cur]) {
151+ while (i_start + cur < (int ) prompt_tgt .size () &&
152+ i + cur < (int ) prompt .size () &&
153+ prompt_tgt [i_start + cur] == prompt [i + cur]) {
88154 cur++;
89155 }
90156
91- if ((cur >= spec->params .n_reuse || prompt .size () <= n_ctx) && cur > reuse_n) {
157+ if ((cur >= spec->params .n_reuse || prompt_tgt .size () <= n_ctx) && cur > reuse_n) {
92158 reuse_i = i;
93159 reuse_n = cur;
94160 }
@@ -97,59 +163,59 @@ void common_speculative_add_draft(
97163 LOG_DBG (" %s: reuse_i = %d, reuse_n = %d\n " , __func__, reuse_i, reuse_n);
98164
99165 if (reuse_n == 0 ) {
100- llama_kv_cache_clear (spec-> params . ctx_dft );
166+ llama_kv_cache_clear (ctx );
101167
102- spec-> prompt_last .clear ();
168+ prompt .clear ();
103169 } else {
104- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , 0 , reuse_i);
105- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , reuse_i + reuse_n, -1 );
106- llama_kv_cache_seq_add (spec-> params . ctx_dft , 0 , reuse_i, -1 , -reuse_i);
170+ llama_kv_cache_seq_rm (ctx , 0 , 0 , reuse_i);
171+ llama_kv_cache_seq_rm (ctx , 0 , reuse_i + reuse_n, -1 );
172+ llama_kv_cache_seq_add (ctx , 0 , reuse_i, -1 , -reuse_i);
107173
108- spec-> prompt_last .erase (spec-> prompt_last .begin (), spec-> prompt_last .begin () + reuse_i);
109- spec-> prompt_last .erase (spec-> prompt_last .begin () + reuse_n, spec-> prompt_last .end ());
174+ prompt .erase (prompt .begin (), prompt .begin () + reuse_i);
175+ prompt .erase (prompt .begin () + reuse_n, prompt .end ());
110176 }
111177
112- common_batch_clear (spec-> batch_dft );
178+ common_batch_clear (batch );
113179
114- for (int i = i_start + reuse_n; i < (int ) prompt .size (); ++i) {
115- // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt [i]);
116- common_batch_add (spec-> batch_dft , prompt [i], i - i_start, { 0 }, false );
180+ for (int i = i_start + reuse_n; i < (int ) prompt_tgt .size (); ++i) {
181+ // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt [i]);
182+ common_batch_add (batch, prompt_tgt [i], i - i_start, { 0 }, false );
117183
118- spec-> prompt_last .push_back (prompt [i]);
184+ prompt .push_back (prompt_tgt [i]);
119185 }
120186
121- const llama_pos n_past = prompt .size () - i_start;
187+ const llama_pos n_past = prompt_tgt .size () - i_start;
122188
123189 LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
124190
125- if (spec-> batch_dft .n_tokens > 0 ) {
126- LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> batch_dft ).c_str ());
191+ if (batch .n_tokens > 0 ) {
192+ LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (ctx, batch ).c_str ());
127193
128- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
194+ llama_decode (ctx, batch );
129195 }
130196
131- common_batch_clear (spec-> batch_dft );
132- common_batch_add (spec-> batch_dft , id_last, n_past, { 0 }, true );
197+ common_batch_clear (batch );
198+ common_batch_add (batch , id_last, n_past, { 0 }, true );
133199
134- spec-> prompt_last .push_back (id_last);
200+ prompt .push_back (id_last);
135201
136- LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> prompt_last ).c_str ());
202+ LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (ctx, prompt ).c_str ());
137203
138- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
204+ llama_decode (ctx, batch );
139205
140- common_sampler_reset (spec-> smpl );
206+ common_sampler_reset (smpl);
141207
142208 // sample n_draft tokens from the draft model
143209 for (int i = 0 ; i < spec->params .n_draft ; ++i) {
144- common_batch_clear (spec-> batch_dft );
210+ common_batch_clear (batch );
145211
146- common_sampler_sample (spec-> smpl , spec-> params . ctx_dft , 0 , true );
212+ common_sampler_sample (smpl, ctx , 0 , true );
147213
148- const auto * cur_p = common_sampler_get_candidates (spec-> smpl );
214+ const auto * cur_p = common_sampler_get_candidates (smpl);
149215
150216 for (int k = 0 ; k < std::min (3 , (int ) cur_p->size ); ++k) {
151217 LOG_DBG (" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n " ,
152- k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (spec-> params . ctx_dft , cur_p->data [k].id ).c_str ());
218+ k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx , cur_p->data [k].id ).c_str ());
153219 }
154220
155221 // add drafted token for each sequence
@@ -160,20 +226,20 @@ void common_speculative_add_draft(
160226 break ;
161227 }
162228
163- common_sampler_accept (spec-> smpl , id, true );
229+ common_sampler_accept (smpl, id, true );
164230
165231 common_batch_add (batch_tgt, id, n_past_tgt + i, { 0 }, true );
166232
167233 if (batch_tgt.n_tokens > spec->params .n_draft ) {
168234 break ;
169235 }
170236
171- common_batch_add (spec-> batch_dft , id, n_past + i + 1 , { 0 }, true );
237+ common_batch_add (batch , id, n_past + i + 1 , { 0 }, true );
172238
173239 // evaluate the drafted tokens on the draft model
174- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
240+ llama_decode (ctx, batch );
175241
176- spec-> prompt_last .push_back (id);
242+ prompt .push_back (id);
177243 }
178244
179245 // don't waste time on small batches
0 commit comments