@@ -120,7 +120,6 @@ int main(int argc, char ** argv) {
120120 }
121121 }
122122
123-
124123 // Tokenize the prompt
125124 std::vector<llama_token> inp;
126125 inp = common_tokenize (ctx_tgt, params.prompt , true , true );
@@ -139,18 +138,6 @@ int main(int argc, char ** argv) {
139138 LOG (" %s" , common_token_to_piece (ctx_tgt, id).c_str ());
140139 }
141140
142- const int n_input = inp.size ();
143-
144- const auto t_enc_start = ggml_time_us ();
145-
146- // eval the prompt
147- llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), n_input - 1 ));
148-
149- // note: keep the last token separate!
150- llama_token id_last = inp.back ();
151-
152- int n_past = inp.size () - 1 ;
153-
154141 // how many tokens to draft each time
155142 int n_draft = params.n_draft ;
156143
@@ -161,9 +148,25 @@ int main(int argc, char ** argv) {
161148 // used to determine end of generation
162149 bool has_eos = false ;
163150
151+ // ================================================
152+ // everything until here is standard initialization
153+ // the relevant stuff for speculative decoding starts here
154+
155+ const int n_input = inp.size ();
156+
157+ const auto t_enc_start = ggml_time_us ();
158+
164159 // target model sampling context
165160 struct common_sampler * smpl = common_sampler_init (model_tgt, params.sparams );
166161
162+ // eval the prompt
163+ llama_decode (ctx_tgt, llama_batch_get_one (inp.data (), n_input - 1 ));
164+
165+ // note: keep the last token separate!
166+ llama_token id_last = inp.back ();
167+
168+ int n_past = inp.size () - 1 ;
169+
167170 // init the speculator
168171 struct common_speculative_params params_spec;
169172 params_spec.n_draft = n_draft;
@@ -174,6 +177,13 @@ int main(int argc, char ** argv) {
174177 struct common_speculative * spec = common_speculative_init (params_spec);
175178
176179 // feed the prompt to the speculator
180+ //
181+ // this has to be kept synchronized with the target context
182+ //
183+ // TODO: simplify this by moving the context management logic in the common_speculative instance
184+ // for example, the common_speculative_add_draft can pass the entire context (or part of it) and the
185+ // speculator will automatically compute any new tokens that are not present in its context
186+ //
177187 common_speculative_set_prompt (spec, inp.data (), n_input - 1 );
178188
179189 llama_batch batch_tgt = llama_batch_init (llama_n_batch (ctx_tgt), 0 , 1 );
@@ -188,23 +198,41 @@ int main(int argc, char ** argv) {
188198 common_batch_add (batch_tgt, id_last, n_past, { 0 }, true );
189199
190200 // optionally, append draft tokens to the target batch
201+ //
202+ // this is the most important part of the speculation. the more probable tokens that are provided here
203+ // the better the performance will be. in theory, this computation can be performed asynchronously and even
204+ // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
205+ // from a cache or lookup tables.
206+ //
191207 common_speculative_add_draft (spec, batch_tgt, id_last, n_past);
192208
193- // evaluate the target model on the drafted tokens
209+ // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
194210 {
195211 // LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
196212
197213 llama_decode (ctx_tgt, batch_tgt);
198214 }
199215
200- // process the full target batch and return the accepted token based on the target sampler
216+ // sample from the full target batch and return the accepted tokens based on the target sampler
217+ //
218+ // for each token to be accepted, the sampler would have to sample that same token
219+ // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
220+ // available logits from the batch and sample the next token until we run out of logits or the sampler
221+ // disagrees with the draft
222+ //
201223 const auto ids = common_speculative_sample (spec, smpl, ctx_tgt);
202224
225+ GGML_ASSERT (ids.size () > 0 ); // there will always be at least one accepted token
226+
203227 n_past += ids.size ();
204228 n_drafted += batch_tgt.n_tokens - 1 ;
205229 n_accept += ids.size () - 1 ;
206230
207231 // process the accepted tokens and update contexts
232+ //
233+ // this is the standard token post-processing that we normally do
234+ // in this case, we do it for a group of accepted tokens at once
235+ //
208236 {
209237 llama_token id;
210238 std::string token_str;
@@ -232,7 +260,7 @@ int main(int argc, char ** argv) {
232260 break ;
233261 }
234262
235- LOG_DBG (" the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens \n " , id, token_str.c_str ());
263+ LOG_DBG (" accepted %d draft tokens, the last target token is: (%d, '%s')\n " , ( int ) ids. size () - 1 , id, token_str.c_str ());
236264
237265 {
238266 LOG_DBG (" clear kv cache from any extra tokens, n_past = %d\n " , n_past);
@@ -241,6 +269,7 @@ int main(int argc, char ** argv) {
241269 llama_kv_cache_seq_rm (ctx_dft, 0 , n_past, -1 );
242270 }
243271
272+ // remember the last accepted token for the next iteration
244273 id_last = id;
245274 }
246275 }
0 commit comments