1
+ # -----------------------------------------------------------------------------
2
+ #
3
+ # Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4
+ # SPDX-License-Identifier: BSD-3-Clause
5
+ #
6
+ # -----------------------------------------------------------------------------
7
+
8
+ from typing import List , Optional
9
+
10
+ import numpy as np
11
+ from transformers import AutoTokenizer
12
+
13
+ from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM
14
+ from QEfficient .generation .cloud_infer import QAICInferenceSession
15
+
16
+
17
+ def run_prefill_on_draft_and_target (
18
+ tlm_session : QAICInferenceSession ,
19
+ dlm_session : QAICInferenceSession ,
20
+ prompt : dict ,
21
+ prompt_len : int ,
22
+ ctx_len : int ,
23
+ prefill_batch_size : int ,
24
+ decode_batch_size : int ,
25
+ slot_idx : int
26
+ ):
27
+ tlm_decode_start_input = dict ()
28
+ dlm_decode_start_input = dict ()
29
+ inputs = prompt
30
+ input_len = prompt .input_ids .shape [1 ]
31
+ num_chunks = - (input_len // - prompt_len ) # ceil divide without float
32
+ input_len = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len
33
+ assert input_len <= ctx_len , "input_len should be less than ctx_len"
34
+ # pad the prompt tokens to match the input_len
35
+ inputs = prompt
36
+ # TODO need to store the attention mask and position ids for each batch element so that we can access them
37
+ # at decode time
38
+ inputs ["attention_mask" ] = np .concatenate (
39
+ [inputs ["attention_mask" ].astype (bool ) for j in range (decode_batch_size )], 0
40
+ )
41
+ inputs ["position_ids" ] = (np .cumsum (inputs ["attention_mask" ][0 :1 ], 1 ) - 1 ) * inputs ["attention_mask" ][0 :1 ]
42
+
43
+ # FIXME "not" does not work for below line in place of the "== False" check, but code formatter recommends it
44
+ inputs ["position_ids" ][inputs ["attention_mask" ][0 :1 ] == False ] = - 1
45
+ cache_index = np .array ([[0 ]], np .int64 )
46
+ batch_index = np .array ([[slot_idx ]], np .int64 )
47
+ inputs ["batch_index" ] = batch_index
48
+
49
+ # Run chunked prefill
50
+ for i in range (num_chunks ):
51
+ chunk_inputs = inputs .copy ()
52
+ chunk_inputs ["input_ids" ] = inputs ["input_ids" ][:, cache_index [0 , 0 ] : cache_index [0 , 0 ] + prompt_len ]
53
+ chunk_inputs ["position_ids" ] = inputs ["position_ids" ][:, cache_index [0 , 0 ] : cache_index [0 , 0 ] + prompt_len ]
54
+
55
+ chunk_inputs .pop ("attention_mask" )
56
+ tlm_outputs = tlm_session .run (chunk_inputs )
57
+ dlm_outputs = dlm_session .run (chunk_inputs )
58
+ cache_index += prompt_len
59
+
60
+ tlm_logits = tlm_outputs ["logits" ]
61
+ dlm_logits = dlm_outputs ["logits" ]
62
+
63
+ if len (tlm_logits .shape ) == 2 :
64
+ tlm_logits = np .expand_dims (tlm_logits , 1 )
65
+ if len (dlm_logits .shape ) == 2 :
66
+ dlm_logits = np .expand_dims (dlm_logits , 1 )
67
+
68
+ tlm_decode_start_pos_id = inputs ["attention_mask" ][0 :1 ].sum (1 , keepdims = True )
69
+ tlm_decode_start_input_id = tlm_logits .argmax (2 )
70
+ dlm_decode_start_input_id = dlm_logits .argmax (2 )
71
+ dlm_decode_start_pos_id = inputs ["attention_mask" ][0 :1 ].sum (1 , keepdims = True )
72
+
73
+ inputs .pop ("attention_mask" )
74
+
75
+ tlm_decode_start_input = {
76
+ "logits" : tlm_logits ,
77
+ "input_ids" : tlm_decode_start_input_id ,
78
+ "position_ids" : tlm_decode_start_pos_id ,
79
+ "batch_index" : batch_index ,
80
+ "input_len" : tlm_decode_start_pos_id [0 , 0 ],
81
+ }
82
+ dlm_decode_start_input = {
83
+ "logits" : dlm_logits ,
84
+ "input_ids" : dlm_decode_start_input_id ,
85
+ "position_ids" : dlm_decode_start_pos_id ,
86
+ "batch_index" : batch_index ,
87
+ "input_len" : tlm_decode_start_pos_id [0 , 0 ],
88
+ }
89
+
90
+ return tlm_decode_start_input , dlm_decode_start_input
91
+
92
+
93
+ def get_padded_input_len (input_len : int , prompt_len : int , ctx_len : int ):
94
+ """return padded input length (must be factor of `prompt_len`)
95
+
96
+ Args:
97
+ input_len (int): prompt length
98
+ prompt_len (int): prefill sequence length
99
+ ctx_len (int): context length
100
+
101
+ Returns:
102
+ input_len_padded (int): padded input length
103
+ """
104
+ num_chunks = - (input_len // - prompt_len ) # ceil divide without float
105
+ input_len_padded = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len
106
+ assert input_len_padded <= ctx_len , "input_len rounded to nearest prompt_len multiple should be less than ctx_len"
107
+ return input_len_padded
108
+
109
+
110
+ def populate_inputs (source , dest , index = None ):
111
+ for k , v in dest .items ():
112
+ if k == "batch_index" :
113
+ continue
114
+ if index is None :
115
+ # during decode
116
+ dest [k ] = source [k ]
117
+ else :
118
+ # during prefill with bs=1
119
+ dest [k ][index ] = source [k ]
120
+
121
+ def split_dlm_bonus_token_inputs (dlm_decode_inputs ):
122
+ bonus_token_inputs = dict ()
123
+ bonus_token_inputs ["input_ids" ] = dlm_decode_inputs ["input_ids" ][:,0 :1 ]
124
+ bonus_token_inputs ["position_ids" ] = dlm_decode_inputs ["input_ids" ][:,0 :1 ]
125
+ dlm_decode_inputs ["input_ids" ] = dlm_decode_inputs ["input_ids" ][:,1 :]
126
+ dlm_decode_inputs ["position_ids" ] = dlm_decode_inputs ["position_ids" ][:,1 :]
127
+ return bonus_token_inputs , dlm_decode_inputs
128
+
129
+ def test_spec_decode_inference (
130
+ prompt : List [str ],
131
+ device_group : List [int ],
132
+ num_speculative_tokens : int ,
133
+ prompt_len : int ,
134
+ ctx_len : int ,
135
+ prefill_bsz : int ,
136
+ draft_model_name : str ,
137
+ target_model_name : str ,
138
+ full_batch_size : Optional [int ] = None ,
139
+ ):
140
+ # assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size
141
+ # get vocab size
142
+ tokenizer = AutoTokenizer .from_pretrained (target_model_name )
143
+ if tokenizer .pad_token_id is None :
144
+ tokenizer .pad_token_id = tokenizer .eos_token_id
145
+ vocab_size = len (tokenizer )
146
+
147
+ # export_and_compile tlm and dlm
148
+ target_model = AutoModelForCausalLM .from_pretrained (target_model_name , continuous_batching = True ,is_tlm = True )
149
+ draft_model = AutoModelForCausalLM .from_pretrained (draft_model_name , continuous_batching = True )
150
+
151
+ num_devices = len (device_group )
152
+ target_model_qpc_path : str = target_model .compile (num_cores = 11 ,num_devices = num_devices ,prefill_seq_len = prompt_len ,ctx_len = ctx_len ,mxfp6_matmul = True ,aic_enable_depth_first = True , full_batch_size = full_batch_size , num_speculative_tokens = num_speculative_tokens )
153
+
154
+ draft_model_qpc_path : str = draft_model .compile (is_dlm = False , num_cores = 5 ,prefill_seq_len = prompt_len ,ctx_len = ctx_len ,mxfp6_matmul = True ,aic_enable_depth_first = True , full_batch_size = full_batch_size )
155
+
156
+ # init qaic session
157
+ target_model_session = QAICInferenceSession (target_model_qpc_path , device_ids = [2 ])
158
+ draft_model_session = QAICInferenceSession (draft_model_qpc_path , device_ids = [3 ])
159
+
160
+ # skip inputs/outputs buffers
161
+ target_model_session .skip_buffers (set ([x for x in target_model_session .input_names if x .startswith ("past_" )]))
162
+ target_model_session .skip_buffers (
163
+ set ([x for x in target_model_session .output_names if x .endswith ("_RetainedState" )])
164
+ )
165
+ draft_model_session .skip_buffers (set ([x for x in draft_model_session .input_names if x .startswith ("past_" )]))
166
+ draft_model_session .skip_buffers (set ([x for x in draft_model_session .output_names if x .endswith ("_RetainedState" )]))
167
+
168
+ is_cb = full_batch_size is not None
169
+ if not is_cb :
170
+ prompts = prompt * prefill_bsz
171
+ decode_batch_size = prefill_bsz
172
+ else :
173
+ prompts = prompt
174
+ decode_batch_size = full_batch_size
175
+ # tokenize the prompts
176
+ prompts_tokenized : List [dict ] = []
177
+ for p in prompts :
178
+ input_len : int = tokenizer (p , return_tensors = "np" , padding = True ).input_ids .shape [1 ]
179
+ input_len_padded : int = get_padded_input_len (input_len , prompt_len , ctx_len )
180
+ p_tok : dict = tokenizer (p , return_tensors = "np" , padding = "max_length" , max_length = input_len_padded )
181
+ prompts_tokenized .append (p_tok )
182
+ # create caches to hold generated ids and input prompt lengths
183
+ generated_ids = [[] for i in range (decode_batch_size )]
184
+ input_lengths = [0 ] * decode_batch_size
185
+ # run prefill on both draft and target models
186
+ dlm_decode_inputs = dict ()
187
+ dlm_decode_inputs ["position_ids" ] = np .zeros ((decode_batch_size , 1 ), np .int64 )
188
+ dlm_decode_inputs ["input_ids" ] = np .full ((decode_batch_size , 1 ), tokenizer .pad_token_id )
189
+ dlm_decode_inputs ["batch_index" ] = np .reshape (
190
+ np .array (np .arange (decode_batch_size ), np .int64 ), (decode_batch_size , 1 )
191
+ )
192
+ # mock input key "logits" to store the first batch of output logits
193
+ dlm_decode_inputs ["logits" ] = np .full ((decode_batch_size , 1 , vocab_size ), 0 )
194
+ tlm_precode_inputs = dict (dlm_decode_inputs )
195
+ is_prefill = True
196
+ generation_done = False
197
+ max_gen_len = [ctx_len ] * decode_batch_size
198
+ num_logits_to_keep = num_speculative_tokens + 1
199
+ all_accept = np .full ((decode_batch_size , num_speculative_tokens ), False , dtype = bool )
200
+ tlm_prefill_logits_ph = np .zeros ((prefill_bsz , 1 , vocab_size ), dtype = np .float32 )
201
+ dlm_prefill_logits_ph = np .zeros ((prefill_bsz , 1 , vocab_size ), dtype = np .float32 )
202
+ decode_logits_ph = np .zeros ((decode_batch_size , 1 , vocab_size ), dtype = np .float32 )
203
+ precode_logits_ph = np .zeros ((decode_batch_size , num_logits_to_keep , vocab_size ), dtype = np .float32 )
204
+
205
+ target_model_session .set_buffers ({"logits" : tlm_prefill_logits_ph })
206
+ draft_model_session .set_buffers ({"logits" : dlm_prefill_logits_ph })
207
+ for bi in range (decode_batch_size ):
208
+ # assumes that prefill queue will always be popped from the front
209
+ tlm_prefill_output , dlm_prefill_output = run_prefill_on_draft_and_target (
210
+ tlm_session = target_model_session ,
211
+ dlm_session = draft_model_session ,
212
+ prompt = prompts_tokenized [bi ],
213
+ prompt_len = prompt_len ,
214
+ ctx_len = ctx_len ,
215
+ prefill_batch_size = prefill_bsz ,
216
+ decode_batch_size = decode_batch_size ,
217
+ slot_idx = bi ,
218
+ )
219
+ # this way, we will directly get the updated full batch input dict to run decode
220
+ populate_inputs (dlm_prefill_output , dlm_decode_inputs , bi )
221
+ populate_inputs (tlm_prefill_output , tlm_precode_inputs , bi )
222
+ # assumes that prefill queue will always be popped from the front
223
+ input_lengths [bi ] = tlm_prefill_output ["input_len" ]
224
+ max_gen_len [bi ] -= input_lengths [bi ]
225
+
226
+ target_model_session .set_buffers ({"logits" : precode_logits_ph })
227
+ draft_model_session .set_buffers ({"logits" : decode_logits_ph })
228
+ dlm_run_bonus_token = False
229
+ while not generation_done :
230
+ # compute the processed context length before each iteration to prepare the position id inputs
231
+ processed_context = [len (generated_ids [j ]) + input_lengths [j ] for j in range (decode_batch_size )]
232
+ # generate proposals from draft model
233
+ if is_prefill :
234
+ draft_logits = [dlm_decode_inputs .pop ("logits" )]
235
+ target_logits = [tlm_precode_inputs .pop ("logits" )]
236
+ else :
237
+ if np .any (all_accept ):
238
+ input_ids = []
239
+ position_ids = []
240
+ dlm_run_bonus_token = True
241
+ for bi in range (decode_batch_size ):
242
+ if all_accept [bi ]:
243
+ # both last DLM token and bonus TLM token to be passed as input to DLM
244
+ input_ids .append ([generated_ids [bi ][- 2 ], generated_ids [bi ][- 1 ]])
245
+ position_ids .append ([processed_context [bi ] - 2 , processed_context [bi ] - 1 ])
246
+ else :
247
+ # only the correct token from TLM from previous iteration and the pad_token as a dummy
248
+ input_ids .append ([generated_ids [bi ][- 1 ], tokenizer .pad_token_id ])
249
+ position_ids .append ([processed_context [bi ] - 1 , - 1 ])
250
+ dlm_decode_inputs ["input_ids" ] = np .array (input_ids )
251
+ dlm_decode_inputs ["position_ids" ] = np .array (position_ids )
252
+ else :
253
+ dlm_decode_inputs ["input_ids" ] = np .array ([gid [- 1 ] for gid in generated_ids ], dtype = np .int64 ).reshape (
254
+ (decode_batch_size , 1 )
255
+ )
256
+ dlm_decode_inputs ["position_ids" ] = np .array (
257
+ [(pc - 1 ) for pc in processed_context ], dtype = np .int64
258
+ ).reshape ((decode_batch_size , 1 ))
259
+ # prepare the inputs for the dlm speculation
260
+ # TODO in case of even one of the batch having all_accept, we have to use the seqlen=2 specialization
261
+ # hence need to have dummy -1 position id for other sequences.
262
+ # dlm_decode_inputs["position_ids"] = len(generated_ids per batch)
263
+ # dlm_decode_inputs["input_ids"] = (last gen dlm token) + last true token from TLM
264
+ for k_ in range (num_speculative_tokens ):
265
+ if dlm_run_bonus_token :
266
+ #running decode one extra time in the first speculative iteration
267
+ # workaround to avoid the incorrect precode with 3-specialized multi-batch DLM
268
+ bonus_token_inputs , dlm_decode_inputs = split_dlm_bonus_token_inputs (dlm_decode_inputs )
269
+ dlm_outputs = draft_model_session .run (bonus_token_inputs )
270
+ dlm_run_bonus_token = False
271
+ dlm_outputs = draft_model_session .run (dlm_decode_inputs )
272
+ draft_logits .append (dlm_outputs ["logits" ])
273
+ dlm_decode_inputs ["input_ids" ] = dlm_outputs ["logits" ].argmax (- 1 )
274
+ dlm_decode_inputs ["position_ids" ] = dlm_decode_inputs ["position_ids" ][:, - 1 :] + 1
275
+
276
+ draft_logits = np .array (draft_logits ).squeeze (2 ).transpose ((1 , 0 , 2 ))
277
+ # greedy sampling from draft model
278
+ draft_tokens = draft_logits .argmax (- 1 )
279
+
280
+ # construct precode inputs
281
+ tlm_precode_inputs ["input_ids" ] = draft_tokens
282
+ if not is_prefill :
283
+ last_genid = np .array ([gid [- 1 ] for gid in generated_ids ], dtype = np .int64 ).reshape (decode_batch_size , 1 )
284
+ tlm_precode_inputs ["input_ids" ] = np .concatenate ((last_genid , tlm_precode_inputs ["input_ids" ]), axis = 1 )
285
+ # in case of general precode, first token in input sequence is = last generated TLM token (kv cache backfill)
286
+ tlm_precode_inputs ["position_ids" ] = np .array (
287
+ [np .arange (start = pc - 1 , stop = pc + num_speculative_tokens ) for pc in processed_context ], dtype = np .int64
288
+ )
289
+ else :
290
+ # in case of just first precode, we are feeding in all new positions
291
+ tlm_precode_inputs ["position_ids" ] = np .array (
292
+ [np .arange (start = pc , stop = pc + num_speculative_tokens + 1 ) for pc in processed_context ], dtype = np .int64
293
+ )
294
+
295
+ # run precode on TLM to score the proposed tokens
296
+ tlm_outputs = target_model_session .run (tlm_precode_inputs )
297
+ target_precode_logits = tlm_outputs ["logits" ]
298
+ if is_prefill :
299
+ target_logits = np .concatenate ((target_logits [0 ], target_precode_logits ), axis = 1 )
300
+ # stack the prefill output logit and precode logits into a single tensor
301
+ else :
302
+ target_logits = target_precode_logits
303
+ # greedy sampling from target model
304
+ target_tokens = target_logits .argmax (- 1 )
305
+ # exact matching between draft and target tokens
306
+ matching = draft_tokens == target_tokens [:, :- 1 ]
307
+ num_tokens_selected = np .argmin (matching , axis = 1 )
308
+ all_accept = matching [np .arange (decode_batch_size ), num_tokens_selected ]
309
+ num_tokens_selected = np .where (all_accept , matching .shape [1 ], num_tokens_selected )
310
+
311
+ # append selected tokens to the generated_ids
312
+ for bi in range (decode_batch_size ):
313
+ if len (generated_ids [bi ]) >= max_gen_len [bi ]:
314
+ continue
315
+ num_tokens_to_append = min (num_tokens_selected [bi ], max_gen_len [bi ] - len (generated_ids [bi ]))
316
+ generated_ids [bi ] += list (draft_tokens [bi , :num_tokens_to_append ])
317
+ # append bonus/corrected token where applicable
318
+ for bi in range (decode_batch_size ):
319
+ if len (generated_ids [bi ]) >= max_gen_len [bi ]:
320
+ continue
321
+ if all_accept [bi ]:
322
+ # bonus token
323
+ generated_ids [bi ].append (target_tokens [bi , - 1 ])
324
+ else :
325
+ # correct token
326
+ generated_ids [bi ].append (target_tokens [bi , num_tokens_selected [bi ]])
327
+ generation_done = True
328
+ for bi in range (decode_batch_size ):
329
+ if len (generated_ids [bi ]) < max_gen_len [bi ]:
330
+ generation_done = False
331
+ is_prefill = False
332
+ draft_logits = []
333
+ target_logits = []
334
+ print ("max generation len = " , max_gen_len )
335
+ print ("actual generation len = " , [len (gid ) for gid in generated_ids ])
336
+ print (tokenizer .batch_decode (generated_ids ))
337
+
338
+
339
+ test_spec_decode_inference (
340
+ ["My name is" , "Hello" , "Hi" , "My name is" ],
341
+ [0 ],
342
+ 5 ,
343
+ 32 ,
344
+ 128 ,
345
+ 1 ,
346
+ "JackFram/llama-68m" ,
347
+ "JackFram/llama-68m" ,
348
+ 4 ,
349
+ )
0 commit comments