44#include " common.h"
55#include " sampling.h"
66
7+ #include < cstring>
8+
9+ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
10+ #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
11+
712struct common_speculative {
813 struct common_speculative_params params;
914
10- llama_batch batch_dft ;
15+ llama_batch batch ;
1116
17+ struct llama_context * ctx;
1218 struct common_sampler * smpl;
1319
14- llama_tokens prompt_last ;
20+ llama_tokens prompt ;
1521};
1622
17- struct common_speculative * common_speculative_init (struct common_speculative_params params) {
23+ struct common_speculative * common_speculative_init (
24+ struct common_speculative_params params,
25+ struct llama_context * ctx_dft) {
1826 auto * result = new common_speculative {
19- /* .params = */ params,
20- /* .batch_dft = */ llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 ),
21- /* .smpl = */ nullptr ,
27+ /* .params = */ params,
28+ /* .batch = */ llama_batch_init (llama_n_batch (ctx_dft), 0 , 1 ),
29+ /* .ctx = */ ctx_dft,
30+ /* .smpl = */ nullptr ,
31+ /* .prompt = */ {},
2232 };
2333
2434 // TODO: optimize or pass from outside?
@@ -36,7 +46,7 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
3646 COMMON_SAMPLER_TYPE_INFILL,
3747 };
3848
39- result->smpl = common_sampler_init (params. model_dft , sparams);
49+ result->smpl = common_sampler_init (llama_get_model (ctx_dft) , sparams);
4050 }
4151#else
4252 {
@@ -49,46 +59,104 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa
4959 COMMON_SAMPLER_TYPE_TOP_K,
5060 };
5161
52- result->smpl = common_sampler_init(params.model_dft , sparams);
62+ result->smpl = common_sampler_init(llama_get_model(ctx_dft) , sparams);
5363 }
5464 #endif
5565
56- result->batch_dft = llama_batch_init (llama_n_batch (params.ctx_dft ), 0 , 1 );
57-
5866 return result;
5967}
6068
6169void common_speculative_free (struct common_speculative * spec) {
6270 common_sampler_free (spec->smpl );
6371
64- llama_batch_free (spec->batch_dft );
72+ llama_batch_free (spec->batch );
6573
6674 delete spec;
6775}
6876
77+ bool common_speculative_are_compatible (
78+ const struct llama_context * ctx_tgt,
79+ const struct llama_context * ctx_dft) {
80+ const struct llama_model * model_tgt = llama_get_model (ctx_tgt);
81+ const struct llama_model * model_dft = llama_get_model (ctx_dft);
82+
83+ const bool vocab_type_tgt = llama_vocab_type (model_tgt);
84+ LOG_DBG (" %s: vocab_type tgt: %d\n " , __func__, vocab_type_tgt);
85+
86+ const bool vocab_type_dft = llama_vocab_type (model_dft);
87+ LOG_DBG (" %s: vocab_type dft: %d\n " , __func__, vocab_type_dft);
88+
89+ if (vocab_type_tgt != vocab_type_dft) {
90+ LOG_ERR (" %s: draft model vocab type must match target model to use speculation but "
91+ " vocab_type_dft = %d while vocab_type_tgt = %d\n " , __func__, vocab_type_dft, vocab_type_tgt);
92+ return false ;
93+ }
94+
95+ if (llama_add_bos_token (model_tgt) != llama_add_bos_token (model_dft) ||
96+ llama_add_eos_token (model_tgt) != llama_add_eos_token (model_dft) ||
97+ llama_token_bos (model_tgt) != llama_token_bos (model_dft) ||
98+ llama_token_eos (model_tgt) != llama_token_eos (model_dft)
99+ ) {
100+ LOG_ERR (" %s: draft model special tokens must match target model to use speculation\n " , __func__);
101+ return false ;
102+ }
103+
104+ {
105+ const int n_vocab_tgt = llama_n_vocab (model_tgt);
106+ const int n_vocab_dft = llama_n_vocab (model_dft);
107+
108+ const int vocab_diff = std::abs (n_vocab_tgt - n_vocab_dft);
109+
110+ if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
111+ LOG_ERR (" %s: draft model vocab must closely match target model to use speculation but "
112+ " target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n " ,
113+ __func__, n_vocab_tgt, llama_n_vocab (model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
114+ return false ;
115+ }
116+
117+ for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min (n_vocab_tgt, n_vocab_dft); ++i) {
118+ const char * token_text_tgt = llama_token_get_text (model_tgt, i);
119+ const char * token_text_dft = llama_token_get_text (model_dft, i);
120+ if (std::strcmp (token_text_tgt, token_text_dft) != 0 ) {
121+ LOG_ERR (" %s: draft model vocab must match target model to use speculation but "
122+ " token %d content differs - target '%s', draft '%s'\n " , __func__, i,
123+ common_token_to_piece (ctx_tgt, i).c_str (),
124+ common_token_to_piece (ctx_dft, i).c_str ());
125+ return false ;
126+ }
127+ }
128+ }
129+
130+ return true ;
131+ }
132+
69133void common_speculative_add_draft (
70134 struct common_speculative * spec,
71135 struct llama_batch & batch_tgt,
72- const llama_tokens & prompt ,
136+ const llama_tokens & prompt_tgt ,
73137 llama_token id_last,
74138 llama_token n_past_tgt) {
139+ auto & batch = spec->batch ;
140+ auto & ctx = spec->ctx ;
141+ auto & smpl = spec->smpl ;
142+ auto & prompt = spec->prompt ;
75143
76144 int reuse_i = 0 ;
77145 int reuse_n = 0 ;
78146
79- const int n_ctx = llama_n_ctx (spec-> params . ctx_dft ) - spec->params .n_draft ;
147+ const int n_ctx = llama_n_ctx (ctx ) - spec->params .n_draft ;
80148
81- const int i_start = std::max<int >(0 , (int ) prompt .size () - n_ctx);
149+ const int i_start = std::max<int >(0 , (int ) prompt_tgt .size () - n_ctx);
82150
83- for (int i = 0 ; i < (int ) spec-> prompt_last .size (); ++i) {
151+ for (int i = 0 ; i < (int ) prompt .size (); ++i) {
84152 int cur = 0 ;
85- while (i_start + cur < (int ) prompt .size () &&
86- i + cur < (int ) spec-> prompt_last .size () &&
87- prompt [i_start + cur] == spec-> prompt_last [i + cur]) {
153+ while (i_start + cur < (int ) prompt_tgt .size () &&
154+ i + cur < (int ) prompt .size () &&
155+ prompt_tgt [i_start + cur] == prompt [i + cur]) {
88156 cur++;
89157 }
90158
91- if ((cur >= spec->params .n_reuse || prompt .size () <= n_ctx) && cur > reuse_n) {
159+ if ((cur >= spec->params .n_reuse || prompt_tgt .size () <= n_ctx) && cur > reuse_n) {
92160 reuse_i = i;
93161 reuse_n = cur;
94162 }
@@ -97,59 +165,59 @@ void common_speculative_add_draft(
97165 LOG_DBG (" %s: reuse_i = %d, reuse_n = %d\n " , __func__, reuse_i, reuse_n);
98166
99167 if (reuse_n == 0 ) {
100- llama_kv_cache_clear (spec-> params . ctx_dft );
168+ llama_kv_cache_clear (ctx );
101169
102- spec-> prompt_last .clear ();
170+ prompt .clear ();
103171 } else {
104- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , 0 , reuse_i);
105- llama_kv_cache_seq_rm (spec-> params . ctx_dft , 0 , reuse_i + reuse_n, -1 );
106- llama_kv_cache_seq_add (spec-> params . ctx_dft , 0 , reuse_i, -1 , -reuse_i);
172+ llama_kv_cache_seq_rm (ctx , 0 , 0 , reuse_i);
173+ llama_kv_cache_seq_rm (ctx , 0 , reuse_i + reuse_n, -1 );
174+ llama_kv_cache_seq_add (ctx , 0 , reuse_i, -1 , -reuse_i);
107175
108- spec-> prompt_last .erase (spec-> prompt_last .begin (), spec-> prompt_last .begin () + reuse_i);
109- spec-> prompt_last .erase (spec-> prompt_last .begin () + reuse_n, spec-> prompt_last .end ());
176+ prompt .erase (prompt .begin (), prompt .begin () + reuse_i);
177+ prompt .erase (prompt .begin () + reuse_n, prompt .end ());
110178 }
111179
112- common_batch_clear (spec-> batch_dft );
180+ common_batch_clear (batch );
113181
114- for (int i = i_start + reuse_n; i < (int ) prompt .size (); ++i) {
115- // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt [i]);
116- common_batch_add (spec-> batch_dft , prompt [i], i - i_start, { 0 }, false );
182+ for (int i = i_start + reuse_n; i < (int ) prompt_tgt .size (); ++i) {
183+ // LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt [i]);
184+ common_batch_add (batch, prompt_tgt [i], i - i_start, { 0 }, false );
117185
118- spec-> prompt_last .push_back (prompt [i]);
186+ prompt .push_back (prompt_tgt [i]);
119187 }
120188
121- const llama_pos n_past = prompt .size () - i_start;
189+ const llama_pos n_past = prompt_tgt .size () - i_start;
122190
123191 LOG_DBG (" %s: n_past = %d\n " , __func__, n_past);
124192
125- if (spec-> batch_dft .n_tokens > 0 ) {
126- LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> batch_dft ).c_str ());
193+ if (batch .n_tokens > 0 ) {
194+ LOG_DBG (" %s: draft batch: %s\n " , __func__, string_from (ctx, batch ).c_str ());
127195
128- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
196+ llama_decode (ctx, batch );
129197 }
130198
131- common_batch_clear (spec-> batch_dft );
132- common_batch_add (spec-> batch_dft , id_last, n_past, { 0 }, true );
199+ common_batch_clear (batch );
200+ common_batch_add (batch , id_last, n_past, { 0 }, true );
133201
134- spec-> prompt_last .push_back (id_last);
202+ prompt .push_back (id_last);
135203
136- LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (spec-> params . ctx_dft , spec-> prompt_last ).c_str ());
204+ LOG_DBG (" %s: prompt_last: %s\n " , __func__, string_from (ctx, prompt ).c_str ());
137205
138- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
206+ llama_decode (ctx, batch );
139207
140- common_sampler_reset (spec-> smpl );
208+ common_sampler_reset (smpl);
141209
142210 // sample n_draft tokens from the draft model
143211 for (int i = 0 ; i < spec->params .n_draft ; ++i) {
144- common_batch_clear (spec-> batch_dft );
212+ common_batch_clear (batch );
145213
146- common_sampler_sample (spec-> smpl , spec-> params . ctx_dft , 0 , true );
214+ common_sampler_sample (smpl, ctx , 0 , true );
147215
148- const auto * cur_p = common_sampler_get_candidates (spec-> smpl );
216+ const auto * cur_p = common_sampler_get_candidates (smpl);
149217
150218 for (int k = 0 ; k < std::min (3 , (int ) cur_p->size ); ++k) {
151219 LOG_DBG (" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n " ,
152- k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (spec-> params . ctx_dft , cur_p->data [k].id ).c_str ());
220+ k, i, cur_p->data [k].id , cur_p->data [k].p , common_token_to_piece (ctx , cur_p->data [k].id ).c_str ());
153221 }
154222
155223 // add drafted token for each sequence
@@ -160,20 +228,20 @@ void common_speculative_add_draft(
160228 break ;
161229 }
162230
163- common_sampler_accept (spec-> smpl , id, true );
231+ common_sampler_accept (smpl, id, true );
164232
165233 common_batch_add (batch_tgt, id, n_past_tgt + i, { 0 }, true );
166234
167235 if (batch_tgt.n_tokens > spec->params .n_draft ) {
168236 break ;
169237 }
170238
171- common_batch_add (spec-> batch_dft , id, n_past + i + 1 , { 0 }, true );
239+ common_batch_add (batch , id, n_past + i + 1 , { 0 }, true );
172240
173241 // evaluate the drafted tokens on the draft model
174- llama_decode (spec-> params . ctx_dft , spec-> batch_dft );
242+ llama_decode (ctx, batch );
175243
176- spec-> prompt_last .push_back (id);
244+ prompt .push_back (id);
177245 }
178246
179247 // don't waste time on small batches
0 commit comments