Skip to content

Commit 34386ed

Browse files
Finite lorax support (#153)
* Initial commit for finite loras implementation Signed-off-by: Jou-An Chen <[email protected]> * Remove set delete adapter, add init assertion, update LinearMultiLoRA Signed-off-by: Jou-An Chen <[email protected]> * Fix base model inference index INTMAX issue Signed-off-by: Jou-An Chen <[email protected]> * Addressed review comments Signed-off-by: Jou-An Chen <[email protected]> * Rebase on PR116 and make API changes Signed-off-by: Jou-An Chen <[email protected]> * Enable init from QEffAutoPeftModelForCausalLM with finite_adapters flag Signed-off-by: Jou-An Chen <[email protected]> * Address review comments Signed-off-by: Jou-An Chen <[email protected]> * allow adapter_name passed as keyword argument, updated all finite lora tests to use single layer models Signed-off-by: Onkar Chougule <[email protected]> * added pytest on_qaic marker for lora test using AI_100 device Signed-off-by: Onkar Chougule <[email protected]> --------- Signed-off-by: Jou-An Chen <[email protected]> Signed-off-by: Onkar Chougule <[email protected]> Co-authored-by: Onkar Chougule <[email protected]>
1 parent 40ec246 commit 34386ed

File tree

12 files changed

+1064
-11
lines changed

12 files changed

+1064
-11
lines changed

QEfficient/generation/text_generation_inference.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,7 @@ def cloud_ai_100_exec_kv(
230230
stream: bool = True,
231231
write_io_dir: Optional[str] = None,
232232
automation=False,
233+
prompt_to_lora_id_mapping: Optional[List[int]] = None,
233234
):
234235
"""
235236
This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
@@ -277,6 +278,7 @@ def cloud_ai_100_exec_kv(
277278
stream=stream,
278279
write_io_dir=write_io_dir,
279280
full_batch_size=full_batch_size,
281+
prompt_to_lora_id_mapping=prompt_to_lora_id_mapping,
280282
)
281283
if full_batch_size is None:
282284
exec_info = [
@@ -313,6 +315,7 @@ def __init__(
313315
qpc_path: str,
314316
prompt: List[str],
315317
full_batch_size: Optional[int] = None,
318+
prompt_to_lora_id_mapping: Optional[List[int]] = None,
316319
ctx_len: Optional[int] = None,
317320
generation_len: Optional[int] = None,
318321
device_id: Optional[List[int]] = None,
@@ -342,6 +345,16 @@ def __init__(
342345
full_batch_size if full_batch_size else self._fetch_full_batch_size()
343346
) # Check and fetch full batch size if CB is enabled
344347

348+
if prompt_to_lora_id_mapping:
349+
self.prompt_to_lora_id_mapping_prefill = deque(prompt_to_lora_id_mapping)
350+
if self.full_batch_size:
351+
self.prompt_to_lora_id_mapping_decode = prompt_to_lora_id_mapping
352+
else:
353+
self.prompt_to_lora_id_mapping_decode = deque(prompt_to_lora_id_mapping)
354+
else:
355+
self.prompt_to_lora_id_mapping_prefill = None
356+
self.prompt_to_lora_id_mapping_decode = None
357+
345358
self.set_tokenizer_params() # set tokenizer params
346359

347360
# Initialize the storage variables.
@@ -464,6 +477,16 @@ def prepare_decode_inputs(self):
464477
if self.batch_index is not None:
465478
decode_inputs["batch_index"] = self.batch_index
466479

480+
if self.prompt_to_lora_id_mapping_decode:
481+
if self.full_batch_size:
482+
first_batch_lora_ids = [self.prompt_to_lora_id_mapping_decode[i] for i in range(self.full_batch_size)]
483+
decode_inputs["lora_ids"] = np.array(first_batch_lora_ids, dtype=np.int64).reshape(
484+
self.full_batch_size, 1
485+
)
486+
else:
487+
batch_lora_ids = [self.prompt_to_lora_id_mapping_decode.popleft() for i in range(self.batch_size)]
488+
decode_inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
489+
467490
return decode_inputs
468491

469492
def _update_decode_input(self, outputs, position_ids, generation_len, decode_batch_id=None):
@@ -552,6 +575,15 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
552575
if decode_batch_id is not None:
553576
inputs["batch_index"] = decode_batch_id
554577

578+
if self.prompt_to_lora_id_mapping_prefill:
579+
if self.full_batch_size:
580+
inputs["lora_ids"] = np.array(self.prompt_to_lora_id_mapping_prefill.popleft(), dtype=np.int64).reshape(
581+
1, 1
582+
)
583+
else:
584+
batch_lora_ids = [self.prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
585+
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
586+
555587
for i in range(num_chunks):
556588
chunk_inputs = inputs.copy()
557589
chunk_inputs["input_ids"] = inputs["input_ids"][
@@ -628,6 +660,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
628660

629661
self.session.set_buffers({"logits": logits_out_placeholder})
630662
decode_pause_time += perf_counter() - start
663+
664+
if self.prompt_to_lora_id_mapping_decode:
665+
decode_inputs["lora_ids"][decode_batch_id] = self.prompt_to_lora_id_mapping_decode[
666+
batch_id_map[decode_batch_id]
667+
]
668+
631669
else:
632670
current_decode_ongoing[decode_batch_id] = False
633671
else:
@@ -639,6 +677,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
639677
)
640678

641679
generated_id_current_index[decode_batch_id] += 1
680+
642681
return decode_pause_time
643682

644683
def run_decode(self, decode_inputs, generation_len):

QEfficient/peft/auto.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import numpy as np
1414
import torch
15-
from peft import AutoPeftModelForCausalLM, PeftModelForCausalLM, load_peft_weights
15+
from peft import AutoPeftModelForCausalLM, PeftConfig, PeftModelForCausalLM, load_peft_weights
1616
from torch import nn
1717
from transformers import GenerationConfig, StoppingCriteria, StoppingCriteriaList
1818
from transformers.generation.streamers import BaseStreamer
@@ -21,6 +21,7 @@
2121
from QEfficient.base.onnx_transforms import FP16ClipTransform, OnnxTransform, SplitTensorsTransform
2222
from QEfficient.base.pytorch_transforms import PytorchTransform
2323
from QEfficient.generation.cloud_infer import QAICInferenceSession
24+
from QEfficient.peft.lora import QEffAutoLoraModelForCausalLM
2425
from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform
2526
from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform
2627
from QEfficient.transformers.pytorch_transforms import CustomOpsTransform, KVCacheTransform
@@ -38,6 +39,7 @@ class QEffAutoPeftModelForCausalLM(QEFFBaseModel):
3839
3940
Args:
4041
:model (nn.Module): PyTorch model
42+
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
4143
4244
.. code-block:: python
4345
@@ -80,6 +82,9 @@ def __init__(self, model: nn.Module):
8082
for adapter_name in model.peft_config
8183
}
8284

85+
def __repr__(self) -> str:
86+
return self.__class__.__name__ + "\n" + self.model.__repr__()
87+
8388
@property
8489
def model_name(self) -> str:
8590
mname = self.model.get_base_model().__class__.__name__ + "-lora"
@@ -145,14 +150,31 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
145150
"""
146151
Args:
147152
:pretrained_name_or_path (str): Model card name from huggingface or local path to model directory.
153+
:finite_adapters (bool): set True to enable finite adapter mode with QEffAutoLoraModelForCausalLM class. Please refer to QEffAutoLoraModelForCausalLM for API specification.
154+
:adapter_name (str): Name used to identify loaded adapter.
148155
:args, kwargs: Additional arguments to pass to peft.AutoPeftModelForCausalLM.
149156
"""
150157
if kwargs.get("full_batch_size"):
151158
raise NotImplementedError("Continuous batching currently not supported for PEFT models")
152159
if kwargs.get("use_cache") is False:
153160
warnings.warn("Overriding to use_cache=True")
154161
kwargs["use_cache"] = True
155-
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
162+
163+
if kwargs.pop("finite_adapters", False): # initialize through finite_adapters class
164+
obj = QEffAutoLoraModelForCausalLM.from_pretrained(
165+
pretrained_model_name_or_path=PeftConfig.from_pretrained(
166+
pretrained_name_or_path
167+
).base_model_name_or_path,
168+
**kwargs,
169+
)
170+
if adapter_name := kwargs.pop("adapter_name", None):
171+
obj.load_adapter(pretrained_name_or_path, adapter_name=adapter_name)
172+
return obj
173+
if len(args) == 0 or not isinstance(list(args)[0], str):
174+
raise TypeError("Required adapter name argument in string format")
175+
obj.load_adapter(pretrained_name_or_path, list(args)[0])
176+
else:
177+
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
156178
return obj
157179

158180
def export(self, export_dir: Optional[str] = None) -> str:

QEfficient/peft/lora/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# ----------------------------------------------------------------------------
7+
8+
from QEfficient.peft.lora.auto import QEffAutoLoraModelForCausalLM
9+
10+
__all__ = [
11+
"QEffAutoLoraModelForCausalLM",
12+
]

0 commit comments

Comments
 (0)