Skip to content

Commit ac52060

Browse files
committed
rebasing and adding back test_spd_inference
Signed-off-by: eplatero <[email protected]>
1 parent 7b967e7 commit ac52060

File tree

1 file changed

+349
-0
lines changed

1 file changed

+349
-0
lines changed
Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,349 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
from typing import List, Optional
9+
10+
import numpy as np
11+
from transformers import AutoTokenizer
12+
13+
from QEfficient import QEFFAutoModelForCausalLM as AutoModelForCausalLM
14+
from QEfficient.generation.cloud_infer import QAICInferenceSession
15+
16+
17+
def run_prefill_on_draft_and_target(
18+
tlm_session: QAICInferenceSession,
19+
dlm_session: QAICInferenceSession,
20+
prompt: dict,
21+
prompt_len: int,
22+
ctx_len: int,
23+
prefill_batch_size: int,
24+
decode_batch_size: int,
25+
slot_idx: int
26+
):
27+
tlm_decode_start_input = dict()
28+
dlm_decode_start_input = dict()
29+
inputs = prompt
30+
input_len = prompt.input_ids.shape[1]
31+
num_chunks = -(input_len // -prompt_len) # ceil divide without float
32+
input_len = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len
33+
assert input_len <= ctx_len, "input_len should be less than ctx_len"
34+
# pad the prompt tokens to match the input_len
35+
inputs = prompt
36+
# TODO need to store the attention mask and position ids for each batch element so that we can access them
37+
# at decode time
38+
inputs["attention_mask"] = np.concatenate(
39+
[inputs["attention_mask"].astype(bool) for j in range(decode_batch_size)], 0
40+
)
41+
inputs["position_ids"] = (np.cumsum(inputs["attention_mask"][0:1], 1) - 1) * inputs["attention_mask"][0:1]
42+
43+
# FIXME "not" does not work for below line in place of the "== False" check, but code formatter recommends it
44+
inputs["position_ids"][inputs["attention_mask"][0:1] == False] = -1
45+
cache_index = np.array([[0]], np.int64)
46+
batch_index = np.array([[slot_idx]], np.int64)
47+
inputs["batch_index"] = batch_index
48+
49+
# Run chunked prefill
50+
for i in range(num_chunks):
51+
chunk_inputs = inputs.copy()
52+
chunk_inputs["input_ids"] = inputs["input_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len]
53+
chunk_inputs["position_ids"] = inputs["position_ids"][:, cache_index[0, 0] : cache_index[0, 0] + prompt_len]
54+
55+
chunk_inputs.pop("attention_mask")
56+
tlm_outputs = tlm_session.run(chunk_inputs)
57+
dlm_outputs = dlm_session.run(chunk_inputs)
58+
cache_index += prompt_len
59+
60+
tlm_logits = tlm_outputs["logits"]
61+
dlm_logits = dlm_outputs["logits"]
62+
63+
if len(tlm_logits.shape) == 2:
64+
tlm_logits = np.expand_dims(tlm_logits, 1)
65+
if len(dlm_logits.shape) == 2:
66+
dlm_logits = np.expand_dims(dlm_logits, 1)
67+
68+
tlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True)
69+
tlm_decode_start_input_id = tlm_logits.argmax(2)
70+
dlm_decode_start_input_id = dlm_logits.argmax(2)
71+
dlm_decode_start_pos_id = inputs["attention_mask"][0:1].sum(1, keepdims=True)
72+
73+
inputs.pop("attention_mask")
74+
75+
tlm_decode_start_input = {
76+
"logits": tlm_logits,
77+
"input_ids": tlm_decode_start_input_id,
78+
"position_ids": tlm_decode_start_pos_id,
79+
"batch_index": batch_index,
80+
"input_len": tlm_decode_start_pos_id[0, 0],
81+
}
82+
dlm_decode_start_input = {
83+
"logits": dlm_logits,
84+
"input_ids": dlm_decode_start_input_id,
85+
"position_ids": dlm_decode_start_pos_id,
86+
"batch_index": batch_index,
87+
"input_len": tlm_decode_start_pos_id[0, 0],
88+
}
89+
90+
return tlm_decode_start_input, dlm_decode_start_input
91+
92+
93+
def get_padded_input_len(input_len: int, prompt_len: int, ctx_len: int):
94+
"""return padded input length (must be factor of `prompt_len`)
95+
96+
Args:
97+
input_len (int): prompt length
98+
prompt_len (int): prefill sequence length
99+
ctx_len (int): context length
100+
101+
Returns:
102+
input_len_padded (int): padded input length
103+
"""
104+
num_chunks = -(input_len // -prompt_len) # ceil divide without float
105+
input_len_padded = num_chunks * prompt_len # Convert input_len to a multiple of prompt_len
106+
assert input_len_padded <= ctx_len, "input_len rounded to nearest prompt_len multiple should be less than ctx_len"
107+
return input_len_padded
108+
109+
110+
def populate_inputs(source, dest, index=None):
111+
for k, v in dest.items():
112+
if k == "batch_index":
113+
continue
114+
if index is None:
115+
# during decode
116+
dest[k] = source[k]
117+
else:
118+
# during prefill with bs=1
119+
dest[k][index] = source[k]
120+
121+
def split_dlm_bonus_token_inputs(dlm_decode_inputs):
122+
bonus_token_inputs = dict()
123+
bonus_token_inputs["input_ids"] = dlm_decode_inputs["input_ids"][:,0:1]
124+
bonus_token_inputs["position_ids"] = dlm_decode_inputs["input_ids"][:,0:1]
125+
dlm_decode_inputs["input_ids"] = dlm_decode_inputs["input_ids"][:,1:]
126+
dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:,1:]
127+
return bonus_token_inputs, dlm_decode_inputs
128+
129+
def test_spec_decode_inference(
130+
prompt: List[str],
131+
device_group: List[int],
132+
num_speculative_tokens: int,
133+
prompt_len: int,
134+
ctx_len: int,
135+
prefill_bsz: int,
136+
draft_model_name: str,
137+
target_model_name: str,
138+
full_batch_size: Optional[int] = None,
139+
):
140+
# assumes dlm and tlm are compiled to the same prompt-chunk-size, context length and full_batch_size/batch-size
141+
# get vocab size
142+
tokenizer = AutoTokenizer.from_pretrained(target_model_name)
143+
if tokenizer.pad_token_id is None:
144+
tokenizer.pad_token_id = tokenizer.eos_token_id
145+
vocab_size = len(tokenizer)
146+
147+
# export_and_compile tlm and dlm
148+
target_model = AutoModelForCausalLM.from_pretrained(target_model_name, continuous_batching=True,is_tlm=True)
149+
draft_model = AutoModelForCausalLM.from_pretrained(draft_model_name, continuous_batching=True)
150+
151+
num_devices = len(device_group)
152+
target_model_qpc_path: str = target_model.compile(num_cores=11,num_devices=num_devices,prefill_seq_len=prompt_len,ctx_len=ctx_len,mxfp6_matmul=True,aic_enable_depth_first=True, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens)
153+
154+
draft_model_qpc_path: str = draft_model.compile(is_dlm=False, num_cores=5,prefill_seq_len=prompt_len,ctx_len=ctx_len,mxfp6_matmul=True,aic_enable_depth_first=True, full_batch_size=full_batch_size)
155+
156+
# init qaic session
157+
target_model_session = QAICInferenceSession(target_model_qpc_path, device_ids=[2])
158+
draft_model_session = QAICInferenceSession(draft_model_qpc_path, device_ids=[3])
159+
160+
# skip inputs/outputs buffers
161+
target_model_session.skip_buffers(set([x for x in target_model_session.input_names if x.startswith("past_")]))
162+
target_model_session.skip_buffers(
163+
set([x for x in target_model_session.output_names if x.endswith("_RetainedState")])
164+
)
165+
draft_model_session.skip_buffers(set([x for x in draft_model_session.input_names if x.startswith("past_")]))
166+
draft_model_session.skip_buffers(set([x for x in draft_model_session.output_names if x.endswith("_RetainedState")]))
167+
168+
is_cb = full_batch_size is not None
169+
if not is_cb:
170+
prompts = prompt * prefill_bsz
171+
decode_batch_size = prefill_bsz
172+
else:
173+
prompts = prompt
174+
decode_batch_size = full_batch_size
175+
# tokenize the prompts
176+
prompts_tokenized: List[dict] = []
177+
for p in prompts:
178+
input_len: int = tokenizer(p, return_tensors="np", padding=True).input_ids.shape[1]
179+
input_len_padded: int = get_padded_input_len(input_len, prompt_len, ctx_len)
180+
p_tok: dict = tokenizer(p, return_tensors="np", padding="max_length", max_length=input_len_padded)
181+
prompts_tokenized.append(p_tok)
182+
# create caches to hold generated ids and input prompt lengths
183+
generated_ids = [[] for i in range(decode_batch_size)]
184+
input_lengths = [0] * decode_batch_size
185+
# run prefill on both draft and target models
186+
dlm_decode_inputs = dict()
187+
dlm_decode_inputs["position_ids"] = np.zeros((decode_batch_size, 1), np.int64)
188+
dlm_decode_inputs["input_ids"] = np.full((decode_batch_size, 1), tokenizer.pad_token_id)
189+
dlm_decode_inputs["batch_index"] = np.reshape(
190+
np.array(np.arange(decode_batch_size), np.int64), (decode_batch_size, 1)
191+
)
192+
# mock input key "logits" to store the first batch of output logits
193+
dlm_decode_inputs["logits"] = np.full((decode_batch_size, 1, vocab_size), 0)
194+
tlm_precode_inputs = dict(dlm_decode_inputs)
195+
is_prefill = True
196+
generation_done = False
197+
max_gen_len = [ctx_len] * decode_batch_size
198+
num_logits_to_keep = num_speculative_tokens+1
199+
all_accept = np.full((decode_batch_size, num_speculative_tokens), False, dtype=bool)
200+
tlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32)
201+
dlm_prefill_logits_ph = np.zeros((prefill_bsz, 1, vocab_size), dtype=np.float32)
202+
decode_logits_ph = np.zeros((decode_batch_size, 1, vocab_size), dtype=np.float32)
203+
precode_logits_ph = np.zeros((decode_batch_size, num_logits_to_keep, vocab_size), dtype=np.float32)
204+
205+
target_model_session.set_buffers({"logits": tlm_prefill_logits_ph})
206+
draft_model_session.set_buffers({"logits": dlm_prefill_logits_ph})
207+
for bi in range(decode_batch_size):
208+
# assumes that prefill queue will always be popped from the front
209+
tlm_prefill_output, dlm_prefill_output = run_prefill_on_draft_and_target(
210+
tlm_session=target_model_session,
211+
dlm_session=draft_model_session,
212+
prompt=prompts_tokenized[bi],
213+
prompt_len=prompt_len,
214+
ctx_len=ctx_len,
215+
prefill_batch_size=prefill_bsz,
216+
decode_batch_size=decode_batch_size,
217+
slot_idx=bi,
218+
)
219+
# this way, we will directly get the updated full batch input dict to run decode
220+
populate_inputs(dlm_prefill_output, dlm_decode_inputs, bi)
221+
populate_inputs(tlm_prefill_output, tlm_precode_inputs, bi)
222+
# assumes that prefill queue will always be popped from the front
223+
input_lengths[bi] = tlm_prefill_output["input_len"]
224+
max_gen_len[bi] -= input_lengths[bi]
225+
226+
target_model_session.set_buffers({"logits": precode_logits_ph})
227+
draft_model_session.set_buffers({"logits": decode_logits_ph})
228+
dlm_run_bonus_token = False
229+
while not generation_done:
230+
# compute the processed context length before each iteration to prepare the position id inputs
231+
processed_context = [len(generated_ids[j]) + input_lengths[j] for j in range(decode_batch_size)]
232+
# generate proposals from draft model
233+
if is_prefill:
234+
draft_logits = [dlm_decode_inputs.pop("logits")]
235+
target_logits = [tlm_precode_inputs.pop("logits")]
236+
else:
237+
if np.any(all_accept):
238+
input_ids = []
239+
position_ids = []
240+
dlm_run_bonus_token = True
241+
for bi in range(decode_batch_size):
242+
if all_accept[bi]:
243+
# both last DLM token and bonus TLM token to be passed as input to DLM
244+
input_ids.append([generated_ids[bi][-2], generated_ids[bi][-1]])
245+
position_ids.append([processed_context[bi] - 2, processed_context[bi] - 1])
246+
else:
247+
# only the correct token from TLM from previous iteration and the pad_token as a dummy
248+
input_ids.append([generated_ids[bi][-1], tokenizer.pad_token_id])
249+
position_ids.append([processed_context[bi] - 1, -1])
250+
dlm_decode_inputs["input_ids"] = np.array(input_ids)
251+
dlm_decode_inputs["position_ids"] = np.array(position_ids)
252+
else:
253+
dlm_decode_inputs["input_ids"] = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape(
254+
(decode_batch_size, 1)
255+
)
256+
dlm_decode_inputs["position_ids"] = np.array(
257+
[(pc - 1) for pc in processed_context], dtype=np.int64
258+
).reshape((decode_batch_size, 1))
259+
# prepare the inputs for the dlm speculation
260+
# TODO in case of even one of the batch having all_accept, we have to use the seqlen=2 specialization
261+
# hence need to have dummy -1 position id for other sequences.
262+
# dlm_decode_inputs["position_ids"] = len(generated_ids per batch)
263+
# dlm_decode_inputs["input_ids"] = (last gen dlm token) + last true token from TLM
264+
for k_ in range(num_speculative_tokens):
265+
if dlm_run_bonus_token:
266+
#running decode one extra time in the first speculative iteration
267+
# workaround to avoid the incorrect precode with 3-specialized multi-batch DLM
268+
bonus_token_inputs, dlm_decode_inputs = split_dlm_bonus_token_inputs(dlm_decode_inputs)
269+
dlm_outputs = draft_model_session.run(bonus_token_inputs)
270+
dlm_run_bonus_token = False
271+
dlm_outputs = draft_model_session.run(dlm_decode_inputs)
272+
draft_logits.append(dlm_outputs["logits"])
273+
dlm_decode_inputs["input_ids"] = dlm_outputs["logits"].argmax(-1)
274+
dlm_decode_inputs["position_ids"] = dlm_decode_inputs["position_ids"][:, -1:] + 1
275+
276+
draft_logits = np.array(draft_logits).squeeze(2).transpose((1, 0, 2))
277+
# greedy sampling from draft model
278+
draft_tokens = draft_logits.argmax(-1)
279+
280+
# construct precode inputs
281+
tlm_precode_inputs["input_ids"] = draft_tokens
282+
if not is_prefill:
283+
last_genid = np.array([gid[-1] for gid in generated_ids], dtype=np.int64).reshape(decode_batch_size, 1)
284+
tlm_precode_inputs["input_ids"] = np.concatenate((last_genid, tlm_precode_inputs["input_ids"]), axis=1)
285+
# in case of general precode, first token in input sequence is = last generated TLM token (kv cache backfill)
286+
tlm_precode_inputs["position_ids"] = np.array(
287+
[np.arange(start=pc - 1, stop=pc + num_speculative_tokens) for pc in processed_context], dtype=np.int64
288+
)
289+
else:
290+
# in case of just first precode, we are feeding in all new positions
291+
tlm_precode_inputs["position_ids"] = np.array(
292+
[np.arange(start=pc, stop=pc + num_speculative_tokens + 1) for pc in processed_context], dtype=np.int64
293+
)
294+
295+
# run precode on TLM to score the proposed tokens
296+
tlm_outputs = target_model_session.run(tlm_precode_inputs)
297+
target_precode_logits = tlm_outputs["logits"]
298+
if is_prefill:
299+
target_logits = np.concatenate((target_logits[0], target_precode_logits), axis=1)
300+
# stack the prefill output logit and precode logits into a single tensor
301+
else:
302+
target_logits = target_precode_logits
303+
# greedy sampling from target model
304+
target_tokens = target_logits.argmax(-1)
305+
# exact matching between draft and target tokens
306+
matching = draft_tokens == target_tokens[:, :-1]
307+
num_tokens_selected = np.argmin(matching, axis=1)
308+
all_accept = matching[np.arange(decode_batch_size), num_tokens_selected]
309+
num_tokens_selected = np.where(all_accept, matching.shape[1], num_tokens_selected)
310+
311+
# append selected tokens to the generated_ids
312+
for bi in range(decode_batch_size):
313+
if len(generated_ids[bi]) >= max_gen_len[bi]:
314+
continue
315+
num_tokens_to_append = min(num_tokens_selected[bi], max_gen_len[bi] - len(generated_ids[bi]))
316+
generated_ids[bi] += list(draft_tokens[bi, :num_tokens_to_append])
317+
# append bonus/corrected token where applicable
318+
for bi in range(decode_batch_size):
319+
if len(generated_ids[bi]) >= max_gen_len[bi]:
320+
continue
321+
if all_accept[bi]:
322+
# bonus token
323+
generated_ids[bi].append(target_tokens[bi, -1])
324+
else:
325+
# correct token
326+
generated_ids[bi].append(target_tokens[bi, num_tokens_selected[bi]])
327+
generation_done = True
328+
for bi in range(decode_batch_size):
329+
if len(generated_ids[bi]) < max_gen_len[bi]:
330+
generation_done = False
331+
is_prefill = False
332+
draft_logits = []
333+
target_logits = []
334+
print("max generation len = ", max_gen_len)
335+
print("actual generation len = ", [len(gid) for gid in generated_ids])
336+
print(tokenizer.batch_decode(generated_ids))
337+
338+
339+
test_spec_decode_inference(
340+
["My name is", "Hello", "Hi", "My name is"],
341+
[0],
342+
5,
343+
32,
344+
128,
345+
1,
346+
"JackFram/llama-68m",
347+
"JackFram/llama-68m",
348+
4,
349+
)

0 commit comments

Comments
 (0)