diff --git a/.github/workflows/config/bloom-560m-ci.yaml b/.github/workflows/config/bloom-560m-ci.yaml
index 16a97d896..674644798 100644
--- a/.github/workflows/config/bloom-560m-ci.yaml
+++ b/.github/workflows/config/bloom-560m-ci.yaml
@@ -13,9 +13,3 @@ ipex:
model_description:
model_id_or_path: bigscience/bloom-560m
tokenizer_name_or_path: bigscience/bloom-560m
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
diff --git a/.github/workflows/config/gpt2-ci.yaml b/.github/workflows/config/gpt2-ci.yaml
index 1e6df57cb..7ed3f6972 100644
--- a/.github/workflows/config/gpt2-ci.yaml
+++ b/.github/workflows/config/gpt2-ci.yaml
@@ -12,10 +12,5 @@ ipex:
model_description:
model_id_or_path: gpt2
tokenizer_name_or_path: gpt2
- chat_processor: ChatModelGptJ
gpt_base_model: true
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml b/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml
index 46be6eb57..d3d96a0e1 100644
--- a/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml
+++ b/.github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml
@@ -16,13 +16,5 @@ ipex:
model_description:
model_id_or_path: meta-llama/Llama-2-7b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]
-
- '
- bot_id: ''
- stop_words: []
config:
use_auth_token: ''
diff --git a/.github/workflows/config/mpt_deltatuner.yaml b/.github/workflows/config/mpt_deltatuner.yaml
index 250004dc2..e0c0d6946 100644
--- a/.github/workflows/config/mpt_deltatuner.yaml
+++ b/.github/workflows/config/mpt_deltatuner.yaml
@@ -13,20 +13,7 @@ ipex:
model_description:
model_id_or_path: mosaicml/mpt-7b
tokenizer_name_or_path: EleutherAI/gpt-neox-20b
- chat_processor: ChatModelGptJ
peft_model_id_or_path: nathan0/mpt-7b-deltatuner-model
peft_type: deltatuner
- prompt:
- intro: 'Below is an instruction that describes a task, paired with an input that
- provides further context. Write a response that appropriately completes the request.
-
- '
- human_id: '
-
- ### Instruction'
- bot_id: '
-
- ### Response'
- stop_words: []
config:
trust_remote_code: true
diff --git a/.github/workflows/config/mpt_deltatuner_deepspeed.yaml b/.github/workflows/config/mpt_deltatuner_deepspeed.yaml
index 40051e0fa..a4fdd0709 100644
--- a/.github/workflows/config/mpt_deltatuner_deepspeed.yaml
+++ b/.github/workflows/config/mpt_deltatuner_deepspeed.yaml
@@ -13,20 +13,7 @@ ipex:
model_description:
model_id_or_path: mosaicml/mpt-7b
tokenizer_name_or_path: EleutherAI/gpt-neox-20b
- chat_processor: ChatModelGptJ
peft_model_id_or_path: nathan0/mpt-7b-deltatuner-model
peft_type: deltatuner
- prompt:
- intro: 'Below is an instruction that describes a task, paired with an input that
- provides further context. Write a response that appropriately completes the request.
-
- '
- human_id: '
-
- ### Instruction'
- bot_id: '
-
- ### Response'
- stop_words: []
config:
trust_remote_code: true
diff --git a/.github/workflows/config/opt-125m-ci.yaml b/.github/workflows/config/opt-125m-ci.yaml
index 047d0008c..96c9c345b 100644
--- a/.github/workflows/config/opt-125m-ci.yaml
+++ b/.github/workflows/config/opt-125m-ci.yaml
@@ -13,9 +13,4 @@ ipex:
model_description:
model_id_or_path: facebook/opt-125m
tokenizer_name_or_path: facebook/opt-125m
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/docs/finetune_parameters.md b/docs/finetune_parameters.md
index 5d24f42e6..ee3615d5e 100644
--- a/docs/finetune_parameters.md
+++ b/docs/finetune_parameters.md
@@ -15,6 +15,7 @@ The following are the parameters supported in the finetuning workflow.
|lora_config|task_type: CAUSAL_LM
r: 8
lora_alpha: 32
lora_dropout: 0.1|Will be passed to the LoraConfig `__init__()` method, then it'll be used as config to build Peft model object.|
|deltatuner_config|"algo": "lora"
"denas": True
"best_model_structure": "/path/to/best_structure_of_deltatuner_model"|Will be passed to the DeltaTunerArguments `__init__()` method, then it'll be used as config to build [Deltatuner model](https://github.com/intel/e2eAIOK/tree/main/e2eAIOK/deltatuner) object.|
|enable_gradient_checkpointing|False|enable gradient checkpointing to save GPU memory, but will cost more compute runtime|
+|chat_template|None|User-defined chat template.|
## Dataset Parameters
diff --git a/examples/inference/api_server_openai/query_http_requests.py b/examples/inference/api_server_openai/query_http_requests.py
index 536deb30e..a2be3873f 100644
--- a/examples/inference/api_server_openai/query_http_requests.py
+++ b/examples/inference/api_server_openai/query_http_requests.py
@@ -58,7 +58,6 @@
body = {
"model": args.model_name,
"messages": [
- {"role": "assistant", "content": "You are a helpful assistant."},
{"role": "user", "content": args.input_text},
],
"stream": args.streaming_response,
diff --git a/examples/inference/api_server_openai/query_http_requests_tool.py b/examples/inference/api_server_openai/query_http_requests_tool.py
index 217f2b792..c9efd222d 100644
--- a/examples/inference/api_server_openai/query_http_requests_tool.py
+++ b/examples/inference/api_server_openai/query_http_requests_tool.py
@@ -73,7 +73,6 @@
messages = [
[
- {"role": "user", "content": "You are a helpful assistant"},
{"role": "user", "content": "What's the weather like in Boston today?"},
],
]
@@ -81,7 +80,7 @@
proxies = {"http": None, "https": None}
for message in messages:
- print(f"User: {message[1]['content']}")
+ print(f"User: {message[0]['content']}")
print("Assistant:", end=" ", flush=True)
body = {
diff --git a/examples/inference/api_server_simple/query_single.py b/examples/inference/api_server_simple/query_single.py
index 62bb4dc45..b6d935c9a 100644
--- a/examples/inference/api_server_simple/query_single.py
+++ b/examples/inference/api_server_simple/query_single.py
@@ -55,7 +55,12 @@
)
args = parser.parse_args()
-prompt = "Once upon a time,"
+# prompt = "Once upon a time,"
+prompt = [
+ {"role": "user", "content": "Which is bigger, the moon or the sun?"},
+]
+
+
config: Dict[str, Union[int, float]] = {}
if args.max_new_tokens:
config["max_new_tokens"] = int(args.max_new_tokens)
diff --git a/llm_on_ray/common/dataprocesser/general_processer.py b/llm_on_ray/common/dataprocesser/general_processer.py
index b963611e7..b2727e97b 100644
--- a/llm_on_ray/common/dataprocesser/general_processer.py
+++ b/llm_on_ray/common/dataprocesser/general_processer.py
@@ -99,10 +99,65 @@ def torch_call(self, examples):
class GeneralProcesser(DataProcesser):
+ def tokenize_function(self, examples, tokenizer):
+ if self.config.get("gpt_base_model"):
+ instruction = examples["instruction"]
+ response = examples["response"]
+ context = examples.get("context")
+ if not instruction:
+ raise ValueError(f"Expected an instruction in: {examples}")
+ if not response:
+ raise ValueError(f"Expected a response in: {examples}")
+ if context:
+ new_message = PROMPT_WITH_INPUT_FORMAT.format(
+ instruction=instruction, response=response, input=context
+ )
+ else:
+ new_message = PROMPT_NO_INPUT_FORMAT.format(
+ instruction=instruction, response=response
+ )
+ return tokenizer(
+ new_message, add_special_tokens=False, max_length=self.config.get("max_length")
+ )
+ else:
+ new_messages = [
+ {
+ "role": "user",
+ "content": "###Instruction:\n"
+ + examples["instruction"]
+ + "\n\n"
+ + "###context:\n"
+ + examples["context"]
+ + "\n\n",
+ },
+ {"role": "assistant", "content": examples["response"] + "\n\n"},
+ ]
+ if self.config.get("chat_template") is not None:
+ tokenizer.chat_template = self.config.get("chat_template")
+ new_tokenizer = tokenizer.apply_chat_template(
+ new_messages,
+ tokenize=False,
+ )
+ elif tokenizer.chat_template is not None:
+ new_tokenizer = tokenizer.apply_chat_template(
+ new_messages,
+ tokenize=False,
+ )
+ else:
+ tokenizer.chat_template = self.config.get("default_chat_template")
+ new_tokenizer = tokenizer.apply_chat_template(
+ new_messages,
+ tokenize=False,
+ )
+ tokenizer = tokenizer(
+ new_tokenizer, add_special_tokens=False, max_length=self.config.get("max_length")
+ )
+ return tokenizer
+
def prepare(self, tokenizer, dataset):
per_device_train_batch_size = self.config.get("per_device_train_batch_size")
per_device_eval_batch_size = self.config.get("per_device_eval_batch_size")
- max_length = self.config.get("max_length")
+
group = self.config.get("group")
block_size = self.config.get("block_size")
shuffle = self.config.get("shuffle")
@@ -114,38 +169,8 @@ def prepare(self, tokenizer, dataset):
if isinstance(dataset, datasets.DatasetDict):
column_names = dataset["train"].column_names
- if column_names and TEXT_COLUMN_NAME not in column_names:
-
- def prompt(rec):
- instruction = rec["instruction"]
- response = rec["response"]
- context = rec.get("context")
- if not instruction:
- raise ValueError(f"Expected an instruction in: {rec}")
- if not response:
- raise ValueError(f"Expected a response in: {rec}")
- if context:
- rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(
- instruction=instruction, response=response, input=context
- )
- else:
- rec["text"] = PROMPT_NO_INPUT_FORMAT.format(
- instruction=instruction, response=response
- )
- return rec
-
- dataset = dataset.map(
- prompt,
- load_from_cache_file=False,
- desc="Prompt",
- )
- column_names += [TEXT_COLUMN_NAME]
-
- def tokenize_function(examples):
- return tokenizer(examples[TEXT_COLUMN_NAME], max_length=max_length)
-
tokenized_datasets = dataset.map(
- tokenize_function,
+ lambda examples: self.tokenize_function(examples, tokenizer),
remove_columns=column_names,
load_from_cache_file=False,
desc="Tokenize dataset",
diff --git a/llm_on_ray/common/trainer/default_trainer.py b/llm_on_ray/common/trainer/default_trainer.py
index 366d6f28b..e3800333c 100644
--- a/llm_on_ray/common/trainer/default_trainer.py
+++ b/llm_on_ray/common/trainer/default_trainer.py
@@ -33,6 +33,7 @@
class DefaultTrainer(Trainer):
def __init__(self, config):
self.model = None
+ self.tokenizer = None
self.config = config
dataprocesser_config = config.get("dataprocesser")
dataprocesser_type = dataprocesser_config.get("type")
@@ -121,7 +122,7 @@ def _get_lr_scheduler(
def prepare(self, model, tokenizer, dataset, optimizer, accelerator):
self._coordinate(accelerator)
-
+ self.tokenizer = tokenizer
embedding_size = model.get_input_embeddings().weight.shape[0]
logger.info(f"model embedding size: {embedding_size}")
if len(tokenizer) > embedding_size:
@@ -288,6 +289,11 @@ def train(self):
is_main_process=self.accelerator.is_main_process,
save_function=self.accelerator.save,
)
+ self.tokenizer.save_pretrained(
+ output,
+ is_main_process=self.accelerator.is_main_process,
+ save_function=self.accelerator.save,
+ )
logger.info(f"finish save model to {output}")
self.accelerator.wait_for_everyone()
diff --git a/llm_on_ray/finetune/finetune.py b/llm_on_ray/finetune/finetune.py
index 37c0481d6..14422967b 100644
--- a/llm_on_ray/finetune/finetune.py
+++ b/llm_on_ray/finetune/finetune.py
@@ -14,7 +14,7 @@
# limitations under the License.
#
-#!/usr/bin/env python
+# !/usr/bin/env python
import os
import argparse
@@ -248,6 +248,9 @@ def train_func(config: Dict[str, Any]):
"group": config["Dataset"].get("group", True),
"block_size": config["Dataset"].get("block_size", 512),
"shuffle": config["Dataset"].get("shuffle", False),
+ "gpt_base_model": config["General"].get("gpt_base_model", False),
+ "chat_template": config["General"]["chat_template"],
+ "default_chat_template": config["General"]["default_chat_template"],
},
"lr_scheduler": {
"enable": True,
diff --git a/llm_on_ray/finetune/finetune_config.py b/llm_on_ray/finetune/finetune_config.py
index a01095c16..136b698eb 100644
--- a/llm_on_ray/finetune/finetune_config.py
+++ b/llm_on_ray/finetune/finetune_config.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import os
from pydantic import BaseModel, validator
-from typing import Optional, List
+from typing import Optional, List, Dict
+from pydantic_yaml import parse_yaml_raw_as
PRECISION_BF16 = "bf16"
PRECISION_FP16 = "fp16"
@@ -60,6 +62,23 @@ class General(BaseModel):
lora_config: Optional[LoraConfig] = None
deltatuner_config: Optional[DeltatunerConfig] = None
enable_gradient_checkpointing: bool = False
+ chat_template: Optional[str] = None
+ default_chat_template: str = (
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request."
+ "{% if messages[0]['role'] == 'system' %}"
+ "{{ raise_exception('System role not supported') }}"
+ "{% endif %}"
+ "{% for message in messages %}"
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ '### Instruction: ' + message['content'] + eos_token }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '### Response:' + message['content'] + eos_token }}"
+ "{% endif %}{% endfor %}"
+ "{{'### End \n'}}"
+ )
class Dataset(BaseModel):
@@ -146,3 +165,19 @@ class FinetuneConfig(BaseModel):
General: General
Dataset: Dataset
Training: Training
+
+
+base_models: Dict[str, FinetuneConfig] = {}
+_models: Dict[str, FinetuneConfig] = {}
+
+_cur = os.path.dirname(os.path.abspath(__file__))
+_models_folder = _cur + "/models"
+for model_file in os.listdir(_models_folder):
+ file_path = _models_folder + "/" + model_file
+ if os.path.isdir(file_path):
+ continue
+ with open(file_path, "r") as f:
+ m: FinetuneConfig = parse_yaml_raw_as(FinetuneConfig, f)
+ _models[m.General.base_model] = m
+
+all_models = _models.copy()
diff --git a/llm_on_ray/inference/chat_process.py b/llm_on_ray/inference/chat_process.py
deleted file mode 100644
index 3ee238fb7..000000000
--- a/llm_on_ray/inference/chat_process.py
+++ /dev/null
@@ -1,222 +0,0 @@
-#
-# Copyright 2023 The LLM-on-Ray Authors.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-
-class ChatModel:
- human_id = ""
- bot_id = ""
- unknown_id = ""
- MEANINGLESS_WORDS = ["", "", "<|endoftext|>", "
"]
- stop_words = [""]
-
- def __init__(self, intro, human_id, bot_id, stop_words) -> None:
- self.intro = intro
- self.human_id = human_id
- self.bot_id = bot_id
- self.stop_words = stop_words
- self.MEANINGLESS_WORDS.extend(self.stop_words)
-
- def prepare_prompt(self, messages: list):
- """Prepare prompt from history messages."""
- prompt = ""
- for msg in messages:
- role, content = msg.role, msg.content
- if role == "user":
- prompt += f"{self.human_id}: {content}\n"
- elif role == "assistant":
- prompt += f"{self.bot_id}: {content}\n"
- else:
- prompt += f"{self.unknown_id}: {content}\n"
- prompt += f"{self.bot_id}:"
- return prompt
-
- def convert_output(self, output: str):
- """Convert the model output to final answer."""
- human_id = self.human_id.strip()
- bot_id = self.bot_id.strip()
- if human_id != "":
- output = output.split(human_id)[0]
- if bot_id != "":
- output = output.split(bot_id)[0]
- for word in self.MEANINGLESS_WORDS:
- output = output.replace(word, "")
- text = output
- # remove partial human_id or bot id
- if "\n" in text and (
- human_id.startswith(text[text.rfind("\n") + 1 :])
- or bot_id.startswith(text[text.rfind("\n") + 1])
- ):
- text = text[: text.rfind("\n")]
- return text
-
- def get_prompt(self, messages):
- """Generate response based on messages."""
- prompt = self.prepare_prompt(messages)
- return prompt
-
-
-class ChatModelGptJ(ChatModel):
- def __init__(self, intro, human_id, bot_id, stop_words):
- super().__init__(intro, human_id, bot_id, stop_words)
-
- def prepare_prompt(self, messages: list):
- """Prepare prompt from history messages."""
- prompt = self.intro
- for msg in messages:
- msg = dict(msg)
- role, content = msg["role"], msg["content"]
- if role == "user":
- if self.human_id != "":
- prompt += f"{self.human_id}:\n{content}\n"
- else:
- prompt += f"{content}\n"
- elif role == "assistant":
- if self.bot_id != "":
- prompt += f"{self.bot_id}:\n{content}\n"
- else:
- prompt += f"{content}\n"
- else:
- prompt += f"### Unknown:\n{content}\n"
- if self.bot_id != "":
- prompt += f"{self.bot_id}:\n"
- return prompt
-
-
-class ChatModelLLama(ChatModel):
- def __init__(self, intro, human_id, bot_id, stop_words):
- super().__init__(intro, human_id, bot_id, stop_words)
-
- def prepare_prompt(self, messages: list):
- """Prepare prompt from history messages."""
- prompt = self.intro
- for msg in messages:
- msg = dict(msg)
- role, content = msg["role"], msg["content"]
- if role == "user":
- if self.human_id != "":
- prompt += self.human_id.format(msg=content)
- else:
- prompt += f"{content}\n"
- elif role == "assistant":
- prompt += f"{content}\n"
- elif role == "tool":
- prompt += f"{content}\n"
- elif role == "system":
- prompt += f"### system:\n{content}\n"
- else:
- prompt += f"### Unknown:\n{content}\n"
- if self.bot_id != "":
- prompt += f"{self.bot_id}:\n"
- return prompt
-
-
-class ChatModelwithImage(ChatModel):
- def __init__(self, intro, human_id, bot_id, stop_words):
- super().__init__(intro, human_id, bot_id, stop_words)
-
- def prepare_prompt(self, messages: list):
- """Prepare prompt from history messages."""
- from PIL import Image
- import requests
- from io import BytesIO
- import base64
- import re
-
- prompt = self.intro
- for msg in messages:
- msg = dict(msg)
- role, content = msg["role"], msg["content"]
- text_prompt = []
- image_prompt = []
- for item in content:
- if item["type"] == "text":
- text_prompt.append(item["text"])
- elif item["type"] == "image_url":
- image_prompt.append(item["image_url"])
- else:
- raise ValueError(f"Unknown content type {item['type']}")
-
- content = "\n".join(text_prompt)
- # prepare images
- images = []
- for img in image_prompt:
- if "url" not in img:
- continue
- is_data = len(re.findall("^data:image/.+;base64,", img["url"])) > 0
- if is_data:
- encoded_str = re.sub("^data:image/.+;base64,", "", img["url"])
- images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
- else:
- images.append(Image.open(requests.get(img["url"], stream=True).raw))
-
- if role == "user":
- if self.human_id != "":
- prompt += self.human_id.format(msg=content)
- else:
- prompt += f"{content}\n"
- elif role == "assistant":
- prompt += f"{content}\n"
- else:
- prompt += f"### Unknown:\n{content}\n"
- if self.bot_id != "":
- prompt += f"{self.bot_id}:\n"
- return prompt, images
-
-
-class ChatModelGemma(ChatModel):
- def __init__(self, intro, human_id, bot_id, stop_words):
- super().__init__(intro, human_id, bot_id, stop_words)
-
- def prepare_prompt(self, messages: list):
- """Prepare prompt from history messages."""
- prompt = self.intro
- for msg in messages:
- msg = dict(msg)
- role, content = msg["role"], msg["content"]
- if role == "user":
- if self.human_id != "":
- prompt += f"{self.human_id} {content}\n"
- else:
- prompt += f"{content}\n"
- elif role == "assistant":
- if self.bot_id != "":
- prompt += f"{self.bot_id} {content}\n"
- else:
- prompt += f"{content}\n"
- else:
- prompt += f"### Unknown:\n{content}\n"
- if self.bot_id != "":
- prompt += f"{self.bot_id}:\n"
- return prompt
-
-
-class ChatModelNoFormat(ChatModel):
- def __init__(self, intro, human_id, bot_id, stop_words):
- super().__init__(intro, human_id, bot_id, stop_words)
-
- def prepare_prompt(self, messages: list):
- """Prepare prompt from history messages."""
- prompt = ""
- for msg in messages:
- msg = dict(msg)
- prompt += msg["content"]
- return prompt
-
-
-if __name__ == "__main__":
- process_tool = ChatModelGptJ(
- "", "### Instruction", "### Response", stop_words=["##", "### Instruction"]
- )
diff --git a/llm_on_ray/inference/chat_template_process.py b/llm_on_ray/inference/chat_template_process.py
new file mode 100644
index 000000000..851004b01
--- /dev/null
+++ b/llm_on_ray/inference/chat_template_process.py
@@ -0,0 +1,84 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import List
+from llm_on_ray.inference.api_openai_backend.openai_protocol import ChatMessage
+
+
+class ChatTemplatePreprocess:
+ def __init__(self, predictor) -> None:
+ self.predictor = predictor
+
+ def get_prompt(self, input: List, is_mllm=False):
+ """Generate response based on input."""
+ self.predictor.tokenizer.chat_template = (
+ self.predictor.infer_conf.model_description.chat_template
+ or self.predictor.tokenizer.chat_template
+ or self.predictor.infer_conf.model_description.default_chat_template
+ )
+
+ if isinstance(input, list) and input and isinstance(input[0], (ChatMessage, dict)):
+ messages = (
+ [dict(chat_message) for chat_message in input]
+ if isinstance(input[0], ChatMessage)
+ else input
+ )
+ prompt = self.predictor.tokenizer.apply_chat_template(
+ messages, add_generation_prompt=True, tokenize=False
+ )
+ if is_mllm:
+ texts, images = self._extract_messages(messages)
+ image = self._prepare_image(images)
+ prompt = self.predictor.tokenizer.apply_chat_template(
+ texts, add_generation_prompt=True, tokenize=False
+ )
+ return prompt, image
+ return prompt
+
+ raise TypeError(f"Unsupported type {type(input)} for text. Expected dict or list of dicts.")
+
+ def _extract_messages(self, messages):
+ texts, images = [], []
+ for message in messages:
+ if message["role"] == "user" and isinstance(message["content"], list):
+ texts.append({"role": "user", "content": message["content"][0]["text"]})
+ images.append(
+ {"role": "user", "content": message["content"][1]["image_url"]["url"]}
+ )
+ else:
+ texts.append(message)
+ return texts, images
+
+ def _prepare_image(self, messages: list):
+ from PIL import Image
+ import requests
+ from io import BytesIO
+ import base64
+ import re
+
+ images: List = []
+ for msg in messages:
+ msg = dict(msg)
+ content = msg["content"]
+ if "url" not in content:
+ continue
+ is_data = len(re.findall("^data:image/.+;base64,", content["url"])) > 0
+ if is_data:
+ encoded_str = re.sub("^data:image/.+;base64,", "", content["url"])
+ images.append(Image.open(BytesIO(base64.b64decode(encoded_str))))
+ else:
+ images.append(Image.open(requests.get(content["url"], stream=True).raw))
+
+ return images
diff --git a/llm_on_ray/inference/inference_config.py b/llm_on_ray/inference/inference_config.py
index 96833c24b..b7b58598a 100644
--- a/llm_on_ray/inference/inference_config.py
+++ b/llm_on_ray/inference/inference_config.py
@@ -112,6 +112,27 @@ class ModelDescription(BaseModel):
input_processor: str = "AutoProcessor"
model_loader: str = "AutoModel"
+ chat_model_with_image: bool = False
+ chat_template: Union[str, None] = None
+ default_chat_template: str = (
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request."
+ "{% if messages[0]['role'] == 'system' %}"
+ "{% set loop_messages = messages[1:] %}"
+ "{% set system_message = messages[0]['content'] %}"
+ "{% else %}{% set loop_messages = messages %}"
+ "{% set system_message = false %}{% endif %}"
+ "{% for message in loop_messages %}"
+ "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}"
+ "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}"
+ "{% endif %}"
+ "{% if message['role'] == 'user' %}"
+ "{{ '### Instruction: ' + message['content'].strip() }}"
+ "{% elif message['role'] == 'assistant' %}"
+ "{{ '### Response:' + message['content'].strip() }}"
+ "{% endif %}{% endfor %}"
+ "{% if add_generation_prompt %}{{'### Response:\n'}}{% endif %}"
+ )
+
@validator("quantization_type")
def _check_quant_type(cls, v: str):
if v:
@@ -170,7 +191,6 @@ def _check_workers_per_group(cls, v: int):
all_models: Dict[str, InferenceConfig] = {}
-base_models: Dict[str, InferenceConfig] = {}
_models: Dict[str, InferenceConfig] = {}
_cur = os.path.dirname(os.path.abspath(__file__))
diff --git a/llm_on_ray/inference/models/CodeLlama-7b-hf.yaml b/llm_on_ray/inference/models/CodeLlama-7b-hf.yaml
index 9ea2d77db..5cad7e6aa 100644
--- a/llm_on_ray/inference/models/CodeLlama-7b-hf.yaml
+++ b/llm_on_ray/inference/models/CodeLlama-7b-hf.yaml
@@ -6,16 +6,11 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: codellama/CodeLlama-7b-hf
tokenizer_name_or_path: codellama/CodeLlama-7b-hf
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/bloom-560m.yaml b/llm_on_ray/inference/models/bloom-560m.yaml
index ba2a6d962..92dbb8809 100644
--- a/llm_on_ray/inference/models/bloom-560m.yaml
+++ b/llm_on_ray/inference/models/bloom-560m.yaml
@@ -6,16 +6,10 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: CPU
ipex:
- enabled: true
+ enabled: false
precision: bf16
model_description:
model_id_or_path: bigscience/bloom-560m
tokenizer_name_or_path: bigscience/bloom-560m
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
diff --git a/llm_on_ray/inference/models/deepseek-coder-33b-instruct.yaml b/llm_on_ray/inference/models/deepseek-coder-33b-instruct.yaml
index 75e646a44..84f6d2a43 100644
--- a/llm_on_ray/inference/models/deepseek-coder-33b-instruct.yaml
+++ b/llm_on_ray/inference/models/deepseek-coder-33b-instruct.yaml
@@ -10,16 +10,6 @@ device: cpu
ipex:
enabled: false
precision: bf16
-model_description:
+model_description:
model_id_or_path: deepseek-ai/deepseek-coder-33b-instruct
tokenizer_name_or_path: deepseek-ai/deepseek-coder-33b-instruct
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: ['<|EOT|>', ""]
-
-
-
-
diff --git a/llm_on_ray/inference/models/deplot.yaml b/llm_on_ray/inference/models/deplot.yaml
index 4e732a4fe..acfbe3e87 100644
--- a/llm_on_ray/inference/models/deplot.yaml
+++ b/llm_on_ray/inference/models/deplot.yaml
@@ -6,22 +6,12 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: google/deplot
tokenizer_name_or_path: google/deplot
- chat_processor: ChatModelwithImage
- input_processor: 'AutoProcessor'
- model_loader: 'Pix2StructForConditionalGeneration'
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]
-
- '
- bot_id: ''
- stop_words: []
- config:
- use_auth_token: ''
+ chat_model_with_image: true
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/falcon-7b.yaml b/llm_on_ray/inference/models/falcon-7b.yaml
index 8176a2689..fbbbdda08 100644
--- a/llm_on_ray/inference/models/falcon-7b.yaml
+++ b/llm_on_ray/inference/models/falcon-7b.yaml
@@ -6,16 +6,11 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: tiiuae/falcon-7b
tokenizer_name_or_path: tiiuae/falcon-7b
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/fuyu8b.yaml b/llm_on_ray/inference/models/fuyu8b.yaml
index 551a85789..3f5fa7ab7 100644
--- a/llm_on_ray/inference/models/fuyu8b.yaml
+++ b/llm_on_ray/inference/models/fuyu8b.yaml
@@ -6,22 +6,12 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: adept/fuyu-8b
tokenizer_name_or_path: adept/fuyu-8b
- chat_processor: ChatModelwithImage
- input_processor: FuyuProcessor
- model_loader: FuyuForCausalLM
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]
-
- '
- bot_id: ''
- stop_words: []
- config:
- use_auth_token: ''
+ chat_model_with_image: true
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/gemma-2b.yaml b/llm_on_ray/inference/models/gemma-2b.yaml
index 8335857ca..b6d16b18c 100644
--- a/llm_on_ray/inference/models/gemma-2b.yaml
+++ b/llm_on_ray/inference/models/gemma-2b.yaml
@@ -6,20 +6,13 @@ cpus_per_worker: 2
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: CPU
ipex:
enabled: true
precision: bf16
model_description:
model_id_or_path: google/gemma-2b
tokenizer_name_or_path: google/gemma-2b
- chat_processor: ChatModelGemma
- prompt:
- intro: ''
- human_id: 'user
- {msg}'
- bot_id: 'model
- {msg}'
- stop_words: []
config:
use_auth_token: ' '
+ chat_template: "{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}"
diff --git a/llm_on_ray/inference/models/gpt-j-6b.yaml b/llm_on_ray/inference/models/gpt-j-6b.yaml
index c7778c12e..3bdb9997f 100644
--- a/llm_on_ray/inference/models/gpt-j-6b.yaml
+++ b/llm_on_ray/inference/models/gpt-j-6b.yaml
@@ -6,7 +6,7 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: "cpu"
ipex:
# false here for ci coverage
enabled: false
@@ -14,17 +14,4 @@ ipex:
model_description:
model_id_or_path: EleutherAI/gpt-j-6b
tokenizer_name_or_path: EleutherAI/gpt-j-6b
- chat_processor: ChatModelGptJ
gpt_base_model: true
- prompt:
- intro: 'Below is an instruction that describes a task. Write a response that appropriately
- completes the request.
-
- '
- human_id: '
-
- ### Instruction'
- bot_id: '
-
- ### Response'
- stop_words: []
diff --git a/llm_on_ray/inference/models/gpt2.yaml b/llm_on_ray/inference/models/gpt2.yaml
index 48287670a..9ad098c24 100644
--- a/llm_on_ray/inference/models/gpt2.yaml
+++ b/llm_on_ray/inference/models/gpt2.yaml
@@ -6,17 +6,12 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: CPU
ipex:
enabled: true
precision: bf16
model_description:
model_id_or_path: gpt2
tokenizer_name_or_path: gpt2
- chat_processor: ChatModelGptJ
gpt_base_model: true
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml b/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml
index d68da8428..ab411ff0e 100644
--- a/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml
+++ b/llm_on_ray/inference/models/hpu/llama-2-70b-chat-hf-hpu.yaml
@@ -10,13 +10,5 @@ device: hpu
model_description:
model_id_or_path: meta-llama/Llama-2-70b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-70b-chat-hf
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]
-
- '
- bot_id: ''
- stop_words: []
config:
use_auth_token: ''
diff --git a/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml b/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml
index 374a98f77..b7b19f02a 100644
--- a/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml
+++ b/llm_on_ray/inference/models/hpu/llama-2-7b-chat-hf-hpu.yaml
@@ -8,13 +8,5 @@ device: hpu
model_description:
model_id_or_path: meta-llama/Llama-2-7b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]
-
- '
- bot_id: ''
- stop_words: []
config:
use_auth_token: ''
diff --git a/llm_on_ray/inference/models/hpu/neural-chat-7b-v3-3.yaml b/llm_on_ray/inference/models/hpu/neural-chat-7b-v3-3.yaml
index 848358bec..35fadb820 100644
--- a/llm_on_ray/inference/models/hpu/neural-chat-7b-v3-3.yaml
+++ b/llm_on_ray/inference/models/hpu/neural-chat-7b-v3-3.yaml
@@ -14,13 +14,4 @@ ipex:
model_description:
model_id_or_path: Intel/neural-chat-7b-v3-3
tokenizer_name_or_path: Intel/neural-chat-7b-v3-3
- chat_processor: ChatModelGptJ
- prompt:
- intro: '### System:
- You are a chatbot developed by Intel. Please answer all questions to the best of your ability.'
- human_id: '
-
- ### User'
- bot_id: '
-
- ### Assistant'
+ chat_template: "'### System:You are a chatbot developed by Intel. Please answer all questions to the best of your ability.\n'{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:]}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### User: ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ '### Assistant:' + message['content'].strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'### Assistant:\n'}}{% endif %}"
diff --git a/llm_on_ray/inference/models/ipex-llm/mistral-7b-v0.1-ipex-llm.yaml b/llm_on_ray/inference/models/ipex-llm/mistral-7b-v0.1-ipex-llm.yaml
index 6a8523467..2ad30d0b8 100644
--- a/llm_on_ray/inference/models/ipex-llm/mistral-7b-v0.1-ipex-llm.yaml
+++ b/llm_on_ray/inference/models/ipex-llm/mistral-7b-v0.1-ipex-llm.yaml
@@ -14,12 +14,7 @@ model_description:
model_id_or_path: mistralai/Mistral-7B-v0.1
ipexllm: true
tokenizer_name_or_path: mistralai/Mistral-7B-v0.1
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]'
- bot_id: ''
- stop_words: []
config:
trust_remote_code: true
load_in_4bit: true
+ chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}""
diff --git a/llm_on_ray/inference/models/ipex-llm/mpt-7b-ipex-llm.yaml b/llm_on_ray/inference/models/ipex-llm/mpt-7b-ipex-llm.yaml
index d352a6517..ecb129973 100644
--- a/llm_on_ray/inference/models/ipex-llm/mpt-7b-ipex-llm.yaml
+++ b/llm_on_ray/inference/models/ipex-llm/mpt-7b-ipex-llm.yaml
@@ -14,19 +14,6 @@ model_description:
model_id_or_path: mosaicml/mpt-7b-chat
ipexllm: true
tokenizer_name_or_path: EleutherAI/gpt-neox-20b
- chat_processor: ChatModelGptJ
- prompt:
- intro: 'Below is an instruction that describes a task, paired with an input that
- provides further context. Write a response that appropriately completes the request.
-
- '
- human_id: '
-
- ### Instruction'
- bot_id: '
-
- ### Response'
- stop_words: []
config:
trust_remote_code: true
load_in_4bit: true
diff --git a/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml b/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml
index 4b3e11e98..4f2ed0194 100644
--- a/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml
+++ b/llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml
@@ -6,20 +6,12 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: meta-llama/Llama-2-7b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]
-
- '
- bot_id: ''
- stop_words: []
config:
use_auth_token: ''
diff --git a/llm_on_ray/inference/models/mistral-7b-Instruct-v0.2.yaml b/llm_on_ray/inference/models/mistral-7b-Instruct-v0.2.yaml
index 1af9aad1b..ab901eb95 100644
--- a/llm_on_ray/inference/models/mistral-7b-Instruct-v0.2.yaml
+++ b/llm_on_ray/inference/models/mistral-7b-Instruct-v0.2.yaml
@@ -5,19 +5,13 @@ cpus_per_worker: 48
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: CPU
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: mistralai/Mistral-7B-Instruct-v0.2
- ipexllm: false
+ bigdl: false
tokenizer_name_or_path: mistralai/Mistral-7B-Instruct-v0.2
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]'
- bot_id: ''
- stop_words: []
config:
trust_remote_code: true
diff --git a/llm_on_ray/inference/models/mistral-7b-v0.1.yaml b/llm_on_ray/inference/models/mistral-7b-v0.1.yaml
index c8a0ff385..db2eec1e4 100644
--- a/llm_on_ray/inference/models/mistral-7b-v0.1.yaml
+++ b/llm_on_ray/inference/models/mistral-7b-v0.1.yaml
@@ -14,11 +14,6 @@ model_description:
model_id_or_path: mistralai/Mistral-7B-v0.1
ipexllm: false
tokenizer_name_or_path: mistralai/Mistral-7B-v0.1
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]'
- bot_id: ''
- stop_words: []
config:
trust_remote_code: true
+ chat_template: "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/mpt-7b.yaml b/llm_on_ray/inference/models/mpt-7b.yaml
index 4ea12adb3..80f062a82 100644
--- a/llm_on_ray/inference/models/mpt-7b.yaml
+++ b/llm_on_ray/inference/models/mpt-7b.yaml
@@ -6,25 +6,12 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: CPU
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: mosaicml/mpt-7b
tokenizer_name_or_path: EleutherAI/gpt-neox-20b
- chat_processor: ChatModelGptJ
- prompt:
- intro: 'Below is an instruction that describes a task, paired with an input that
- provides further context. Write a response that appropriately completes the request.
-
- '
- human_id: '
-
- ### Instruction'
- bot_id: '
-
- ### Response'
- stop_words: []
config:
trust_remote_code: true
diff --git a/llm_on_ray/inference/models/neural-chat-7b-v3-1.yaml b/llm_on_ray/inference/models/neural-chat-7b-v3-1.yaml
index 13a29676c..2d6ac4d29 100644
--- a/llm_on_ray/inference/models/neural-chat-7b-v3-1.yaml
+++ b/llm_on_ray/inference/models/neural-chat-7b-v3-1.yaml
@@ -6,20 +6,11 @@ cpus_per_worker: 24
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: CPU
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: Intel/neural-chat-7b-v3-1
tokenizer_name_or_path: Intel/neural-chat-7b-v3-1
- chat_processor: ChatModelGptJ
- prompt:
- intro: '### System:
- You are a chatbot developed by Intel. Please answer all questions to the best of your ability.'
- human_id: '
-
- ### User'
- bot_id: '
-
- ### Assistant'
+ chat_template: "'### System:You are a chatbot developed by Intel. Please answer all questions to the best of your ability.\n'{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:]%}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '### User: ' + message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ '### Assistant:' + message['content'].strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{'### Assistant:\n'}}{% endif %}"
diff --git a/llm_on_ray/inference/models/opt-125m.yaml b/llm_on_ray/inference/models/opt-125m.yaml
index 545cd2145..92fd30260 100644
--- a/llm_on_ray/inference/models/opt-125m.yaml
+++ b/llm_on_ray/inference/models/opt-125m.yaml
@@ -13,9 +13,4 @@ ipex:
model_description:
model_id_or_path: facebook/opt-125m
tokenizer_name_or_path: facebook/opt-125m
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/sqlcoder-7b-2.yaml b/llm_on_ray/inference/models/sqlcoder-7b-2.yaml
index 7130148a3..e4e629599 100644
--- a/llm_on_ray/inference/models/sqlcoder-7b-2.yaml
+++ b/llm_on_ray/inference/models/sqlcoder-7b-2.yaml
@@ -5,18 +5,13 @@ cpus_per_worker: 22
gpus_per_worker: 0
deepspeed: false
workers_per_group: 2
-device: cpu
+device: "cpu"
ipex:
enabled: false
precision: bf16
model_description:
model_id_or_path: defog/sqlcoder-7b-2
tokenizer_name_or_path: defog/sqlcoder-7b-2
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: ["```"]
config:
use_auth_token: ''
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/starcoder.yaml b/llm_on_ray/inference/models/starcoder.yaml
index 0da59ac02..a57ae351d 100644
--- a/llm_on_ray/inference/models/starcoder.yaml
+++ b/llm_on_ray/inference/models/starcoder.yaml
@@ -9,15 +9,10 @@ workers_per_group: 2
ipex:
enabled: false
precision: bf16
-device: cpu
+device: "cpu"
model_description:
model_id_or_path: bigcode/starcoder
tokenizer_name_or_path: bigcode/starcoder
- chat_processor: ChatModelGptJ
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
config:
use_auth_token: ''
+ chat_template: "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ message['content'].strip() }}{% elif message['role'] == 'assistant' %}{{ message['content'].strip() }}{% endif %}{% endfor %}"
diff --git a/llm_on_ray/inference/models/template/inference_config_template.yaml b/llm_on_ray/inference/models/template/inference_config_template.yaml
index 137ddb2dc..1e6726a12 100644
--- a/llm_on_ray/inference/models/template/inference_config_template.yaml
+++ b/llm_on_ray/inference/models/template/inference_config_template.yaml
@@ -13,7 +13,7 @@ ipex:
precision: bf16
model_description:
model_id_or_path: null
- ipexllm:: false
+ ipexllm: false
tokenizer_name_or_path: null
chat_processor: null
gpt_base_model: false
@@ -22,11 +22,6 @@ model_description:
peft_model_id_or_path: null
peft_type: null
use_hpu_graphs: true
- prompt:
- intro: ''
- human_id: ''
- bot_id: ''
- stop_words: []
config:
trust_remote_code: false
use_auth_token: null
diff --git a/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml
index acbf58455..9302b9be2 100644
--- a/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml
+++ b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml
@@ -16,13 +16,5 @@ ipex:
model_description:
model_id_or_path: meta-llama/Llama-2-7b-chat-hf
tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf
- chat_processor: ChatModelLLama
- prompt:
- intro: ''
- human_id: '[INST] {msg} [/INST]
-
- '
- bot_id: ''
- stop_words: []
config:
use_auth_token: ''
diff --git a/llm_on_ray/inference/predictor_deployment.py b/llm_on_ray/inference/predictor_deployment.py
index 18b23d86b..74b9430e8 100644
--- a/llm_on_ray/inference/predictor_deployment.py
+++ b/llm_on_ray/inference/predictor_deployment.py
@@ -26,6 +26,8 @@
from starlette.requests import Request
from starlette.responses import StreamingResponse, JSONResponse
from fastapi import HTTPException
+
+from llm_on_ray.inference.chat_template_process import ChatTemplatePreprocess
from llm_on_ray.inference.inference_config import InferenceConfig
from llm_on_ray.inference.api_openai_backend.openai_protocol import (
ChatMessage,
@@ -51,31 +53,11 @@ def __init__(
max_batch_size=_DEFAULT_MAX_BATCH_SIZE,
):
self.device = torch.device(infer_conf.device)
- self.process_tool = None
- chat_processor_name = infer_conf.model_description.chat_processor
- prompt = infer_conf.model_description.prompt
self.handle_dynamic_batch.set_max_batch_size(max_batch_size)
-
- if chat_processor_name:
- try:
- module = __import__("chat_process")
- except Exception:
- sys.path.append(os.path.dirname(__file__))
- module = __import__("chat_process")
- chat_processor = getattr(module, chat_processor_name, None)
- if chat_processor is None:
- raise ValueError(
- infer_conf.name
- + " deployment failed. chat_processor("
- + chat_processor_name
- + ") does not exist."
- )
- self.process_tool = chat_processor(**prompt.dict())
-
self.use_deepspeed = infer_conf.deepspeed
self.use_vllm = infer_conf.vllm.enabled
- self.is_mllm = True if chat_processor_name in ["ChatModelwithImage"] else False
+ self.is_mllm = infer_conf.model_description.chat_model_with_image
# Used to determine if openai backend is used
self.use_openai = False
@@ -102,6 +84,7 @@ def __init__(
self.predictor = TransformerPredictor(infer_conf)
self.loop = asyncio.get_running_loop()
+ self.process_tool = ChatTemplatePreprocess(self.predictor)
def consume_streamer(self, streamer):
for text in streamer:
@@ -305,12 +288,13 @@ async def handle_static_batch(self, prompts: List[str], **config: Dict[str, Any]
preprocessing_time=0,
)
- def preprocess_prompts(self, input: Union[str, List], tools=None, tool_choice=None):
+ # TODO:Abstract the preprocess_prompts function into a class for handling chat templates
+ def preprocess_prompts(self, input: Union[str, list], tools=None, tool_choice=None):
"""
Preprocesses the input prompts.
Args:
- input (Union[str, List[str]]): The input prompt(s) to be preprocessed.
+ input (Union[str, List[dict]]): The input prompt(s) to be preprocessed.
tools (List[str]): The list of tools to be used.
tool_choice: The choice of tool to be used.
@@ -327,6 +311,7 @@ def preprocess_prompts(self, input: Union[str, List], tools=None, tool_choice=No
Raises:
HTTPException: If the input prompt format is invalid or not supported.
"""
+
if isinstance(input, str):
return input
elif isinstance(input, List):
@@ -352,7 +337,7 @@ def preprocess_prompts(self, input: Union[str, List], tools=None, tool_choice=No
# Process the input prompts with MLLM tool
if self.process_tool is not None:
if self.is_mllm:
- input, image = self.process_tool.get_prompt(input)
+ input, image = self.process_tool.get_prompt(input, self.is_mllm)
prompts.append(input)
images.extend(image)
return prompts, images
@@ -379,16 +364,15 @@ async def __call__(self, http_request: Request) -> Union[StreamingResponse, JSON
status_code=400,
content="Invalid JSON format from http request.",
)
-
streaming_response = json_request["stream"] if "stream" in json_request else False
input = json_request["text"] if "text" in json_request else ""
+
if input == "":
return JSONResponse(
status_code=400,
content="Empty prompt is not supported.",
)
config = json_request["config"] if "config" in json_request else {}
-
# return prompt or list of prompts preprocessed
prompts = self.preprocess_prompts(input)
@@ -408,9 +392,14 @@ async def openai_call(
tool_choice=None,
):
self.use_openai = True
+ print("openai_call")
+ print(input)
+ print(type(input))
# return prompt or list of prompts preprocessed
prompts = self.preprocess_prompts(input, tools, tool_choice)
+ print(prompts)
+ print(type(prompts))
# Handle streaming response
if streaming_response:
diff --git a/llm_on_ray/inference/utils.py b/llm_on_ray/inference/utils.py
index 91e311088..56b9146e5 100644
--- a/llm_on_ray/inference/utils.py
+++ b/llm_on_ray/inference/utils.py
@@ -166,7 +166,7 @@ def get_prompt_format(input: Union[List[str], List[dict], List[ChatMessage]]):
chat_format = True
prompts_format = True
for item in input:
- if isinstance(item, str) or isinstance(item, list):
+ if isinstance(item, str):
chat_format = False
elif isinstance(item, dict) or isinstance(item, ChatMessage):
prompts_format = False
diff --git a/llm_on_ray/ui/start_ui.py b/llm_on_ray/ui/start_ui.py
index c30851a8e..e7188b283 100644
--- a/llm_on_ray/ui/start_ui.py
+++ b/llm_on_ray/ui/start_ui.py
@@ -29,13 +29,11 @@
from ray.train.base_trainer import TrainingFailedError
from ray.tune.logger import LoggerCallback
from ray.util import queue
-from llm_on_ray.inference.inference_config import all_models, ModelDescription, Prompt
-from llm_on_ray.inference.inference_config import InferenceConfig as FinetunedConfig
-from llm_on_ray.inference.chat_process import (
- ChatModelGptJ,
- ChatModelLLama,
- ChatModelwithImage,
-)
+
+from llm_on_ray.finetune.finetune_config import base_models, FinetuneConfig
+from llm_on_ray.inference.inference_config import ModelDescription, all_models
+from llm_on_ray.inference.inference_config import InferenceConfig
+
from llm_on_ray.inference.predictor_deployment import PredictorDeployment
from llm_on_ray.ui.html_format import cpu_memory_html, ray_status_html, custom_css
from langchain.vectorstores import FAISS
@@ -113,8 +111,8 @@ def get_result(self):
class ChatBotUI:
def __init__(
self,
- all_models: Dict[str, FinetunedConfig],
- base_models: Dict[str, FinetunedConfig],
+ all_models: Dict[str, InferenceConfig],
+ base_models: Dict[str, FinetuneConfig],
finetune_model_path: str,
finetuned_checkpoint_path: str,
repo_code_path: str,
@@ -151,7 +149,6 @@ def __init__(
"What is Ray?",
"What is chatbot?",
]
- self.process_tool = None
self.finetune_actor = None
self.finetune_status = False
self.default_rag_path = default_rag_path
@@ -219,7 +216,6 @@ def user(self, user_message, history):
def model_generate(self, prompt, request_url, model_name, config, simple_api=True):
if simple_api:
- prompt = self.process_tool.get_prompt(prompt)
sample_input = {"text": prompt, "config": config, "stream": True}
else:
sample_input = {
@@ -231,28 +227,28 @@ def model_generate(self, prompt, request_url, model_name, config, simple_api=Tru
"top_p": config["top_p"],
"top_k": config["top_k"],
}
+
proxies = {"http": None, "https": None}
outputs = requests.post(request_url, proxies=proxies, json=sample_input, stream=True)
+
outputs.raise_for_status()
for output in outputs.iter_content(chunk_size=None, decode_unicode=True):
- # remove context
- if simple_api:
- if prompt in output:
- output = output[len(prompt) :]
- else:
- if output is None or output == "":
- continue
+ if not simple_api:
import json
import re
- chunk_data = re.sub("^data: ", "", output)
- if chunk_data != "[DONE]":
- decoded_output = json.loads(chunk_data)
- if "choices" in decoded_output:
- choices = decoded_output["choices"]
+ if output is not None and output != "":
+ # Get data from reponse chunk
+ chunk_data = re.sub("^data: ", "", output)
+ if chunk_data.strip() != "[DONE]":
+ # Get message choices from data
+ choices = json.loads(chunk_data)["choices"]
+
+ # Pick content from first choice
output = choices[0]["delta"].get("content", "")
- else:
- output = ""
+
+ else:
+ output = ""
yield output
def bot(
@@ -299,11 +295,7 @@ def bot(
for output in outputs:
if len(output) != 0:
time_end = time.time()
- if simple_api:
- history[-1][1] += output
- history[-1][1] = self.process_tool.convert_output(history[-1][1])
- else:
- history[-1][1] += output
+ history[-1][1] += output
time_spend = round(time_end - time_start, 3)
token_num += 1
new_token_latency = f"""
@@ -554,7 +546,6 @@ def finetune(
cpus_per_worker_ftn,
):
if model_name == "specify other models":
- model_desc = None
origin_model_path = custom_model_name
tokenizer_path = custom_tokenizer_name
if "gpt" in model_name.lower() or "pythia" in model_name.lower():
@@ -562,22 +553,18 @@ def finetune(
else:
gpt_base_model = False
else:
- model_desc = self._base_models[model_name].model_description
- origin_model_path = model_desc.model_id_or_path
- tokenizer_path = model_desc.tokenizer_name_or_path
- gpt_base_model = model_desc.gpt_base_model
+ finetune_config = self._base_models[model_name]
+ gpt_base_model = finetune_config.General.gpt_base_model
+
+ finetune_config = finetune_config.dict()
last_gpt_base_model = False
finetuned_model_path = os.path.join(self.finetuned_model_path, model_name, new_model_name)
- finetuned_checkpoint_path = (
- os.path.join(self.finetuned_checkpoint_path, model_name, new_model_name)
- if self.finetuned_checkpoint_path != ""
- else None
- )
- finetune_config = self.config.copy()
- training_config = finetune_config.get("Training")
- exist_worker = int(training_config["num_training_workers"])
- exist_cpus_per_worker_ftn = int(training_config["resources_per_worker"]["CPU"])
+ exist_worker = int(finetune_config["Training"].get("num_training_workers"))
+
+ exist_cpus_per_worker_ftn = int(
+ finetune_config["Training"].get("resources_per_worker")["CPU"]
+ )
ray_resources = ray.available_resources()
if "CPU" not in ray_resources or cpus_per_worker_ftn * worker_num + 1 > int(
@@ -610,22 +597,18 @@ def finetune(
if gpt_base_model:
new_ray_init_config["runtime_env"]["pip"] = ["transformers==4.26.0"]
else:
- new_ray_init_config["runtime_env"]["pip"] = ["transformers==4.31.0"]
- last_gpt_base_model = gpt_base_model
- finetune_config["Training"]["num_training_workers"] = int(worker_num)
- finetune_config["Training"]["resources_per_worker"]["CPU"] = int(cpus_per_worker_ftn)
+ new_ray_init_config["runtime_env"]["pip"] = ["transformers==4.38.1"]
ray.init(**new_ray_init_config)
- exist_worker = worker_num
- exist_cpus_per_worker_ftn = cpus_per_worker_ftn
finetune_config["Dataset"]["train_file"] = dataset
- finetune_config["General"]["base_model"] = origin_model_path
+ if origin_model_path is not None:
+ finetune_config["General"]["base_model"] = origin_model_path
+ if tokenizer_path is not None:
+ finetune_config["General"]["tokenizer_name"] = tokenizer_path
finetune_config["Training"]["epochs"] = num_epochs
finetune_config["General"]["output_dir"] = finetuned_model_path
- finetune_config["General"]["config"]["trust_remote_code"] = True
- if finetuned_checkpoint_path:
- finetune_config["General"]["checkpoint_dir"] = finetuned_checkpoint_path
+
finetune_config["Training"]["batch_size"] = batch_size
finetune_config["Training"]["learning_rate"] = lr
if max_train_step != 0:
@@ -657,6 +640,9 @@ def finetune(
self.finetune_status = False
# todo: a more reasonable solution is needed
try:
+ print("Start fine-tuning")
+ print(finetune_config)
+
results = main(finetune_config)
if results.metrics["done"]:
self.finetune_status = True
@@ -672,21 +658,19 @@ def finetune(
ray.kill(self.finetune_actor)
self.finetune_actor = None
- new_prompt = Prompt()
- new_prompt.intro = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n"
- new_prompt.human_id = "\n### Instruction"
- new_prompt.bot_id = "\n### Response"
- new_prompt.stop_words.extend(
- ["### Instruction", "# Instruction", "### Question", "##", " ="]
- )
- new_model_desc = ModelDescription(
- model_id_or_path=finetuned_model_path,
- tokenizer_name_or_path=tokenizer_path,
- prompt=new_prompt,
- chat_processor=model_desc.chat_processor if model_desc is not None else "ChatModelGptJ",
- )
+ if finetune_config["General"].get("lora_config", None) is not None:
+ new_model_desc = ModelDescription(
+ model_id_or_path=finetune_config["General"].get("base_model"),
+ tokenizer_name_or_path=finetuned_model_path,
+ peft_model_id_or_path=finetuned_model_path,
+ )
+ else:
+ new_model_desc = ModelDescription(
+ model_id_or_path=finetuned_model_path,
+ tokenizer_name_or_path=finetuned_model_path,
+ )
new_model_desc.config.trust_remote_code = True
- new_finetuned = FinetunedConfig(
+ new_finetuned = InferenceConfig(
name=new_model_name,
route_prefix="/" + new_model_name,
model_description=new_model_desc,
@@ -748,20 +732,8 @@ def deploy_func(
finetuned = self._all_models[model_name]
model_desc = finetuned.model_description
- prompt = model_desc.prompt
print("model path: ", model_desc.model_id_or_path)
- if model_desc.chat_processor is not None:
- chat_model = getattr(sys.modules[__name__], model_desc.chat_processor, None)
- if chat_model is None:
- return (
- model_name
- + " deployment failed. "
- + model_desc.chat_processor
- + " does not exist."
- )
- self.process_tool = chat_model(**prompt.dict())
-
finetuned_deploy = finetuned.copy(deep=True)
if hpus_per_worker_deploy > 0:
finetuned_deploy.device = "hpu"
@@ -780,7 +752,7 @@ def deploy_func(
elif "fuyu-8b" in model_name:
pip_env = "transformers==4.37.2"
else:
- pip_env = "transformers==4.31.0"
+ pip_env = "transformers==4.38.1"
if finetuned_deploy.device == "cpu":
ray_actor_options["runtime_env"] = {"pip": [pip_env]}
deployment = PredictorDeployment.options( # type: ignore
@@ -1803,9 +1775,10 @@ def _init_ui(self):
default_rag_path = args.default_rag_path
initial_model_list = {k: all_models[k] for k in sorted(all_models.keys())}
+ initial_base_model_list = {k: base_models[k] for k in sorted(base_models.keys())}
ui = ChatBotUI(
initial_model_list,
- initial_model_list,
+ initial_base_model_list,
finetune_model_path,
finetune_checkpoint_path,
repo_path,
diff --git a/pyproject.toml b/pyproject.toml
index b319045cc..451d2649d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,7 +34,8 @@ dependencies = [
"deltatuner==1.1.9",
"py-cpuinfo",
"pydantic-yaml",
- "async_timeout",
+ "async-timeout",
+ "jinja2>=3.0.0",
"typer"
]
diff --git a/tests/finetune/test_chat_template.py b/tests/finetune/test_chat_template.py
new file mode 100644
index 000000000..a416d8f7b
--- /dev/null
+++ b/tests/finetune/test_chat_template.py
@@ -0,0 +1,156 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+import transformers
+from transformers import AutoTokenizer
+from llm_on_ray.common.dataprocesser.general_processer import GeneralProcesser
+
+
+class TestTokenizeFunction(unittest.TestCase):
+ def setUp(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+ self.config = {
+ "gpt_base_model": True,
+ "max_length": 512,
+ "trust_remote_code": False,
+ "chat_template": "Below is an instruction that describes a task. Write a response that appropriately "
+ "completes the request\n {% if messages[0]['role'] == 'system' %}{{ raise_exception("
+ "'System role not supported') }}{% endif %}{% for message in messages %}{% if (message["
+ "'role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles "
+ "must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] "
+ "== 'user' %}{{ '### Instruction: ' + message['content'] }}{% elif message['role'] == "
+ "'assistant' %}{{ '### Response: ' + message['content'] }}{% endif %}{% endfor %}{{'### "
+ "End \n'}}",
+ }
+ self.processer = GeneralProcesser(self.config)
+
+ def test_tokenize_function_with_gpt_model(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
+
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ # Verify the format of the result
+ expected_result = (
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request.\n"
+ "\n"
+ "### Instruction:\n"
+ "Test instruction\n"
+ "\n"
+ "Input:\n"
+ "Test context\n"
+ "\n"
+ "### Response:\n"
+ "Test response\n"
+ "\n"
+ "### End"
+ )
+
+ result = self.processer.tokenize_function(examples, self.tokenizer)
+ self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
+
+ def test_tokenize_function_with_custom_chat_template(self):
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ # Verify the format of the result
+ expected_result = (
+ "<|im_start|>user\n"
+ "###Instruction:\n"
+ "Test instruction\n"
+ "\n"
+ "###context:\n"
+ "Test context\n"
+ "\n"
+ "<|im_end|><|im_start|>assistant\n"
+ "Test response\n"
+ "\n"
+ "<|im_end|>"
+ )
+ # Set custom chat template
+ self.config["custom_chat_template"] = (
+ "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n'"
+ "+ message['content'] + '<|im_end|>'}}{% endfor %}"
+ )
+
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_function(examples, self.tokenizer)
+ self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
+
+ def test_tokenize_function_with_chat_template(self):
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ # Verify the format of the result
+ expected_result = (
+ "Below is an instruction that describes a task. Write a response that "
+ "appropriately completes the request\n"
+ "### Instruction: ###Instruction:\n"
+ "Test instruction\n"
+ "\n"
+ "###context:\n"
+ "Test context\n"
+ "\n"
+ "### Response: Test response\n"
+ "\n"
+ "### End \n"
+ )
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_function(examples, self.tokenizer)
+ self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
+
+ def test_tokenize_function_with_default_chat_template(self):
+ self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
+ examples = {
+ "instruction": "Test instruction",
+ "response": "Test response",
+ "context": "Test context",
+ }
+
+ chat_example = [
+ {
+ "role": "user",
+ "content": "###Instruction:\nTest instruction\n\n###context:\nTest context\n\n",
+ },
+ {
+ "role": "assistant",
+ "content": "Test response\n\n",
+ },
+ ]
+
+ # Verify the format of the result
+ expected_result = self.tokenizer.apply_chat_template(
+ chat_example, tokenize=False, max_length=self.config.get("max_length")
+ )
+
+ self.config["gpt_base_model"] = False
+ result = self.processer.tokenize_function(examples, self.tokenizer)
+ self.assertEqual(self.tokenizer.decode(result["input_ids"]), expected_result)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/inference/test_query_single.py b/tests/inference/test_query_single.py
new file mode 100644
index 000000000..d48727a30
--- /dev/null
+++ b/tests/inference/test_query_single.py
@@ -0,0 +1,123 @@
+#
+# Copyright 2023 The LLM-on-Ray Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import subprocess
+import pytest
+import os
+
+os.environ["no_proxy"] = "localhost,127.0.0.1"
+
+
+def start_serve(model_name):
+ current_path = os.path.dirname(os.path.abspath(__file__))
+
+ config_path = os.path.join(
+ current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml"
+ )
+
+ cmd_serve = ["llm_on_ray-serve", "--config_file", config_path, "--simple"]
+
+ result_serve = subprocess.run(cmd_serve, capture_output=True, text=True)
+
+ # Ensure there are no errors in the serve script execution
+ assert result_serve.returncode == 0, print(
+ "\n" + "Serve error stderr message: " + "\n", result_serve.stderr
+ )
+
+ # Print the output of subprocess.run for checking if output is expected
+ print("\n" + "Serve message: " + "\n", result_serve.stdout)
+
+ # Ensure there are no errors in the serve script execution
+ assert "Error" not in result_serve.stderr
+
+
+def script_with_args(
+ base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k
+):
+ current_path = os.path.dirname(os.path.abspath(__file__))
+
+ os.path.join(current_path, "../../.github/workflows/config/" + model_name + "-ci.yaml")
+
+ example_query_single_path = os.path.join(
+ current_path, "../../examples/inference/api_server_simple/query_single.py"
+ )
+
+ cmd_single = [
+ "python",
+ example_query_single_path,
+ "--model_endpoint",
+ base_url + model_name,
+ ]
+
+ if streaming_response:
+ cmd_single.append("--streaming_response")
+
+ if max_new_tokens is not None:
+ cmd_single.extend(["--max_new_tokens", str(max_new_tokens)])
+
+ if temperature is not None:
+ cmd_single.extend(["--temperature", str(temperature)])
+
+ if top_p is not None:
+ cmd_single.extend(["--top_p", str(top_p)])
+
+ if top_k is not None:
+ cmd_single.extend(["--top_k", str(top_k)])
+
+ result_query_single = subprocess.run(cmd_single, capture_output=True, text=True)
+
+ # Print the output of subprocess.run for checking if output is expected
+ print(result_query_single)
+
+ # Ensure there are no errors in the OpenAI API query script execution
+ assert "Error" not in result_query_single.stderr
+
+ # Returncode should be 0 when there is no exception
+ assert result_query_single.returncode == 0
+
+
+executed_models = {}
+
+
+# Parametrize the test function with different combinations of parameters
+# TODO: more models and combinations will be added and tested.
+@pytest.mark.parametrize(
+ "base_url,model_name,streaming_response,max_new_tokens,temperature,top_p, top_k",
+ [
+ (base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k)
+ for base_url in ["http://localhost:8000/"]
+ for model_name in ["gpt2"]
+ for streaming_response in [None]
+ for max_new_tokens in [None]
+ for temperature in [None]
+ for top_p in [None]
+ for top_k in [None]
+ ],
+)
+def test_script(
+ base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k
+):
+ global executed_models
+
+ # Check if this modelname has already executed start_serve
+ if model_name not in executed_models:
+ start_serve(model_name)
+ # Mark this modelname has already executed start_serve
+ executed_models[model_name] = True
+
+ script_with_args(
+ base_url, model_name, streaming_response, max_new_tokens, temperature, top_p, top_k
+ )
diff --git a/tests/test_getting_started.sh b/tests/test_getting_started.sh
index 6a900a553..052ac51bb 100755
--- a/tests/test_getting_started.sh
+++ b/tests/test_getting_started.sh
@@ -33,7 +33,7 @@ curl $ENDPOINT_URL/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt2",
- "messages": [{"role": "assistant", "content": "You are a helpful assistant."}, {"role": "user", "content": "Hello!"}],
+ "messages": [{"role": "user", "content": "Hello!"}],
"temperature": 0.7
}'