diff --git a/.github/license/header_exclude_files.txt b/.github/license/header_exclude_files.txt index e69de29bb..9563fbc43 100644 --- a/.github/license/header_exclude_files.txt +++ b/.github/license/header_exclude_files.txt @@ -0,0 +1 @@ +vllm-ext/vllm/extension/ns/__init__.py \ No newline at end of file diff --git a/.github/workflows/workflow_inference.yml b/.github/workflows/workflow_inference.yml index ca57affac..f38243fa4 100644 --- a/.github/workflows/workflow_inference.yml +++ b/.github/workflows/workflow_inference.yml @@ -34,7 +34,7 @@ jobs: name: inference strategy: matrix: - model: [ gpt-j-6b, gpt2, bloom-560m, opt-125m, mpt-7b, mistral-7b-v0.1, mpt-7b-ipex-llm, neural-chat-7b-v3-1, CodeLlama-7b-hf, falcon-7b, starcoder, llama-2-7b-chat-hf, llama-2-7b-chat-hf-vllm, gemma-2b, deepseek-coder-33b-instruct] + model: [ gpt-j-6b, gpt2, bloom-560m, opt-125m, mpt-7b, mistral-7b-v0.1, mpt-7b-ipex-llm, neural-chat-7b-v3-1, CodeLlama-7b-hf, falcon-7b, starcoder, llama-2-7b-chat-hf, llama-2-7b-chat-hf-vllm, llama-2-7b-chat-hf-vllm-ns, gemma-2b, deepseek-coder-33b-instruct] isPR: - ${{inputs.ci_type == 'pr'}} @@ -97,7 +97,11 @@ jobs: run: | TARGET=${{steps.target.outputs.target}} source dev/scripts/ci-functions.sh - strat_ray ${TARGET} + if [[ "$TARGET" == *ns ]]; then + start_ray ${TARGET} 1 + else + start_ray ${TARGET} + fi - name: Run Inference Test run: | diff --git a/.github/workflows/workflow_inference_gaudi2.yml b/.github/workflows/workflow_inference_gaudi2.yml index dedeb4154..a24f56397 100644 --- a/.github/workflows/workflow_inference_gaudi2.yml +++ b/.github/workflows/workflow_inference_gaudi2.yml @@ -94,7 +94,7 @@ jobs: # check and remove exited container cid=$(docker ps -a -q --filter "name=${TARGET}") if [[ ! -z "$cid" ]]; then docker rm $cid; fi - docker run -tid --name="${TARGET}" --hostname="${TARGET}-container" --runtime=habana -v /home/yizhong/Model-References:/root/Model-References -v ${{ inputs.code_checkout_path }}:/root/llm-on-ray -v ${{ inputs.model_cache_path }}:/root/.cache/huggingface/hub/ -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --cap-add sys_ptrace --net=host --ipc=host ${TARGET}:habana + docker run -tid --privileged --name="${TARGET}" --hostname="${TARGET}-container" --runtime=habana -v /home/yizhong/Model-References:/root/Model-References -v ${{ inputs.code_checkout_path }}:/root/llm-on-ray -v ${{ inputs.model_cache_path }}:/root/.cache/huggingface/hub/ -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --cap-add sys_ptrace --net=host --ipc=host ${TARGET}:habana - name: Start Ray Cluster run: | TARGET=${{steps.target.outputs.target}} diff --git a/.github/workflows/workflow_test_benchmark.yml b/.github/workflows/workflow_test_benchmark.yml index d737675c0..37c0518e2 100644 --- a/.github/workflows/workflow_test_benchmark.yml +++ b/.github/workflows/workflow_test_benchmark.yml @@ -80,7 +80,7 @@ jobs: # check and remove exited container cid=$(docker ps -a -q --filter "name=${TARGET}") if [[ ! -z "$cid" ]]; then docker rm $cid; fi - docker run -tid -v ${{ inputs.model_cache_path }}:/root/.cache/huggingface/hub -v ${{ inputs.code_checkout_path }}:/root/llm-on-ray -e http_proxy=${{ inputs.http_proxy }} -e https_proxy=${{ inputs.https_proxy }} --name="${TARGET}" --hostname="${TARGET}-container" ${TARGET}:latest + docker run -tid --privileged -v ${{ inputs.model_cache_path }}:/root/.cache/huggingface/hub -v ${{ inputs.code_checkout_path }}:/root/llm-on-ray -e http_proxy=${{ inputs.http_proxy }} -e https_proxy=${{ inputs.https_proxy }} --name="${TARGET}" --hostname="${TARGET}-container" ${TARGET}:latest - name: Start Ray Cluster run: | diff --git a/.github/workflows/workflow_tests.yml b/.github/workflows/workflow_tests.yml index eb51d4f0f..c9af35b4c 100644 --- a/.github/workflows/workflow_tests.yml +++ b/.github/workflows/workflow_tests.yml @@ -176,7 +176,7 @@ jobs: run: | TARGET=${{steps.target.outputs.target}} source dev/scripts/ci-functions.sh - strat_ray ${TARGET} + start_ray ${TARGET} - name: Run Tests run: | diff --git a/.gitignore b/.gitignore index 0ed6093cc..757c2b821 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,9 @@ build/lib/ *.json *.txt *.egg-info +.eggs +*.log +*.so +*.ninja_log +build/ +runtime_outs/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5e56e52a6..fd7188b68 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,12 @@ repos: hooks: - id: ruff args: [ --fix, --exit-non-zero-on-fix, --ignore=E402, --ignore=E501, --ignore=E731, --ignore=F401] + exclude: | + (?x)^( + examples/inference/vllm/ray-vllm-examples/llm.py| + vllm-ext/vllm/extension/ns/__init__.py| + )$ + # Black needs to be ran after ruff with --fix - repo: https://github.com/psf/black @@ -18,7 +24,18 @@ repos: rev: "v0.981" hooks: - id: mypy - exclude: tests + exclude: | + (?x)^( + tests| + vllm-ext/vllm/extension/ns/model/ns_loader.py| + vllm-ext/vllm/extension/ns/kv_cache/ns_cache.py| + vllm-ext/inference_engine/python/inference_engine/| + vllm-ext/setup.py| + examples/inference/vllm/ray-vllm-examples/llm.py| + llm_on_ray/inference/inference_config.py| + vllm-ext/vllm/extension/ns/ + ) + additional_dependencies: - mypy-extensions - pydantic==1.10.0 diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 2a2a8570a..2eba3aeeb 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -50,7 +50,7 @@ import copy # (prompt str, output str, prompt len, output len, request latency, latencies list) -latency_tracking: List[Tuple[Optional[str], Optional[str], int, int, float, List[float]]] = [] +latency_tracking: List[Tuple[Optional[List[str]], Optional[str], int, int, float, List[float]]] = [] def sample_requests_ShareGPT( @@ -62,7 +62,7 @@ def sample_requests_ShareGPT( min_output_tokens_len: int, max_output_tokens_len: int, max_length: int, -) -> List[Tuple[str, int, int]]: +) -> List[Tuple[List[str], int, int]]: """ Sample requests from a dataset of ShareGPT format. @@ -94,11 +94,11 @@ def sample_requests_ShareGPT( tokenized_dataset = [] for i in range(len(dataset)): output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + tokenized_dataset.append(([prompts[i]], prompt_token_ids[i], output_len)) # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: + filtered_dataset: List[Tuple[List[str], int, int]] = [] + for prompts, prompt_token_ids, output_len in tokenized_dataset: prompt_len = len(prompt_token_ids) # Prune too short sequences. if (min_input_tokens_len is not None and prompt_len < min_input_tokens_len) or ( @@ -112,7 +112,7 @@ def sample_requests_ShareGPT( continue if max_length is not None and prompt_len + output_len > max_length: continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((prompts, prompt_len, output_len)) # Sample the requests. sampled_requests = random.sample(filtered_dataset, num_requests) @@ -126,7 +126,7 @@ def sample_requests_IPEX( max_new_tokens: int, num_requests: int, tokenizer: PreTrainedTokenizer, -) -> List[Tuple[str, int, int]]: +) -> List[Tuple[List[str], int, int]]: """ Sample requests from a dataset of IPEX format. @@ -149,10 +149,58 @@ def sample_requests_IPEX( raise ValueError(f'Invalid input_tokens to index from dataset "{dataset_path}"!') prompt_len = len(tokenizer(prompt).input_ids) + print("prompt len, ", prompt_len) output_len = prompt_len if not max_new_tokens else max_new_tokens # Duplicate prompt to generate samples - sampled_requests = [(prompt, prompt_len, output_len)] * num_requests + sampled_requests = [([prompt], prompt_len, output_len)] * num_requests + + return sampled_requests + + +def sample_requests_IDC( + dataset_path: str, + max_new_tokens: int, + num_requests: int, + tokenizer: PreTrainedTokenizer, + config: Dict[str, Union[int, float]], +) -> List[Tuple[List[str], int, int]]: + """ + Sample requests from a dataset of IPEX format. + + Args: + dataset_path (str): The path to the dataset. + input_tokens (str): The input tokens. + max_new_tokens (int): The maximum number of new tokens. + num_requests (int): The number of requests to sample. + tokenizer (PreTrainedTokenizer): The tokenizer. + + Returns: + List[Tuple[str, int, int]]: The sampled requests, each represented as a tuple of (prompt, prompt_len, output_len). + """ + with open(dataset_path) as f: + input = json.load(f) + + if len(input["messages"]) == 2: + prompts = [input["messages"][0]["content"], input["messages"][1]["content"]] + else: + raise ValueError(f'Invalid input_tokens to index from dataset "{dataset_path}"!') + + prompt_len = sum([len(input_id) for input_id in tokenizer(prompts).input_ids]) + print("prompt len, ", prompt_len) + output_len = input["max_tokens"] if "max_tokens" in input else max_new_tokens + + # Duplicate prompt to generate samples + sampled_requests = [(prompts, prompt_len, output_len)] * num_requests + + if "max_tokens" in input: + config["max_new_tokens"] = int(input["max_tokens"]) + if "temperature" in input: + config["temperature"] = float(input["temperature"]) + if "top_p" in input: + config["top_p"] = float(input["top_p"]) + if "top_k" in input: + config["top_k"] = float(input["top_k"]) return sampled_requests @@ -165,7 +213,7 @@ def sample_requests_synthesis( output_len_mean: int, output_len_stddev: int, num_requests: int, -) -> List[Tuple[str, int, int]]: +) -> List[Tuple[List[str], int, int]]: """ Sample requests from random generated prompts. @@ -194,14 +242,14 @@ def gen_prompt_ids(prompt_len): # Generte random prompt from tokenizer's vocabulary prompt = tokenizer.decode(gen_prompt_ids(prompt_len), return_tensors="pt") - sampled_requests.append((prompt, prompt_len, output_len)) + sampled_requests.append(([prompt], prompt_len, output_len)) return sampled_requests async def get_request( - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[str], int, int]], request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: +) -> AsyncGenerator[Tuple[List[str], int, int], None]: """ Asynchronously generates requests based on the input_requests and request_rate. @@ -230,7 +278,7 @@ async def get_request( async def send_request( api_url: str, model_name: str, - prompt: str, + prompts: List[str], prompt_len: int, output_len: int, config: dict, @@ -262,16 +310,27 @@ async def send_request( temp_config["max_new_tokens"] = output_len if simple: pload = { - "text": prompt, + "text": prompts[0], "config": temp_config, "stream": track_token_latency, } + if vllm_engine: + pload.update({"ignore_eos": False}) else: + if len(prompts) == 1: + messages = [ + {"role": "user", "content": f"{prompts[0]}"}, + ] + elif len(prompts) == 2: + messages = [ + {"role": "system", "content": f"{prompts[0]}"}, + {"role": "user", "content": f"{prompts[1]}"}, + ] + else: + raise ValueError("number of prompts should be 1 or 2. Actual ", len(prompts)) pload = { "model": model_name, - "messages": [ - {"role": "user", "content": f"{prompt}"}, - ], + "messages": messages, "stream": track_token_latency, "max_tokens": temp_config["max_new_tokens"] if "max_new_tokens" in temp_config @@ -280,15 +339,16 @@ async def send_request( "top_p": temp_config["top_p"] if "top_p" in temp_config else None, } if vllm_engine: - pload.update({"ignore_eos": True}) + pload.update({"ignore_eos": False}) token_latencies_per_request: List[float] = [] - timeout = aiohttp.ClientTimeout(total=3 * 3600) + timeout = aiohttp.ClientTimeout(total=5 * 3600) async with aiohttp.ClientSession(timeout=timeout) as session: while True: async with session.post(api_url, headers=headers, json=pload) as response: chunks = [] + decoded_chunks = [] start_ts = time.perf_counter() @@ -298,8 +358,11 @@ async def send_request( if track_token_latency: token_latencies_per_request.append(latency) start_ts = end_ts - chunks.append(chunk) - print(chunk.decode("utf-8") + "|", end="", flush=True) + decoded_chunk = chunk.decode("utf-8") + if decoded_chunk: + chunks.append(chunk) + decoded_chunks.append(decoded_chunk) + print(decoded_chunk + "|", end="", flush=True) print("Token Latencies:", token_latencies_per_request) # print(len(chunks), len(token_latencies_per_request)) # Decode the response @@ -325,7 +388,16 @@ async def send_request( response_content = chunks[-2].decode("utf-8") response_content = json.loads(response_content.split("data: ")[1]) generate_len = response_content["usage"]["completion_tokens"] - response_text = b"".join(chunks).decode("utf-8") + response_texts = [] + for decoded_chunk in decoded_chunks: + text = decoded_chunk.split("data: ")[1] + if text.startswith("{"): + json_text = json.loads(text) + if ( + "choices" in json_text + and "content" in json_text["choices"][0]["delta"] + ): + response_texts.append(json_text["choices"][0]["delta"]["content"]) else: response_text = b"".join(chunks).decode("utf-8") try: @@ -341,10 +413,12 @@ async def send_request( progress_bar.update() break + if args.track_token_latency: + print("response: ", "".join(response_texts)) request_end_time = time.perf_counter() request_latency = request_end_time - request_start_time - prompt_str = prompt if track_input_output else None + prompt_str = prompts if track_input_output else None output_str = response_text if track_input_output else None if generate_len is not None: @@ -363,7 +437,7 @@ async def send_request( async def benchmark( api_url: str, model_name: str, - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[str], int, int]], request_rate: float, config: dict, tokenizer: PreTrainedTokenizer, @@ -378,7 +452,7 @@ async def benchmark( Args: api_url (str): The URL of the API. - input_requests (List[Tuple[str, int, int]]): A list of input requests, where each request is a tuple + input_requests (List[Tuple[List[str], int, int]]): A list of input requests, where each request is a tuple containing the prompt, prompt length, and output length. request_rate (float): The rate at which requests should be sent, in requests per second. config (dict): Configuration parameters for sending requests. @@ -437,6 +511,8 @@ def main(args: argparse.Namespace): tokenizer_name_or_path, trust_remote_code=args.trust_remote_code ) + config: Dict[str, Union[int, float]] = {} + if args.dataset_format == "ShareGPT": input_requests = sample_requests_ShareGPT( args.dataset, @@ -466,10 +542,16 @@ def main(args: argparse.Namespace): args.output_len_stddev, args.num_prompts, ) + if args.dataset_format == "IDC": + input_requests = sample_requests_IDC( + args.dataset, + args.max_new_tokens, + args.num_prompts, + tokenizer, + config, + ) - config: Dict[str, Union[int, float]] = {} - - if args.max_new_tokens: + if args.max_new_tokens and "max_new_tokens" not in config: config["max_new_tokens"] = int(args.max_new_tokens) if args.temperature: config["temperature"] = float(args.temperature) @@ -480,7 +562,7 @@ def main(args: argparse.Namespace): config["do_sample"] = args.do_sample # In order to align with vllm test parameters if args.vllm_engine: - config["ignore_eos"] = True + config["ignore_eos"] = False benchmark_start_time = time.perf_counter() asyncio.run( @@ -532,16 +614,21 @@ def main(args: argparse.Namespace): ] ) print(f"Average latency per Token: {avg_per_token_latency:.3f} s") - + first_token_index = 0 if args.simple else 1 + next_token_index = 1 if args.simple else 2 if args.track_token_latency and latency_tracking: avg_first_token_latency = np.mean( - [latencies[0] for _, _, _, _, _, latencies in latency_tracking if latencies != []] + [ + latencies[first_token_index] + for _, _, _, _, _, latencies in latency_tracking + if latencies != [] + ] ) avg_next_token_latency = np.mean( [ - np.mean(latencies[1:]) + np.mean(latencies[next_token_index:]) for _, _, _, _, _, latencies in latency_tracking - if latencies[1:] != [] + if latencies[next_token_index:] != [] ] ) @@ -614,7 +701,7 @@ def main(args: argparse.Namespace): parser.add_argument( "--dataset-format", type=str, - choices=["ShareGPT", "IPEX", "Synthesis"], + choices=["ShareGPT", "IPEX", "Synthesis", "IDC"], required=True, help="Dataset format, should be one of {ShareGPT, IPEX, Synthesis}.", ) diff --git a/benchmarks/calc_stats.py b/benchmarks/calc_stats.py new file mode 100644 index 000000000..de209648c --- /dev/null +++ b/benchmarks/calc_stats.py @@ -0,0 +1,191 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys +import re +from typing import Dict, List + +if len(sys.argv) < 4: + raise ValueError( + "need arguments, file path, number of expected iterations and expected generated token length" + ) + +file_path = sys.argv[1] +nbr_iter = int(sys.argv[2]) +expected_gen_token_len = int(sys.argv[3]) + +with open(file_path) as f: + lines = f.readlines() + +PAT_NBR_USERS = re.compile(r"Run num_prompts (\d+) (.+)") +PAT_ITER = re.compile(r"Run iter (\d+)") +PAT_ACTUAL_LEN = re.compile( + r"Warning: the actual generated length is (\d+), which is different from the expected output length\((\d+)\)\." +) +PAT_TOTAL_TIME = re.compile(r"Total time: ([^ ]+) s") +PAT_PROMPT_LEN = re.compile(r"Prompt Length \(Min/Med/Max\): (\d+).+") +PAT_REQ_TPT = re.compile(r"Request Throughput \(QPS\): ([^ ]+) requests/s") +PAT_INPUT_TPT = re.compile(r"Input Token Throughput: ([^ ]+) tokens/s") +PAT_OUTPUT_TPT = re.compile(r"output Token Throughput: ([^ ]+) tokens/s") +PAT_REQ_LAT = re.compile(r"Average latency per Request: ([^ ]+) s") +PAT_TOK_LAT = re.compile(r"Average latency per Token: ([^ ]+) s") +PAT_FTOK_LAT = re.compile(r"Average latency for First Tokens: ([^ ]+) s") +PAT_NTOK_LAT = re.compile(r"Average latency for Next Tokens: ([^ ]+) s") + +nbr_users_perf: Dict[int, List[Dict[str, float]]] = {} + +token_lengths: List[int] = [] + +state = 0 +current_nbr_user = -1 +current_iter = -1 + +for no, line in enumerate(lines): + if state == 0: + m = PAT_NBR_USERS.match(line) + if m: + current_nbr_user = int(m.group(1)) + print("collecting number of users (num_prompts): ", current_nbr_user) + nbr_users_perf[current_nbr_user] = [] + state = 1 + elif state == 1: + m = PAT_ITER.match(line) + if m: + current_iter = int(m.group(1)) - 1 + nbr_users_perf[current_nbr_user].append({}) + state = 2 + elif state == 2: + m = PAT_ACTUAL_LEN.match(line) + if m: + metrics = nbr_users_perf[current_nbr_user][current_iter] + print(">>>", line, m.group(1)) + token_lengths.append(int(m.group(1))) + if expected_gen_token_len != int(m.group(2)): + raise ValueError( + "expected token lengths are not equal", expected_gen_token_len, m.group(2) + ) + else: + m = PAT_TOTAL_TIME.match(line) + if m: + metrics = nbr_users_perf[current_nbr_user][current_iter] + full_gen_lens = token_lengths + [512] * (current_nbr_user - len(token_lengths)) + metrics["ACT_GEN_TOKENS"] = float(sum(full_gen_lens)) / current_nbr_user + metrics["TOTAL_TIME"] = float(m.group(1)) + token_lengths = [] + state = 4 + elif state == 3: + m = PAT_TOTAL_TIME.match(line) + if m: + metrics["TOTAL_TIME"] = float(m.group(1)) + state = 4 + elif state == 4: + m = PAT_PROMPT_LEN.match(line) + if m: + metrics["PROMPT_LEN"] = float(m.group(1)) + state = 5 + else: + raise ValueError("invalid line: " + line, no) + elif state == 5: + m = PAT_REQ_TPT.match(line) + if m: + metrics["REQ_TPT"] = float(m.group(1)) + state = 6 + else: + raise ValueError("invalid line: " + line, no) + elif state == 6: + m = PAT_INPUT_TPT.match(line) + if m: + metrics["INPUT_TPT"] = float(m.group(1)) + state = 7 + else: + raise ValueError("invalid line: " + line, no) + elif state == 7: + m = PAT_OUTPUT_TPT.match(line) + if m: + metrics["OUTPUT_TPT"] = float(m.group(1)) + state = 8 + else: + raise ValueError("invalid line: " + line, no) + elif state == 8: + m = PAT_REQ_LAT.match(line) + if m: + metrics["REQ_LAT"] = float(m.group(1)) + state = 9 + else: + raise ValueError("invalid line: " + line, no) + elif state == 9: + m = PAT_TOK_LAT.match(line) + if m: + metrics["TOK_LAT"] = float(m.group(1)) + state = 10 + else: + raise ValueError("invalid line: " + line, no) + elif state == 10: + m = PAT_FTOK_LAT.match(line) + if m: + metrics["FTOK_LAT"] = float(m.group(1)) + state = 11 + else: + raise ValueError("invalid line: " + line, no) + elif state == 11: + m = PAT_NTOK_LAT.match(line) + if m: + metrics["NTOK_LAT"] = float(m.group(1)) + if current_iter == nbr_iter - 1: + state = 0 + current_iter = -1 + current_nbr_user = -1 + else: + state = 1 + current_iter = -1 + else: + raise ValueError("invalid line: " + line, no) + +if nbr_users_perf: + print(nbr_users_perf) + for k, values in nbr_users_perf.items(): + print("number of users: ", k) + size = len(values) + if size != nbr_iter: + raise ValueError( + "size should be equal to number of interations, " + + str(size) + + " != " + + str(nbr_iter) + ) + metrics = { + "ACT_GEN_TOKENS": 0.0, + "PROMPT_LEN": 0.0, + "TOTAL_TIME": 0.0, + "REQ_TPT": 0.0, + "INPUT_TPT": 0.0, + "OUTPUT_TPT": 0.0, + "REQ_LAT": 0.0, + "TOK_LAT": 0.0, + "FTOK_LAT": 0.0, + "NTOK_LAT": 0.0, + } + for v in values: + for kk in metrics: + metrics[kk] += v[kk] + for kk, vv in metrics.items(): + metrics[kk] = vv / size + print(metrics) + print("=========================================") + + +else: + raise ValueError("Failed to collect metrics") diff --git a/benchmarks/run_benchmark.sh b/benchmarks/run_benchmark.sh index 3c7118dd7..f2a2ee507 100644 --- a/benchmarks/run_benchmark.sh +++ b/benchmarks/run_benchmark.sh @@ -2,25 +2,46 @@ set -eo pipefail CHOICE=${1} -RUN_MODE=${2} # "test" or "benchmark", where "test" will only use a small part of the dataset +RUN_MODE=${2} # "test", "benchmark" or "benchmark_idc", where "test" will only use a small part of the dataset if [ -z "$CHOICE" ] then echo "Please pass in the value of parameter CHOICE, which can be any subset of 1,2,3,4." fi if [ -z "$RUN_MODE" ] then - echo "Please pass in the value of parameter RUN_MODE, which can be 'test' or 'benchmark'." + echo "Please pass in the value of parameter RUN_MODE, which can be 'test', 'benchmark' or 'benchmark_idc'." fi VALUE_INF=2000 -MODEL_ENDPOINT="http://localhost:8000/llama-2-7b-chat-hf" -MODEL_NAME="llama-2-7b-chat-hf" +MAX_NUM_SEQS=$VALUE_INF +DYNAMIC_BATCH_SIZE=0 +PROMPT_TYPE="simple" +if [ "$#" -gt 2 ] +then + MAX_NUM_SEQS=${3} +fi +if [ "$#" -gt 3 ] +then + DYNAMIC_BATCH_SIZE=${4} +fi +if [ "$#" -gt 4 ] +then + PROMPT_TYPE=${5} +fi + +# for simple +#MODEL_ENDPOINT="http://localhost:8000/llama-3-8b-instruct" +# for openai +MODEL_ENDPOINT="http://localhost:8000" +MODEL_NAME="llama-3-8b-instruct" SHELL_FOLDER=$(cd "$(dirname "$0")";pwd) BENCHMARK_SCRIPT=$SHELL_FOLDER"/benchmark_serving.py" -WITH_VLLM_CONFIG_FILE=$SHELL_FOLDER"/../llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm.yaml" -WO_VLLM_CONFIG_FILE=$SHELL_FOLDER"/../llm_on_ray/inference/models/llama-2-7b-chat-hf.yaml" +#WITH_VLLM_CONFIG_FILE=$SHELL_FOLDER"/../llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm-ns.yaml" +WITH_VLLM_CONFIG_FILE=$SHELL_FOLDER"/../llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm.yaml" +WO_VLLM_CONFIG_FILE=$SHELL_FOLDER"/../llm_on_ray/inference/models/llama-3-8b-instruct.yaml" DATASET_PATH=$SHELL_FOLDER"/../dataset" DATASET_SHAREGPT_PATH=$SHELL_FOLDER"/../dataset/ShareGPT_V3_unfiltered_cleaned_split.json" DATASET_IPEX_PATH=$SHELL_FOLDER"/../dataset/prompt.json" +DATASET_IDC_PATH_PREFIX=$SHELL_FOLDER"/../dataset/idc_prompt_" DATASET_BENCHMARK_NUM=1000 DATASET_COMPARE_NUM=128 NUMA_SERVER_COMMAND="" @@ -39,7 +60,7 @@ if [ $RUN_MODE = "test" ] then SAVE_DIR=$SHELL_FOLDER"/results_test" NUMA_CLIENT_COMMAND="" -elif [ $RUN_MODE = "benchmark" ] +elif [ $RUN_MODE = "benchmark" ] || [ $RUN_MODE = "benchmark_idc" ] then SAVE_DIR=$SHELL_FOLDER"/results" NUMA_CLIENT_COMMAND="numactl -N 1 -m 1" @@ -107,19 +128,84 @@ latency_throughput(){ tokens_dir=$choice_dir"/tokens_"$input_tokens_length"_"$output_tokens_length # server - $NUMA_SERVER_COMMAND llm_on_ray-serve --config_file $WITH_VLLM_CONFIG_FILE --simple --max_ongoing_requests $VALUE_INF --max_num_seqs $VALUE_INF + #$numa_server_command llm_on_ray-serve --config_file $with_vllm_config_file --simple --max_concurrent_queries $VALUE_INF --vllm_max_num_seqs $VALUE_INF # client - for i in $(seq 1 $num_iter) + for num_prompts in ${query_num} do - echo "Run iter $i" - iter_dir=$tokens_dir"/iter_"$i - for num_prompts in ${query_num} + max_con_q=$VALUE_INF + if [ ! "$DYNAMIC_BATCH_SIZE" = "0" ] + then + if [ "$num_prompts" -lt "$NUM_REPLICA" ] || [ "$num_prompts" -eq "$NUM_REPLICA" ] + then + max_con_q=1 + else + max_con_q=$((num_prompts/NUM_REPLICA)) + fi + fi + echo "Run num_prompts ${num_prompts} =======================" + echo "deploying model with --max_concurrent_queries $max_con_q --vllm_max_num_seqs $MAX_NUM_SEQS ..." + $NUMA_SERVER_COMMAND llm_on_ray-serve --config_file $WITH_VLLM_CONFIG_FILE --simple --max_ongoing_requests $max_con_q --max_num_seqs $MAX_NUM_SEQS + sleep 1 + for i in $(seq 0 $num_iter) do + if [ $i = 0 ]; then + iter_dir="$tokens_dir/warmup" + echo "Run warmup" + else + iter_dir=$tokens_dir"/iter_"$i + echo "Run iter $i" + fi results_dir=$iter_dir"/num_prompts_"$num_prompts - echo "Run num_prompts ${num_prompts}" echo "results_dir: ${results_dir}" - $NUMA_CLIENT_COMMAND python $BENCHMARK_SCRIPT --model-endpoint-base $MODEL_ENDPOINT --model-name $MODEL_NAME --dataset $DATASET_IPEX_PATH --num-prompts $num_prompts --dataset-format IPEX --input-tokens $input_tokens_length --max-new-tokens $output_tokens_length --track-token-latency --vllm-engine --simple --results-dir $results_dir + $NUMA_CLIENT_COMMAND python $BENCHMARK_SCRIPT --model-endpoint-base $MODEL_ENDPOINT --model-name $MODEL_NAME --dataset $DATASET_IPEX_PATH --num-prompts $num_prompts --dataset-format IPEX --input-tokens $input_tokens_length --track-token-latency --max-new-tokens $output_tokens_length --vllm-engine --simple --results-dir $results_dir + done + done + echo "CHOICE 3 generation completed" +} + +latency_throughput_idc(){ + echo "get performance results of llm-on-ray with vllm when responding IDC requests" + num_iter=${1} + query_num=${2} + output_tokens_length=${3} + choice_dir=${4} + prompt_type=${5} + tokens_dir=$choice_dir"/tokens_"${prompt_type}"_"$output_tokens_length + + # server + #$numa_server_command llm_on_ray-serve --config_file $with_vllm_config_file --simple --max_concurrent_queries $VALUE_INF --vllm_max_num_seqs $VALUE_INF + + # client + for num_prompts in ${query_num} + do + max_con_q=$VALUE_INF + if [ ! "$DYNAMIC_BATCH_SIZE" = "0" ] + then + if [ "$num_prompts" -lt "$NUM_REPLICA" ] || [ "$num_prompts" -eq "$NUM_REPLICA" ] + then + max_con_q=1 + else + max_con_q=$((num_prompts/NUM_REPLICA)) + fi + fi + echo "Run num_prompts ${num_prompts} =======================" + #echo "deploying model with --max_concurrent_queries $max_con_q --vllm_max_num_seqs $MAX_NUM_SEQS ..." + echo "deploying model with --max_concurrent_queries 4 --vllm_max_num_seqs $MAX_NUM_SEQS ..." + $NUMA_SERVER_COMMAND llm_on_ray-serve --config_file $WITH_VLLM_CONFIG_FILE --max_ongoing_requests 4 --max_num_seqs $MAX_NUM_SEQS + sleep 1 + for i in $(seq 0 $num_iter) + do + if [ $i = 0 ]; then + iter_dir="$tokens_dir/warmup" + echo "Run warmup" + else + iter_dir=$tokens_dir"/iter_"$i + echo "Run iter $i" + fi + results_dir=$iter_dir"/num_prompts_"$num_prompts + echo "results_dir: ${results_dir}" + $NUMA_CLIENT_COMMAND python $BENCHMARK_SCRIPT --model-endpoint-base $MODEL_ENDPOINT --model-name $MODEL_NAME --dataset "${DATASET_IDC_PATH_PREFIX}${prompt_type}.json" --num-prompts $num_prompts --dataset-format IDC --track-token-latency --max-new-tokens $output_tokens_length --vllm-engine --results-dir $results_dir done done echo "CHOICE 3 generation completed" @@ -196,9 +282,9 @@ then for i in "${!concurrent_query_num[@]}"; do concurrent_query_num[$i]=$[${concurrent_query_num[$i]}*$NUM_REPLICA] done - # 32/64 + # 32/128 input_tokens_length=32 - output_tokens_length=64 + output_tokens_length=128 latency_throughput $iter "${concurrent_query_num[*]}" $input_tokens_length $output_tokens_length $benchmark_dir # 1024/128 input_tokens_length=1024 @@ -211,6 +297,15 @@ then input_tokens_length=32 output_tokens_length=20 latency_throughput $iter "${concurrent_query_num[*]}" $input_tokens_length $output_tokens_length $benchmark_dir + elif [ "$RUN_MODE" == "benchmark_idc" ] + then + iter=3 + concurrent_query_num=(1 3 6 9 12) + #for i in "${!concurrent_query_num[@]}"; do + # concurrent_query_num[$i]=$[${concurrent_query_num[$i]}*$NUM_REPLICA] + #done + output_tokens_length=512 + latency_throughput_idc $iter "${concurrent_query_num[*]}" $output_tokens_length $benchmark_dir $PROMPT_TYPE fi fi if [[ "$CHOICE" == *"4"* ]] @@ -229,4 +324,5 @@ then fi output_tokens_length=32 get_best_latency $iter "${input_tokens_length[*]}" $output_tokens_length $benchmark_dir -fi \ No newline at end of file +fi + diff --git a/dev/docker/Dockerfile.vllm_ns b/dev/docker/Dockerfile.vllm_ns new file mode 100644 index 000000000..ac191b7d4 --- /dev/null +++ b/dev/docker/Dockerfile.vllm_ns @@ -0,0 +1,42 @@ +# syntax=docker/dockerfile:1 +FROM ubuntu:22.04 + +ENV LANG C.UTF-8 + +WORKDIR /root/llm-on-ray + +RUN --mount=type=cache,target=/var/cache/apt apt-get update -y \ + && apt-get install -y build-essential cmake wget curl git vim htop ssh net-tools \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +ENV CONDA_DIR /opt/conda +RUN wget --quiet https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \ + /bin/bash ~/miniforge.sh -b -p /opt/conda +ENV PATH $CONDA_DIR/bin:$PATH + +# setup env +SHELL ["/bin/bash", "--login", "-c"] + +RUN --mount=type=cache,target=/opt/conda/pkgs conda init bash && \ + unset -f conda && \ + export PATH=$CONDA_DIR/bin/:${PATH} && \ + mamba config --add channels intel && \ + mamba install -y -c conda-forge python==3.9 gxx=12.3 gxx_linux-64=12.3 libxcrypt + +COPY ./pyproject.toml . +COPY ./MANIFEST.in . + + +# Install llm_on_ray +# Create llm_on_ray package directory to bypass the following 'pip install -e' command +RUN mkdir ./llm_on_ray +RUN --mount=type=cache,target=/root/.cache/pip pip install -e .[vllm-cpu] --extra-index-url https://download.pytorch.org/whl/cpu \ + --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ + +# Install vllm-ext +# We cannot make empty folder here like './llm_on_ray' since vllm-ext has cpp files to be compiled +COPY ./vllm-ext ./vllm-ext +COPY ./dev/scripts/check-vllm-cpu-build-env.sh ./dev/scripts/check-vllm-cpu-build-env.sh +RUN --mount=type=cache,target=/root/.cache/pip \ + source /opt/conda/bin/activate base && cd vllm-ext && pip install . && pip install --upgrade protobuf diff --git a/dev/scripts/check-vllm-cpu-build-env.sh b/dev/scripts/check-vllm-cpu-build-env.sh new file mode 100755 index 000000000..55db8a29f --- /dev/null +++ b/dev/scripts/check-vllm-cpu-build-env.sh @@ -0,0 +1,17 @@ +#!/usr/bin/env bash + +# Check tools +[[ -n $(which g++) ]] || { echo "GNU C++ Compiler (g++) is not found!"; exit 1; } +[[ -n $(which pip) ]] || { echo "pip command is not found!"; exit 1; } + +# g++ version should be >=12.3. You can run the following to install GCC 12.3 and dependencies on conda: +# conda install -y -c conda-forge gxx=12.3 gxx_linux-64=12.3 libxcrypt +version_greater_equal() +{ + printf '%s\n%s\n' "$2" "$1" | sort --check=quiet --version-sort +} +gcc_version=$(g++ --version | grep -o -E '[0-9]+\.[0-9]+\.[0-9]+' | head -n1) +echo +echo Current GNU C++ Compiler version: $gcc_version +echo +version_greater_equal "${gcc_version}" 12.3.0 || { echo "GNU C++ Compiler 12.3.0 or above is required!"; exit 1; } diff --git a/dev/scripts/ci-functions.sh b/dev/scripts/ci-functions.sh index 738154fd8..6e41eb666 100644 --- a/dev/scripts/ci-functions.sh +++ b/dev/scripts/ci-functions.sh @@ -64,7 +64,7 @@ start_docker() { docker_args+=("-e=https_proxy=${HTTPS_PROXY}") fi - echo "docker run -tid "${docker_args[@]}" "${TARGET}:latest"" + echo "docker run -tid --privileged "${docker_args[@]}" "${TARGET}:latest"" docker run -tid "${docker_args[@]}" "${TARGET}:latest" } @@ -75,11 +75,19 @@ install_dependencies(){ docker exec "${TARGET}" bash -c "pip install -r ./tests/requirements.txt" } -strat_ray(){ +start_ray(){ local TARGET=$1 + local UNLIMITED_MAXLOCKMEM=0 + if [ "$2" == "1" ]; then + UNLIMITED_MAXLOCKMEM=1 + fi # Start Ray Cluster - docker exec "${TARGET}" bash -c "./dev/scripts/start-ray-cluster.sh" + if [ "$UNLIMITED_MAXLOCKMEM" == "1" ]; then + docker exec "${TARGET}" bash -c "ulimit -l unlimited; ./dev/scripts/start-ray-cluster.sh" + else + docker exec "${TARGET}" bash -c "./dev/scripts/start-ray-cluster.sh" + fi } stop_ray(){ @@ -111,6 +119,7 @@ declare -A DF_SUFFIX_MAPPER DF_SUFFIX_MAPPER=( ["mpt-7b-ipex-llm"]=".ipex-llm" ["llama-2-7b-chat-hf-vllm"]=".vllm" + ["llama-2-7b-chat-hf-vllm-ns"]=".vllm_ns" ["gpt-j-6b"]=".cpu_and_deepspeed.pip_non_editable" ) @@ -128,6 +137,7 @@ declare -A TARGET_SUFFIX_MAPPER TARGET_SUFFIX_MAPPER=( ["mpt-7b-ipex-llm"]="_ipex-llm" ["llama-2-7b-chat-hf-vllm"]="_vllm" + ["llama-2-7b-chat-hf-vllm-ns"]="_vllm-ns" ) get_TARGET_SUFFIX() { @@ -143,6 +153,7 @@ declare -A INFERENCE_MAPPER INFERENCE_MAPPER=( ["mpt-7b-ipex-llm"]="llm_on_ray-serve --config_file llm_on_ray/inference/models/ipex-llm/mpt-7b-ipex-llm.yaml --simple" ["llama-2-7b-chat-hf-vllm"]="llm_on_ray-serve --config_file .github/workflows/config/llama-2-7b-chat-hf-vllm-fp32.yaml --simple" + ["llama-2-7b-chat-hf-vllm-ns"]="llm_on_ray-serve --config_file llm_on_ray/inference/models/vllm/llama2-7b-chat-hf-vllm-ns.yaml --simple --max_ongoing_requests 1 --max_num_seqs 1" ["default"]="llm_on_ray-serve --simple --models ${model}" ) diff --git a/docs/assets/choice3_tokens_32_64.png b/docs/assets/choice3_tokens_32_64.png index 5e86a5c8f..e2f549a11 100644 Binary files a/docs/assets/choice3_tokens_32_64.png and b/docs/assets/choice3_tokens_32_64.png differ diff --git a/docs/vllm.md b/docs/vllm.md index 90e12a032..604c8c2a7 100644 --- a/docs/vllm.md +++ b/docs/vllm.md @@ -24,10 +24,32 @@ Then please run the following script to install vLLM for CPU into your LLM-on-Ra dev/scripts/install-vllm-cpu.sh ``` +## Install vLLM Extension for Quantization (Optional) +To further speed up quantized model inference on Intel CPU, we extend vLLM to run the model decoding in own own inference engine, which is based on [https://github.com/intel/neural-speed](neural-speed). +The extension gets 2X performance improvement compared to plain vLLM CPU on Llama-2-7b and Llama-3-8b models. + +Neural Speed is an innovative library designed to support the efficient inference of large language models (LLMs) on Intel platforms through the state-of-the-art (SOTA) low-bit quantization powered by +[https://github.com/intel/neural-compressor](Intel Neural Compressor). The work is inspired by [https://github.com/ggerganov/llama.cpp](llama.cpp) and further optimized for Intel platforms with our +innovations in [https://arxiv.org/abs/2311.00502](NeurIPS' 2023). + +You need to first install llm-on-ray with "vllm-cpu" extra. + +```bash +pip install .[vllm-cpu] --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ +``` + +Then, install the vLLM extension and the inference engine. +```bash +cd vllm-ext +pip install . && pip install --upgrade protobuf + +``` + ## Run #### Serving +* Vanilla vLLM To serve model with vLLM and simple protocol, run the following: ```bash @@ -36,6 +58,30 @@ llm_on_ray-serve --config_file llm_on_ray/inference/models/vllm/llama-2-7b-chat- In the above example, `vllm` property is set to `true` in the config file for enabling vLLM. +* vLLM Extension +To serve model with vLLM extension with Intel inference engine, run with following (Note: only Llama-2-7b-chat-hf and Llama-3-8b-instruct are supported for now): + +```bash +# copy quantization config file to your specific snapshot dir, for example .../snapshots/f5db02db7.../ +# the quant_ns_config.json will be copied from llm_on_ray package with default config if you don't copy your desired one manually. +cp llm_on_ray/inference/models/vllm/quantization/quant_ns_config.json + +# startup ray head node and worker node with promper configurations +# '--num-cpus 0' makes no task or actor run in head node +numactl -N 0 -m 0 ray start --head --num-cpus 0 --include-dashboard=False --disable-usage-stats --node-ip-address +# adjust values of NS_NUM_THREADS and numa core binding accordingly in your machine whilst keep OMP_PROC_BIND and OMP_NUM_THREADS unchanged. The example below is for standard SPR machine. +# typically number of bound physical cores should be equal to value --num-cpus. and NS_NUM_THREADS should be equal to number of bound physical cores - 2 +# for deploying multiple replicas, check examples/inference/vllm/run_multiple_ray_node.sh +OMP_PROC_BIND=true NS_NUM_THREADS=44 OMP_NUM_THREADS=1 numactl -N 0 -C 0-45 -m 0 ray start --address=':6379' --num-cpus 46 --resources='{"inference_engine": 1}' +# startup one more ray worker node if it's open-ai compatible deployment +numactl -N 1 -C 110-111 -m 1 ray start --address='10.0.11.8:6379' --num-cpus 1 --resources='{"app_router": 1}' + +# deploy model serving. Note: It includes quantizing the model on the fly based on the quant_ns_config.json if it has not been quantized. +llm_on_ray-serve --config_file llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-ns.yaml --keep_serve_terminal --max-ongoing-requests 4 --max_num_seqs 4 +``` + +For now, only Llama-2-7b-chat-hf and Llama-3-8b-instruct are supported. But it can quickly extend to support other models. + #### Querying To start a non-streaming query, run the following: @@ -52,4 +98,4 @@ python examples/inference/api_server_simple/query_single.py --model_endpoint htt ## Further Configuration -Please follow [Deploying and Serving LLMs on Intel CPU/GPU/Gaudi](serve.md) document to for other configurations. \ No newline at end of file +Please follow [Deploying and Serving LLMs on Intel CPU/GPU/Gaudi](serve.md) document to for other configurations. diff --git a/examples/inference/vllm/api_server.py b/examples/inference/vllm/api_server.py new file mode 100644 index 000000000..71021f9a3 --- /dev/null +++ b/examples/inference/vllm/api_server.py @@ -0,0 +1,146 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine +and simple performance benchmarks. It is not intended for production use. +For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please +change `vllm/entrypoints/openai/api_server.py` instead. +""" + +import argparse +import json +import ssl +from typing import AsyncGenerator + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +from vllm.extension import ns + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.usage.usage_lib import UsageContext +from vllm.utils import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + assert engine is not None + results_generator = engine.generate(prompt, sampling_params, request_id) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [prompt + output.text for output in request_output.outputs] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return Response(status_code=499) + # print(request_output) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +def _modify_qunatization_choices(parser, dest, choices): + for action in parser._actions: + if action.dest == dest: + action.choices = choices + return + else: + raise ValueError("argument {} not found".format(dest)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8070) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument("--ssl-ca-certs", type=str, default=None, help="The CA certificates file") + parser.add_argument( + "--ssl-cert-reqs", + type=int, + default=int(ssl.CERT_NONE), + help="Whether client certificate is required (see stdlib ssl module's)", + ) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy", + ) + parser.add_argument("--log-level", type=str, default="debug") + parser = AsyncEngineArgs.add_cli_args(parser) + _modify_qunatization_choices(parser, "quantization", ("awq", "gptq", "squeezellm", "ns", None)) + _modify_qunatization_choices(parser, "block_size", None) + args = parser.parse_args() + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args, usage_context=UsageContext.API_SERVER) + + app.root_path = args.root_path + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level=args.log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + ) diff --git a/examples/inference/vllm/ray-vllm-examples/llm.py b/examples/inference/vllm/ray-vllm-examples/llm.py new file mode 100644 index 000000000..6a20be65e --- /dev/null +++ b/examples/inference/vllm/ray-vllm-examples/llm.py @@ -0,0 +1,152 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Dict, Optional, List +import logging + +from fastapi import FastAPI +from starlette.requests import Request +from starlette.responses import StreamingResponse, JSONResponse + +from ray import serve + +from vllm.extension import ns + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + ErrorResponse, +) +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_engine import LoRAModulePath + +logger = logging.getLogger("ray.serve") + +app = FastAPI() + + +@serve.deployment( + autoscaling_config={ + "min_replicas": 1, + "max_replicas": 10, + "target_ongoing_requests": 5, + }, + max_ongoing_requests=10, +) +@serve.ingress(app) +class VLLMDeployment: + def __init__( + self, + engine_args: AsyncEngineArgs, + response_role: str, + lora_modules: Optional[List[LoRAModulePath]] = None, + chat_template: Optional[str] = None, + ): + from vllm.extension import ns + + logger.info(f"Starting with engine args: {engine_args}") + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + # Determine the name of the served model for the OpenAI client. + if engine_args.served_model_name is not None: + served_model_names = engine_args.served_model_name + else: + served_model_names = [engine_args.model] + self.openai_serving_chat = OpenAIServingChat( + self.engine, + self.engine.engine.model_config, + served_model_names, + response_role, + lora_modules, + chat_template, + ) + + @app.post("/v1/chat/completions") + async def create_chat_completion(self, request: ChatCompletionRequest, raw_request: Request): + """OpenAI-compatible HTTP endpoint. + + API reference: + - https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html + """ + logger.info(f"Request: {request}") + generator = await self.openai_serving_chat.create_chat_completion(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") + else: + assert isinstance(generator, ChatCompletionResponse) + return JSONResponse(content=generator.model_dump()) + + +def _modify_qunatization_choices(parser, dest, choices): + for action in parser._actions: + if action.dest == dest: + action.choices = choices + return + else: + raise ValueError("argument {} not found".format(dest)) + + +def parse_vllm_args(cli_args: Dict[str, str]): + """Parses vLLM args based on CLI inputs. + + Currently uses argparse because vLLM doesn't expose Python models for all of the + config options we want to support. + """ + parser = make_arg_parser() + _modify_qunatization_choices(parser, "quantization", ("awq", "gptq", "squeezellm", "ns", None)) + _modify_qunatization_choices(parser, "block_size", None) + arg_strings = [] + for key, value in cli_args.items(): + arg_strings.extend([f"--{key}", str(value)]) + logger.info(arg_strings) + parsed_args = parser.parse_args(args=arg_strings) + return parsed_args + + +def build_app(cli_args: Dict[str, str]) -> serve.Application: + """Builds the Serve app based on CLI arguments. + + See https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server + for the complete set of arguments. + + Supported engine arguments: https://docs.vllm.ai/en/latest/models/engine_args.html. + """ # noqa: E501 + parsed_args = parse_vllm_args(cli_args) + engine_args = AsyncEngineArgs.from_cli_args(parsed_args) + # engine_args.worker_use_ray = True + + # tp = engine_args.tensor_parallel_size + # logger.info(f"Tensor parallelism = {tp}") + pg_resources = [] + pg_resources.append({"CPU": 1}) # for the deployment replica + # for i in range(tp): + # pg_resources.append({"CPU": 1, "GPU": 1}) # for the vLLM actors + + # We use the "STRICT_PACK" strategy below to ensure all vLLM actors are placed on + # the same Ray node. + return VLLMDeployment.options( + placement_group_bundles=pg_resources, placement_group_strategy="STRICT_PACK" + ).bind( + engine_args, + parsed_args.response_role, + parsed_args.lora_modules, + parsed_args.chat_template, + ) diff --git a/examples/inference/vllm/ray-vllm-examples/query.py b/examples/inference/vllm/ray-vllm-examples/query.py new file mode 100644 index 000000000..012725edb --- /dev/null +++ b/examples/inference/vllm/ray-vllm-examples/query.py @@ -0,0 +1,41 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from openai import OpenAI + +# Note: Ray Serve doesn't support all OpenAI client arguments and may ignore some. +client = OpenAI( + # Replace the URL if deploying your app remotely + # (e.g., on Anyscale or KubeRay). + base_url="http://localhost:8000/v1", + api_key="NOT A REAL KEY", +) +chat_completion = client.chat.completions.create( + model="meta-llama/Llama-2-7b-chat-hf", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "What are some highly rated restaurants in San Francisco?'", + }, + ], + temperature=0.01, + stream=True, +) + +for chat in chat_completion: + if chat.choices[0].delta.content is not None: + print(chat.choices[0].delta.content, end="") diff --git a/examples/inference/vllm/run_multiple_api_server.sh b/examples/inference/vllm/run_multiple_api_server.sh new file mode 100644 index 000000000..5a65ded6b --- /dev/null +++ b/examples/inference/vllm/run_multiple_api_server.sh @@ -0,0 +1,12 @@ +NS_MODEL_PERF_STEPS=50 OMP_PROC_BIND=true NS_NUM_THREADS=24 OMP_NUM_THREADS=1 numactl -N 0 -m 0 -C 0-26 nohup python api_server.py --model meta-llama/Llama-2-7b-chat-hf --quantization ns --device cpu --max-num-seqs 64 --block-size 4096 --max-model-len 4096 --host 10.0.11.5 --port 8070 > s8070.log 2>&1 & +sleep 2 + +NS_MODEL_PERF_STEPS=50 OMP_PROC_BIND=true NS_NUM_THREADS=24 OMP_NUM_THREADS=1 numactl -N 0 -m 0 -C 28-54 nohup python api_server.py --model meta-llama/Llama-2-7b-chat-hf --quantization ns --device cpu --max-num-seqs 64 --block-size 4096 --max-model-len 4096 --host 10.0.11.5 --port 8071 > s8071.log 2>&1 & +sleep 2 + +NS_MODEL_PERF_STEPS=50 OMP_PROC_BIND=true NS_NUM_THREADS=24 OMP_NUM_THREADS=1 numactl -N 1 -m 1 -C 56-82 nohup python api_server.py --model meta-llama/Llama-2-7b-chat-hf --quantization ns --device cpu --max-num-seqs 64 --block-size 4096 --max-model-len 4096 --host 10.0.11.5 --port 8072 > s8072.log 2>&1 & +sleep 2 + +NS_MODEL_PERF_STEPS=50 OMP_PROC_BIND=true NS_NUM_THREADS=24 OMP_NUM_THREADS=1 numactl -N 1 -m 1 -C 84-110 nohup python api_server.py --model meta-llama/Llama-2-7b-chat-hf --quantization ns --device cpu --max-num-seqs 64 --block-size 4096 --max-model-len 4096 --host 10.0.11.5 --port 8073 > s8073.log 2>&1 & +sleep 2 + diff --git a/examples/inference/vllm/run_multiple_offline.sh b/examples/inference/vllm/run_multiple_offline.sh new file mode 100644 index 000000000..194d83ccb --- /dev/null +++ b/examples/inference/vllm/run_multiple_offline.sh @@ -0,0 +1,7 @@ +# two sockets +#NS_NUM_THREADS=25 OMP_NUM_THREADS=1 OMP_PROC_BIND=true NS_MODEL_PERF_STEPS=50 nohup numactl -N 0 -m 0 -C 0-27 python vllm_offline_inference.py > 2.log 2>&1 & +#NS_NUM_THREADS=25 OMP_NUM_THREADS=1 OMP_PROC_BIND=true NS_MODEL_PERF_STEPS=50 nohup numactl -N 1 -m 1 -C 56-83 python vllm_offline_inference.py > 1.log 2>&1 & + +TOKENIZERS_PARALLELISM=false NS_NUM_THREADS=25 OMP_NUM_THREADS=1 OMP_PROC_BIND=true NS_MODEL_PERF_STEPS=50 nohup numactl -N 0 -m 0 -C 0-26 python vllm_offline_inference.py > 2.log 2>&1 & +TOKENIZERS_PARALLELISM=false NS_NUM_THREADS=25 OMP_NUM_THREADS=1 OMP_PROC_BIND=true NS_MODEL_PERF_STEPS=50 nohup numactl -N 0 -m 0 -C 27-53 python vllm_offline_inference.py > 1.log 2>&1 & + diff --git a/examples/inference/vllm/run_multiple_ray_node.sh b/examples/inference/vllm/run_multiple_ray_node.sh new file mode 100644 index 000000000..056c007a4 --- /dev/null +++ b/examples/inference/vllm/run_multiple_ray_node.sh @@ -0,0 +1,13 @@ +# NS_MODEL_PERF_STEPS=50 +OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 1 -C 83-109 -m 1 ray start --address='10.0.11.8:6379' --num-cpus 27 --resources='{"inference_engine": 1}' +sleep 1 +OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 0 -C 0-26 -m 0 ray start --address='10.0.11.8:6379' --num-cpus 27 --resources='{"inference_engine": 1}' +sleep 1 +OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 1 -C 56-82 -m 1 ray start --address='10.0.11.8:6379' --num-cpus 27 --resources='{"inference_engine": 1}' +sleep 1 +OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 0 -C 27-53 -m 0 ray start --address='10.0.11.8:6379' --num-cpus 27 --resources='{"inference_engine": 1}' +sleep 1 +numactl -N 1 -C 110-111 -m 1 ray start --address='10.0.11.8:6379' --num-cpus 1 --resources='{"app_router": 1}' + +sleep 1 +echo 'done' diff --git a/examples/inference/vllm/run_multiple_ray_node_plain_cpu.sh b/examples/inference/vllm/run_multiple_ray_node_plain_cpu.sh new file mode 100644 index 000000000..69b320238 --- /dev/null +++ b/examples/inference/vllm/run_multiple_ray_node_plain_cpu.sh @@ -0,0 +1,30 @@ +# NS_MODEL_PERF_STEPS=50 +#OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 1 -C 83-109 -m 1 ray start --address='10.0.11.8:6379' --num-cpus 1 +#sleep 1 +#OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 0 -C 0-26 -m 0 ray start --address='10.0.11.8:6379' --num-cpus 1 +#sleep 1 +#OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 1 -C 56-82 -m 1 ray start --address='10.0.11.8:6379' --num-cpus 1 +#sleep 1 +#OMP_PROC_BIND=true NS_NUM_THREADS=25 OMP_NUM_THREADS=1 numactl -N 0 -C 27-53 -m 0 ray start --address='10.0.11.8:6379' --num-cpus 1 + +# two instances two sockets +#OMP_PROC_BIND=true NS_NUM_THREADS=44 OMP_NUM_THREADS=1 numactl -N 0 -C 0-45 -m 0 ray start --address='10.0.11.8:6379' --num-cpus 1 +#sleep 1 +#OMP_PROC_BIND=true NS_NUM_THREADS=44 OMP_NUM_THREADS=1 numactl -N 1 -C 56-101 -m 1 ray start --address='10.0.11.8:6379' --num-cpus 1 + +#LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" numactl -N 0 -m 0 -C 0-55 ray start --address='10.0.11.8:6379' --num-cpus 56 +#sleep 1 +#LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" numactl -N 1 -m 1 -C 56-111 ray start --address='10.0.11.8:6379' --num-cpus 56 + +LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" numactl -N 0 -m 0 -C 0-27 ray start --address='10.0.11.8:6379' --num-cpus 27 +sleep 1 +LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" numactl -N 0 -m 0 -C 28-55 ray start --address='10.0.11.8:6379' --num-cpus 27 +sleep 1 +LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" numactl -N 1 -m 1 -C 56-83 ray start --address='10.0.11.8:6379' --num-cpus 27 +sleep 1 +LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" numactl -N 1 -m 1 -C 84-111 ray start --address='10.0.11.8:6379' --num-cpus 27 +sleep 1 +LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4" ray start --address='10.0.11.8:6379' --num-cpus 1 --resources='{"app_router": 1}' + +sleep 1 +echo 'done' diff --git a/examples/inference/vllm/stop_multiple_api_server.sh b/examples/inference/vllm/stop_multiple_api_server.sh new file mode 100644 index 000000000..60a0cbe00 --- /dev/null +++ b/examples/inference/vllm/stop_multiple_api_server.sh @@ -0,0 +1 @@ +ps -ef | grep api_server | grep -v grep | awk '{ print $2 }' | xargs kill diff --git a/examples/inference/vllm/vllm_offline_inference.py b/examples/inference/vllm/vllm_offline_inference.py new file mode 100644 index 000000000..d8a0bff17 --- /dev/null +++ b/examples/inference/vllm/vllm_offline_inference.py @@ -0,0 +1,166 @@ +# +# Copyright 2023 The LLM-on-Ray Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from vllm import LLM, SamplingParams + +from vllm.extension import ns as ns + +from time import perf_counter + + +############################ +# TODO: TODO: TODO: reset request_id and seq_id before they reach to max value of int64_t +############################ + +# Sample prompts. +# prompts = [ +# "Hello, my name is", +# "The president of the United States is", +# "The capital of France is", +# "The future of AI is", +# "Hello, my name is", +# "The president of the United States is", +# "The capital of France is", +# "The future of AI is", +# "Hello, my name is", +# "The president of the United States is", +# "The capital of France is", +# "The future of AI is", +# ] +prompts = [ + "Russell Brunson's Perfect Webinar Script is a framework for delivering webinars that are designed to sell a product or service. ", + "Tony Robbins describes six core human needs that drive our behaviors and motivations. These six needs are:\n\n1. Certainty: The need for safety, stability, and predictability.", + "1. Homogeneity: The segment should consist of customers who share similar characteristics and behaviors.\n2. Distinctiveness: The segment should be different from other segments in terms of their characteristics and behaviors.\n3. Stability: The segment should remain relatively stable over time and not change drastically. The characteristics and behaviors of customers within the segment should not change significantly.", + 'In Java, I want to replace string like "This is a new {object} at {place}" with a Map, {object: "student", "point 3, 4"}, and get a result "This is a new student at point 3, 4". How can I do?', + 'You can use the `String.format()` method in Java to replace placeholders in a string with values from a map. Here\'s an example code snippet that demonstrates how you can achieve this:\n```java\nimport java.util.HashMap;\nimport java.util.Map;\n\npublic class StringReplaceExample {\n public static void main(String[] args) {\n String input = "This is a new {object} at {place}";\n Map replacements = new HashMap<>();\n replacements.put("object", "student");\n replacements.put("place", "point 3, 4");\n\n ', + 'The language used to describe the addressing modes of these instructions is metaphorical and grandiose, emphasizing the complexity and power of these commands. For example, the use of "enigmatic" and "confounding" to describe JMP ABCD and MOV AX, [BX+SI], respectively, suggests that these instructions are not easily understood and require a level of expertise to comprehend.\n\nSimilarly, the use of "inscrutable" and "cryptic" to describe MOV AX, [100] and MOV AX, [BX], respectively, implies that these commands are shrouded in mystery and are difficult to decipher. The speaker\'s use of "perplexing" and "unfathomable" to describe MOV AX, [BX\\*2+SI] and MOV AX, BX, respectively, ', + "Lo and behold! The arcane and elusive art of metaphorical language has been summoned forth to expound upon the enigmatic addressing modes of the instructions at hand. The speakers have wielded grandiose expressions with utmost reverence and awe, extolling the ineffable power and bewildering functionality of these directives. Among the inscrutable commands are the confounding JMP ABCD, the abstruse MOV AX, [BX+SI], the unfathomable MOV AX, [100],", + "more more perplexity and verbose", + "By the grace of the gods,", +] +# Create a sampling params object. +# sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=256, use_beam_search=True, best_of=2) +# sampling_params = SamplingParams(temperature=0.6, top_p=0.95, max_tokens=256, top_k=40) +sampling_params = SamplingParams(max_tokens=512) +# sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=256) + +####################################### +prompts = [ + "Hello, how are you?", + "What is your name?", + "What is your favorite color?", + "What is your favorite food?", + "What is your favorite movie?", + "What is your favorite song?", + "What is your favorite book?", + "What is your favorite animal", + "What is your favorite sport?", + "What is your favorite hobby?", + "What is your favorite subject?", + "What is your favorite game?", + "What is your favorite TV show?", + "What is your favorite actor?", + "What is your favorite actress?", + "What is your favorite singer?", + "What is your favorite band?", + "What is your favorite artist?", + "What is your favorite author?", + "What is your favorite poet?", +] +# 32 input +prompts = [ + "Tony Robbins describes six core human needs that drive our behaviors and motivations. These six needs are:\n\n1. Certainty: The need for safety, stability, and predictability." +] +# 1024 input +prompts = [ + "It is done, and submitted. You can play 'Survival of the Tastiest' on Android, and on the web. Playing on the web works, but you have to simulate multiple touch for table moving and that can be a bit confusing. There is a lot I'd like to talk about. I will go through every topic, insted of making the typical what went right/wrong list. Concept Working over the theme was probably one of the hardest tasks which I had to face. Originally, I had an idea of what kind of game I wanted to develop, gameplay wise - something with a lot of enemies/actors, simple graphics, maybe set in space, controlled from a top-down view. I was confident that I could fit any theme around it. In the end, the problem with a theme like 'Evolution' in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game? In a game, you need to control something to reach an objective. That control goes against what evolution is supposed to be like. If you allow the user to pick how to evolve something, it's not evolution anymore - it's the equivalent of intelligent design, the fable invented by creationists to combat the idea of evolution. Being agnostic and a Pastafarian, that's not something that rubbed me the right way. Hence, my biggest dillema when deciding what to create was not with what I wanted to create, but with what I did not. I didn't want to create an 'intelligent design' simulator and wrongly call it evolution. This is a problem, of course, every other contestant also had to face. And judging by the entries submitted, not many managed to work around it. I'd say the only real solution was through the use of artificial selection, somehow. So far, I have not seen any entry using this at its core gameplay. Alas, this is just a fun competition and after a while I decided not to be as strict with the game idea, and allowed myself to pick whatever I thought would work out. My initial idea was to create something where humanity tried to evolve to a next level but had some kind of foe trying to stop them from doing so. I kind of had this image of human souls flying in space towards a monolith or a space baby (all based in 2001: A Space Odyssey of course) but I couldn't think of compelling (read: serious) mechanics for that. Borgs were my next inspiration, as their whole hypothesis fit pretty well into the evolution theme. But how to make it work? Are you the borg, or fighting the Borg? The third and final idea came to me through my girlfriend, who somehow gave me the idea of making something about the evolution of Pasta. The more I thought about it the more it sounded like it would work, so I decided to go with it. Conversations with my inspiring co-worker Roushey (who also created the 'Mechanical Underdogs' signature logo for my intros) further matured the concept, as it involved into the idea of having individual pieces of pasta flying around and trying to evolve until they became all-powerful. A secondary idea here was that the game would work to explain how the Flying Spaghetti Monster came to exist - by evolving from a normal dinner table. So the idea evolved more or less into this: you are sitting a table. You have your own plate, with is your 'base'. There are 5 other guests at the table, each with their own plate. Your plate can spawn little pieces of pasta. You do so by 'ordering' them through a menu. Some pastas are better than others; some are faster, some are stronger. They have varying 'costs', which are debited from your credits (you start with a number of credits). Once spawned, your pastas start flying around. Their instinct is to fly to other plates, in order to conquer them (the objective of the game is having your pasta conquer all the plates on the table). But they are really autonomous, so after being spawned, you have no control over your pasta (think DotA or LoL creeps). Your pasta doesn't like other people's pasta, so if they meet, they shoot sauce at each other until one dies. You get credits for other pastas your own pasta kill. Once a pasta is in the vicinity of a plate," +] +# simple +# prompts = [ +# "You are a highly skilled AI trained in language comprehension and summarization. I would like you to read the following transcription of a meeting and summarize it into a concise abstract paragraph. Aim to retain the most important points, providing a coherent and readable summary that could help a person understand the main points of the discussion without needing to read the entire text. Please avoid unnecessary details or tangential points. list the people involved in the converation, that way that will scan transcribe for you, yeah. Kyle Short: Well then I GPT did the transcription. Vinay Kumar: Yes. Chakri Krishna: We will, absolutely. Chakri Krishna: So you wanna go? Kyle Short: Alright. Chakri Krishna: You wanted to be intros, Kyle. Kyle Short: Yeah, I think at least for Jeevan and Rich, your faces will be new to Shredder. Vinay Kumar: I know them. I've been. I've talked to both of them. Kyle Short: You've got to do the one on ones already. Excellent. Rich Rodgers: Yeah. Yep. Jason Rawlins: Yep. Kyle Short: OK, then we can skip it. Awesome. Chakri Krishna: Alright, cool. Chakri Krishna: Let's do white you know, Kyle, to kind of do an extension of what you were doing with security. Chakri Krishna: Is IT security or is IT security? Kyle Short: They confirmed it as security. Chakri Krishna: Security. Thank you very much. We'll, we'll throw. I'll throw a whiteboard out there. Chakri Krishna: At least this whole kind of help me and maybe we can like draw a little bit better. Chakri Krishna: On what we're building. Chakri Krishna: So I'm not, I'm not 100% sure where to start to make it make most sense. Chakri Krishna: Sure. Chakri Krishna: Alright, cool. Yeah, yeah. Cool. Vinay Kumar: Start from I GPT, we know I GPT right and then yeah. Chakri Krishna: Oops, I've got my pen stuck to the mouse. Kyle Short: And just so you know, we spent the whole 4 hours I spent with security. Today was basically on the one diagram that you had which was your high-level" +# ] +# medium +# prompts = [ +# "You are a highly skilled AI trained in language comprehension and summarization. Please avoid unnecessary details or tangential points. list the people involved in the converation presentation and create notes for each slide that I can use to guide my presentation. Each time you see next slide, this is an indication that a new slide is being presented and you should create a new section of notes. If you see the text “for example” this indicates that I am giving an important example that is critical to the context of the presentation. Always include notes that prompt me to remember the examples I am providing. Provide a bolded word “Example” at the beginning of these types of examples. Act as an expert presenter who is guiding a novice presenter on how to prepare notes for a presentation that will be given to a large audience. The presenter will be nervous and might need additional notes to help keep him on track. Ask me questions before generating notes that might help me in creating a good presentation. Only produce one set of notes for one slide at a time. Ask me if I am done with the notes or want to work on them more for each slide. When I tell you I am done with the notes for the slide, you will move to the next slide where we will work on the next slide’s notes. DICTATION: I want to start out by talking about the journey the generative AI journey that we see from an IT perspective I have a timeline here listed out and kind of goes through a little bit before 2023 out to the future and we'll cover those things as they let's see our our kind of go go through those things um as we all know there are a lot of industry innovations that took place over the last many years actually even decades when we think about generative AI large language models these concepts have been around a long time decades but we haven't really had the compute power in place to leverage them like we are these days which is wonderful from a timeline perspective I'm starting here in November of 2023 November of 2023 open AI releases their product called ChatGPT and as we all know the media picks it up a lot of things start happening people start using it you know people are declaring it's the the end of work or you know a lot of different things that that happen you know new new technology comes out and it's all every everyone buzzes about it the interesting thing that we see is that there's suddenly lots and lots of money being poured in billions 10s of billions of dollars being poured into these ventures for open AI for for different solutions that are out there and huge vendors some of our primary customers start to develop solutions that that are are also available so they start to give sort of copy cat solutions if you will or similar ones right one thing that is important also note the servant for me is a software engineer one of the key things that changes here is that open AI opens up their models to be used via an API So what this does for me is a software engineer as it changes the picture a little bit I now I'm not a data scientist but now I can start building solutions around these APIs and there's a lot of software engineers out there and so you see an explosion of all kinds of different things that are out there software engineers are assembling things putting things together and really kind of innovating with what we can do with these APIs and that's a big change around that same time kind of going duck back down onto the bottom of this picture here we at Intel start to kick off all kinds of pocs proofs of concepts around the company how can we make use of these models of large language models generative AI what does it look like and when I say hundreds of PCs POC's kicked off they did hundreds of different POC's and STG leadership asked IT to kind of get their arms around what's happening and understand the use cases which we did we launched off a work group and we called it degenerative a work group probably some of you here were part of that and we we split it up into a couple of different ways but we're really focused on trying to understand the use cases and how those use cases are realized inside of Intel and what they do right I'll along the process as we're doing POC's within the company the industry is marching along as well you know Intel releases a product called articulate and spun it off into a different company you see big improvements from the capabilities and open source models as well as other vendors get involved like you know you've got Google out there and and and they're and they're associated with partners and a WS and so a lot of lot of innovation happening in this space as well all at the same time billions of dollars being funneled into this business and inside of Intel again we launched something called the AI inside Intel MRC the IRC and that that ELT level team starts to develop some goals um those goals kind of come down to to me and my team in various bits and pieces I know their their scope is vast but they they launch and they give us some specific goals which we've been marching towards so the kind of if you look at this picture there's a lot of you know a lot of pocs happening a lot of work happening a lot of innovation happening in the company and at the same time we're we're given a directive by the IRC to figure out how we can focus on to a single platform all of the innovation that's been done across the company and bring it together and bring it down into one platform that can be leveraged by Intel and Intel users and Intel developers in order to build solutions and figure out what that is so that's kind of what we have been doing as from an IT perspective looking at the use cases trying to figure out a strategy doing our own proofs of concept and of course we've built some tools in IT as well and really kind of narrowing down that that those use cases so we can build a single platform for people to use looking out in the future beyond 2024 we do see the ability and at the end towards the end of 2024 to start onboarding more and more folks onto a single platform and I'll talk about that what that means and uh and and kind of leveraging that platform all along this journey and I'll mention this a couple of times right we have to kind of figure out how do we how do we leverage the Intel hardware and the Intel programs around the Habana gaudi accelerators around AI PC how do we really make that a true game changer for the industry and for us as well right we need to be using our own tools so really what we do what we did from this journey perspective is take all of the innovation on try to that's happening so many teams are doing so many cool things try to understand what's happening figure out how it can service the use cases from a common perspective and really kind of take the company forward OK we'll go to the next slide During the the work group that we ran as I mentioned before we were looking for common use cases what are the primary use cases now as you know the use cases for generative are vast but they kind of fall into three buckets if you will one kind of around developer productivity with a big focus on code generation another one kind of around the QA bots with a focus on rag or retrieval augmented generation and the last one is just general productivity with focus around prompt engineering and the capabilities there we we looked at the code generation side of things early on and we were we found a couple different patterns there were there were people who were using the tools like you know the the GPT tools to just generate code have it do analysis of code and things like that and sort of in a separate window and teams like the MSO team were building VS code plugins to kind of really test and figure out how just the open AI models themselves would perform for us in terms of generating code and analyzing code at the same time there were several tools out in the industry that were really built into the IDE itself and allowed developers to kind of inline work with their code and have code suggestions there obviously GitHub copilot is one of the primary ones and we are doing a pilot with that right now in in the across the organization and and trying to figure out how that actually works but during this sort of common use case discovery people are building all kinds of different things there as well so you know how do you do test case generation how do you do inject this into your cicd pipeline to maybe create documentation I've seen some good use cases there as well so a lot of stuff around developer productivity and this extends beyond just the coder right so so there are some great use cases being done as well around you know creating other types of code assets as well beyond just like software that you would use in an IT perspective right how do we create product how do we do those types of things so a lot of really interesting things being done there that was a huge focus the the next around retrieval augmented generation this is a very common use case that we see across a lot of different folks and the the focus for a lot of people when we were doing the work group was really trying to figure out the right way to approach rag and the technical pieces that we needed right to to enable it as well as you know where where should we take it in the future so so during the work group people who've who've used this we got a we have our our dbas platform that's up and running and it's using a plugin for PostgreSQL called PG vector so folks can can start using a rag based model with PG vector and that's all supported standardized T stuff right now and that's one of the things we worked on out there but also at the same time a lot of innovation people really working on improving the the results that come back from that from the rag model as well as you know trying to figure out and fine tune exactly how to to build the different prompt workflows there as well so also the last column there on prompt engineering of early round general productivity so this this we see a lot of work or saw a lot of work people are just really interested in prompt engineering and trying to figure out how we make general employees more productive one of the common really common prompts engineering use cases we see as something around OK R and something around insights so we all do these every quarter right we have insight sessions with our management managers we create OK RC there for ourselves or for our teams or for our organizations and there are a lot of folks that were have been and still are trying to look at how we could have deep conversations with from a prompt engineering perspective on on actually trying to improve the ability for us to generate OKR's and and insights and make it a more productive conversation and using chats in order to do that prompt engineering is a fascinating example that that you see um and it's something that is probably worth noting here when somebody come first comes to a tool just the the tool like like a charge GBT or IGBT or any of the other implementations we have inside of Intel when they first see that they look at it and they recognize it they think they know how to use it they're like ohh you know what I've seen this before it's a search box it says ask me a question and so they think of it more like Google search or Bing search where they ask a question and they're generally quite disappointed with with the response that they get um and they will a lot of times walk away from the tool I can tell you that the platform that my team supports we've had about 17,000 people come to our tool over the last six the last two months of those 17 thousand 7000 have never come back so so they're missing something right why why are more why do I have you know 50,000 people coming back pretty regularly and 7000 who just left there's a reason and it has to do with prompt engineering and a lack of knowledge around prompt engineering and this is something that's really important to us in order to try to help people and figure out and get them inspired on how they could do prompt engineering because this is really where at least in my opinion we can see a lot of productivity gains in the company so prompt engineering is sort of one of those areas that's that's really you know kind of something we're digging into also you know you probably have read there's articles out there that will say prompt engineering is dead and it may die eventually right we may get to the point where we can really work with the intent of the user in a way that we can don't have to do any that's the big prompt engineering but for the time being it is here and it's one of the best ways that you can become more productive in using it so these three big common use cases the code generation developer productivity the rag and the QA bots and prompt engineering and general productivity these are kind of the things we saw during this this work group that we were running over the last year or so next slide So coming out of our journey from an IT perspective we decided to create a platform we're calling it the gas platform or the generative AI as a service platform and it we have a very specific thing that is motivating us based on what you have seen we are really focused on trying to hit that 30 to 40% productivity improvement that has been handed to us by our leadership trying to figure out how we can enable that faster and easier for teams and so we're looking at two key things one is reusable capabilities and the other is platform democratization so when you think about reusable capabilities you know it's it's not code if you will so we're not thinking about reasonable code we're actually thinking about big high level reusable capabilities if you break down any generative AI application that we built today at Intel it's generally made-up of the same big pieces or the same parts they all have an AI workflow engine in them whether it's an off the shelf one or open source one like lynching or semantic kernel or whether it's been constructed handcrafted they have that AI workflow in there that's the thing that kind of manages the combination of data and and prompting and an inference and embeddings right that whole kind of workflow of what you would do when you're building your application they all have that they all have to to some extent a data ingestion capability and a data ingestion pipeline so getting the data from an unstructured data source into a vector store so whatever they do they all kind of have that as well and they also have some sort of model inferencing so whether most apps that Intel are running on Azure open AI which is a little bit disheartening well I'll talk about that in a future slide but we really kind of have to shift that a little bit over time that's what we need to focus on but they all have that model inferencing layer that's in there somewhere so those are kind of the three big giant capabilities that people have they also kind of we think about it too they'll all have UI shouldn't they all do most of them have a UI all there most of them have or some of them have an API layer that they've exposed as well but they all kind of have those big buckets of of things right so from from our perspective in it the thought is why should people rebuild that every time do they need to rebuild that every time is that something that is easy enough for us to provide as a capability maybe we can just give you configuration and you can configure how things work you know why would a team want to rebuild something over and over again that's not the innovation right that's not the innovative stuff that we need business teams product teams and everyone else focusing on it's that is the standard out-of-the-box platform level stuff that everyone needs to do that's IT bread and butter we know how to do that we know how to take those platforms we know how to scale them that's what we do that's our job we build these enterprise platforms but that innovation the thought about what can I do to make what can I do to make my team more productive that is generally with the groups where they're innovating right now that one of the reasons they're building these custom platforms and these is that they are trying to innovate and they don't have a place to do that innovation so they're doing it in their own platform but they're rebuilding things over and over and over again which is very inefficient the next thing we're focused on is platform democratization and this will take me into my next slide but just at a high level we want everybody to be able to use the platform we want everybody whether how no matter how technical or non-technical they are we want them to be able to get value out of the platform whether they're just a user or whether there's a data scientist we wanted them to be able to use the platform so easy in it we see the real value of generative AI in in the small things that you do as teams every single day so there are to be certain huge programs out there that are going to return a lot of value in terms of productivity to the company big giant programs but we can only do so many of those they're very expensive to do we have a limited number of resources and all of this smaller tasks that you do on a regular basis there isn't going to be a big program to put those in place so how do we provide the tools to everyone so that they can configure in a very inexpensive way um their own assistance which is a concept I'll talk about too so we'll go to the next slide here what I'm going to talk about here is our an extension of the democratization that I I mentioned in the previous section and it is a platform what we're really driving towards which is a platform for everyone so if you think about different types of users we have our end users sure people that come to the platform they may want to do some chatting with it and they may want to discover what other people have done as well right what have you done what have you done is there an OK R assistant out there that I can go and grab is there something that can help me with my insights is there something that can help me make my benefits selection or any of the general things that you do as a person on a regular basis meeting summary right that's another one I need something to summarize my meetings and most of the platforms that have been built out there until they all have meeting summarization which is great um that there needs to be these sort of core things that people are going to want to do right I just want to use the tool then there's the power user the power user is somebody who can manipulate prompts create some engineered prompts tune some prompts for specific models and then we would take that and publish it into a marketplace so that other people could find it also included within this concept and this is what I mentioned before the term assistant this is really what we're we're starting to kind of you know work down on which is an assistant could be created by a power user a developer or a data scientist those are the remaining roles right and it can be used by an end user or any of the others in there as well obviously but an assistant contains the workflow it contains the prompts it contains data as well in order to help accomplish a specific task a specific ask a specific thing so these power users are tuning these prompts and creating things publishing in the marketplace creating assistance the developer role they're creating more of an advanced assistant if you will so maybe they are creating some bringing some plugins in for for Python so they may want to do some specific Python execution maybe they have some data that's in a enterprise system that they want to access maybe they have some data that is in an unstructured data that they have in SharePoint they want to bring all of this data together all of these capabilities together and they to configure a plan for the workflow engine to execute so when a developer comes in or sorry when the user comes into the assistant that they create that the AI workflow takes place under the covers they just have to configure it our plans are or as we look at this you know we we have different plans in this place but we are you know potentially going to allow people to just upload a yaml document or maybe it's adjacent document that's yaml in our dev systems right now and we'll be we'll be able to kind of configure how the system works the last role we have here is just the traditional AI and ML data engineer and data scientist they're still huge role for these folks obviously but these people will be able to fine tune models on the platform they'll be able to implement advanced rag pipelines as well should they should they need that so they're they're definitely groups out there they're going to need that capability and these people will be able to continue to play the valuable role that they do Next Slide.At a high level this describes the three big components in our architecture and what and subsequently it's what we see in most other architectures as well over on the left we have data the data portion of the platform is really all about taking unstructured data and bringing it into the platform so that it can be used right we're taking the data and we're vectorizing it and then it can be used by the EA workflow tools that are there and we're talking about really any unstructured data that you have what we are providing on in the data side as well is the ability for people to just kind of select models like you're building an assistant and you want to bring in data and vectorize it you'd be able to just and you don't have a lot of technical skills you'd be able to select a preconfigured pipeline that would do you know mostly what you wanted you may have a specialized pipeline for word documents or for office docs and another pipeline that's out there for for web pages and things like that so depending on the kind of data you're bringing in we would allow people to kind of just select from drop downs on on what they wanted to do of course hitting back on that AI and mill engineer we will also have the ability for people to deploy custom pipelines into here so that they can ingest the data structured data is also very important as part of the data picture we have a lot of structured data sources at Intel we have a lot of master data we have a lot of enterprise data as well and I know there's every team has structured data out there and it's not always most efficient or effective to vectorize structure data could be um but we still need to make that available as structured data into the AI workflow engine and through so through a sets of plugins and you know into the air workflow tools we will be able to continue to or allow the access of to structured data via APIs into the into the a workflow tool going all the way over to the right on the picture we have large language models and and we all know what these do right they analyze the generate text and from our perspective right now a lot of that is done on a commercial cloud platform a lot of it's done with Azure open AI and all of that inferencing all of it is done with NVIDIA that's not a place we want to be I don't think that's good for Intel probably all don't either we really need to work out how we can run more of our workload more of our tentative our workload on our Intel hardware my team is working on that right now I know there's another team for sure that's working on that trying to figure out how do we get more and more of the workload running on on the gaudy AI accelerators as well as the Xeon data center processors so we're really focused on that but from a consumer perspective if you think about something consuming this architecture building things on this we don't want someone to rebuild that we want to be able to host that and just make those models available so you can then you know instead of using open AI GPT 4 you could use a different open source model than we might have 1 horseman gowdy still TBD on those so I'm not going to say which models for sure we'll have on there but we we're we're working through different options right now now back into the middle the a workflow is sort of the heart of the operation um we this is where um assistance sometimes called agents are executed to accomplish goals this is all we plan to make this all configurable we actually have it configurable right now and when you create something in the a workflow when you create an assistant or an agent in the a workflow engine we will automatically expose that via chat so that you can start working with it right away and through a process it can be exposed via as an API now if you think about kind of the tools that you're developing maybe you have chatter interfaces already maybe you have an application maybe you're not aren't gonna use the chat that we provide but we'll be have seen over and over again is that a lot of people need that chat interface and for advanced teams especially in advanced abetted IT teams that are out there they you can build what you want right you have the skills but the vast majority of people at Intel the vast majority of teams at Intel they don't have the type of skills that you have they they've heard of generative AI they know about generative AI they've seen what it's done maybe they've watched some videos read some articles but they need tools that they themselves can configure to accomplish goals without having to hire software engineers without having to engage it right these are all really expensive endeavors and so we're we're really trying to target them and give them the tooling that they need to to do that themselves and we believe that this architecture will fulfill that next slide.Taking that a little bit more into the architecture and then also specifically calling out Intel products I have this slide umm as I as I mentioned we'll kind of go down from the top here so as I mentioned before we will be having a universal chat interface in a marketplace that we make available to people client apps can create whatever they want call APIs they will all be invoking and using assistance if you aren't familiar with the concept of assistant it's similar to what Chad TPT has they just call them GPT uh Microsoft calls them assistance so conceptually it's just a preconfigured workflow AI workflow that includes both plugins to connect to external data as well as internal data as vectorized data as well as highly engineered prompts so it's really a plan prompts plugins code and things like that in order to accomplish a goal so all of those things they're callable the API we believe that many of the tools and plot and the things that are out there today could be recast as assistance we'll see how that goes overtime I'm pretty pretty confident that that's the case and then of course all of that sits on top of our generativity a service platform now we don't know how the Intel products are all going to fit in except one for sure that I'm absolutely positive we will be using gaudy in our platform we have it in POC mode right now I have access to the inference APIs I've been using them so we absolutely will be using gowdy hardware to to host open source large language models that is a for sure thing the other products out there though when it comes to like articulate or converge IO we are trying to figure out where they fit we're doing a sort of proof of technology or proof of concept with um with articulate right now in the TV space we're trying to learn learn how they how we could use that and how we can leverage it as well as we're working with with the Intel developer cloud folks to try to figure out how we can use that as well so yes we absolutely need to run on Intel hardware but at the same time we need to look at other Intel products too so we're kind of trying to figure out the right way to leverage those products and bring them into the platform so we don't obviously our capabilities that we're providing to to our clients and to our consumers wouldn't change so you know if for instance if we were talking about articulate that's a that's more of a retriever as a retrieval of the service capability we would implement that under the covers and people would be able to continue to use the assistant just like they have in the past to retrieve data we made slide in articulate depending on what we needed to do with it so trying to figure that out but it really we're trying to kind of figure that that whole thing out how do we properly leverage Intel capabilities where possible so this is something that we're really focused on as well next slide.Another thing that our platform does it allows us to do Co marketing opportunities so obviously we we we want to use into hardware but at the same time there may be opportunities in the future for Intel to do Co marketing I'm I you know we're trying to be flexible here in terms of giving giving opportunities to do different things but we we we want to target workload over to to Google or over to IBM or over to Azure or depending on depending on how things work out but we really want to be able to be very very flexible and choose really any cloud provider under the covers so we want to be able to target the right workload do the strength of each provider and potentially allow allow Co marketing should that ever that opportunity ever arise so we don't want to tie ourselves I think that's probably the important important thing here we're not going to tie ourselves to 1 vendor we are going to tie ourselves to Azure sure the open AI models are there but what about what what if what happens if if suddenly Google has a better product we want to be able to use that we want to be able to use the strength of each provider so it's really important for us not to tie ourselves to the them as well also as I mentioned previously we want to be able to enable and use the gaudy accelerators where we possibly can it really is important for us to use the entire hardware and and we want to provide as an IT organization we'd love to provide our sales teams with great proof points and great information and great examples of how organizations can leverage these tools next slide.All right our road map so not committed I lots of safe harbor language here we are we are focused on on delivering these capabilities but you know things change so you know in the legal portion of this presentation it mentions that you know we can't count on any of this but you know we're very much trying to March towards this but but once again it's definitely not a committed road map this is just a high level of what we're doing have this laid out on the top and in terms of releases this is an IT thing but we we have a quarterly road map as well which I'm happy to share but that this is kind of how we think of things in IT for better or for worse so in in April So what are we targeting to release we really are kind of focused on improving our knowledge assistance so giving people the ability to build start to build and construct their own assistance now those assistants could have data in them they could be what we are calling an inference assistant so we're where we allow people to inference through our platform so instead of going out and building your own connectivity to open AI your own express routes your own all your own stuff you should be able to make inference calls through through our platform and and and leverage some of the capabilities that we have there we are trying to give people the ability to share their prompts so as as you like say hey here's some prompts I wanna do give people the ability to share those I did mention earlier as well sharing prompts and giving people who've never used these tools the ideas about how to best gain value from prompts is super important sharing prompts is one way of doing that there could be some other way as well the high performance features we actually have implemented those so those who use our tool today which is igt.intel.com and I'll give an example below when we get there but those who use our tool today and we'll notice it performs quite well we're purchasing capacity committed performance capacity from Microsoft to do that we've also deployed our solution in a highly scalable way so um we're deployed globally so you know every zone at Intel we have a deployment in I think it's six if I remember right so 6 deployment zones you know couple on the West Coast Asia Europe Israel etcetera all deployed and so users around the company who use the tool should feel that it runs pretty well for them um yeah challenges with the with the open AI inferencing still exist in terms of where i" +# ] +# complex +# prompts = [ +# "You are a wonderful assistant. The following information is an agument that I am making to use Microsoft Semantc Kernel instead of LangChain. Can you review this information and reformat it so that it makes sense: I will be setting up a meeting early next week to drive this decision to a close. I want to go into that meeting with you and I aligned. Let me know your thoughts here.Something I won’t say in the meeting, but it is important to contemplate. As far as I know, we are the only team at Intel to thoughtfully select a GAI workflow library. We evaluated several frameworks and chose Semantic Kernel based on more than our team’s current skillset. I don’t think anyone else has done this.Decision Factors Capability. We need to choose a library that provides us the capabilities to deliver generative AI workflow as a service.Scalability. The platform needs to scale to tens of thousands of users and hour AND millions of API calls an hour.Supportability. We need to be able to support the platform long term for all of Intel as an enterprise capability.Roadmap focused on Intel Hardware support for LLMs. The platform must be able to showcase Intel hardware and demonstrate how Intel hardware can be used in a cost-effective way to enable generative AI.The argument for SK Capability o Semantic Kernel provides all of the required capabilities for a workflow library. We can give developers and configurators complete control over the actions that the workflow takes or we can allow the library to leverage an LLM to provide planning capabilities.The library provides the ability to include any OpenAPI compliant API into a workflow with simple configuration as a plugin. This allows us to extend the workflow library to utilize any of our thousands of APIs that we already have at Intel.o The plugin extensibility allows us to provide code based plugins that can enable various functions inside the workflow engine that LLMs are not good at doing. Math, time, etc.o The plugin extensibility allows us to extend the workflow engine in a way that enables us to provide a callable Python runtime that can execute any Python code.o All other aspects of the library are easily extended using the standard plugin interface, including LLM providers, vector db providers, etc. This gives us the ability to implement anything we want in the future without requiring us to modify the original open source solution from MSFT.o There is out of the box support for Microsoft’s services such as MS Graph.• Supportability o Semantic Kernel is built by Microsoft with the backing of their open source program. Microsoft is very committed to building open source software and keeping it open source. It is expected that they will provide support for SK for a long time.o Intel IT knows how to support .Net Core based projects. We have done it for years. We have many off the shelf libraries we can pull to perform basic tasks like authentication and authorization, logging, real-time monitoring, etc.o Intel IT has thousands of C# developers on staff and can staff augment with C# developers very easily should we need to move this platform into a sustain mode down the road.o Intel IT knows how to deliver containerized apps using this architecture on IT CaaS, Azure, or any other cloud provider.o Intel IT has a long relationship with Microsoft and has access to assistance for architecture questions and support questions. We know where to go to get answers to our questions.o Microsoft has a well know security methodology for evaluating security bugs in their open source software and closing them. There is a well established process for reporting security bugs for SK.• Scalability o Intel IT knows how to scale .Net Core platforms and solutions. We do this all the time and can easily scale the platform to meet the needs of tens of thousands of users an hour and millions of API calls an hour.o Moving the execution of the Python runtime for teams to leverage outside of the core implementation ensures that we can control the resource required to operate the platform at scale. We won’t lose control of the core kernel execution due to long running programs that hold execution threads.• Roadmap o While SK does not specifically enable implementation of our hardware, it does give us the ability to easily extend our GAI solution to use Intel Hardware via an MLOps platform.o Through our partnership with MSFT, we can contribute back to SK to enable others to use Intel Hardware.o Microsoft claims to be using, or will be going to use, SK as part of all of their co-pilot offerings. This will ensure continually enhancements to the library.o It is extremely clear what will delivered in upcoming releases of SK. MSFT uses GitHub projects to communicate this with the community.The argument against Semantic Kernel• Capability o We don’t get everything we want off the shelf. In order to enable Python code execution for instance, we would need to implement external code containers. While this has a benefit of allowing us to isolate the workflow runtime away from running code, we still have to build and support this.o Semantic Kernel’s capabilities align with Microsoft’s Azure platform. Microsoft is providing SK as a revenue generator for Azure and its other services. There is a reason that you find out of the box integration with MS Graph and any vector db that is available in Azure.o If we want off the shelf support for other aspects of the platform, such as REST API serving, we need to build our own wrapper capabilities.• Supportability o The required software engineering role can be expensive if not resourced correctly.• Scalability o While we know how to scale .Net Core C# apps, we could run into problems when we scaled a SK based workflow app especially as we are forced to move some execution, for instance the execution of Python code, outside of the kernel execution runtime.• Roadmapo This is still an open source project that is new. It doesn’t have a long running history of support. Our guidelines in IT state that we shouldn’t use open source projects that are less than 3 years old. I wrote the guidelines that way because operational support and roadmap of continual updates and fixes (most importantly security fixes) is critical for IT systems that live 5 – 10 years.o It is likely that MSFT will continue to tie more and more of their products into Semantic Kernel. While this isn’t a huge concern for us, since Intel IT does lean as a MSFT shop, but it could make it difficult to change directions if we are not careful about what we implement.The argument for LangChain• Capability o LangChain provides all of the required capabilities for a workflow library. We can give developers and configurators complete control over the actions that the workflow takes, or we can allow the library to leverage an LLM to provide planning capabilities.o The core library provides a vast number of features that have been added by open source contributors. LLM providers, data sources, vector dbs, etc. people have added a ton of features.o The library provides the ability to include any OpenAPI compliant API into a workflow with simple configuration as a plugin. This allows us to extend the workflow library to utilize any of our thousands of APIs that we already have at Intel.o We can provide our own python libraries that can perform tasks that LLMs are bad at like math, time, etc.o The library is easily extended so we could add our own capabilities without requiring us to change the LangChain core project code.o Additional capabilities like LangSmith are being provided by LangChain (the company) and will be available as cloud based services.• Supportability o LangChain is a pure Python offering that has the support of thousands of community members.o Python is a very popular language in both data science roles and software engineering roles.• Scalability o I believe that a Python based hosting solution can scale.• Roadmap o It is certainly possible for us to provide modifications to the library that will enable other people to use Intel Hardware. We can certainly do this as well. The library is highly extensible and built to allow us to do this.o Community members can contribute to the project in a safe way which can enable more capabilities in the future.The argument against LangChain• Capabilityo We don’t get everything we want off the shelf with LangChain. We are going to need to build a few connectors that talk to our MLOps platform. We are going to need to build a custom search connector for Acceler8.o As LangChain is driven by their VC funders to seek revenue, it is hard to determine the course of the capability. It appears that additional capabilities like LangSmith are being added as cloud based services with cost.o The capabilities added are driven by the community. This is a two edged sword. The library has lots and lots of things we will never use,o It is hard to influence the community roadmap without contributing code ourselves.• Supportabilityo We must go to the community for answers to questions or for guidance when problems arise.o Intel IT does not have any off the shelf components to build scalable Python based API. We would need to construct a new set of libraries to support basic functionality like logging, real-time monitoring, and authentication and authorization.o LangServe, the open source chain hosting platform for LangChain is an early project and would require significant “plumbing” work for Intel IT to make it work in their environment.o Some project teams have had problems moving from one version of LangChain to another because of lack of backward compatibility (this might be solved with the latest versions of the libraries where community contributors and isolated to a different project)o It is difficult to guess at the long term goals of LangChan (the company). They are the stewards of the libraries, but are now funded by VC. Those VC funders expect a return on their investment and it remains to be seen as to how they will get that. The LangChain (the company) revenue stream is not clear to me.o This is a new open source library and the open source standards that I wrote state libraries should have at least a 3 year history before being used in IT projects.o We do not know if LangChain (the company) will address security vulnerabilities in the same way as other enterprises that steward open source projects. Scalability o I believe that we can scale Python, but I could not point to a person in Intel IT that has experience scaling a Python solution to the level required by Intel. I am sure we can do this, but we will be breaking new ground to do it.• Roadmap o It isn’t clear what is being delivered by LangChain and when.o They are certainly adding new features like LangServe, but I am concerned that their latest feature LangSmith requires a license to use. I am not sure how their VC investors expect to be compensated, and am concerned (without data) that the roadmap could lead to a tie in with LangChain (the company) product.Comparison Summary Capability: Semantic Kernel and LangChain provide us with almost exactly the same capabilities but come in different wrappers. LangChain is a pure Python implementation (JS is on the way), and Semantic Kernel is a C# first implementation (with full support for Python and Java).Supportability: Intel IT knows how to support C# .Net Core based platforms very well. Intel IT does not know how to support enterprise Python based platforms. To meet supportability requirements, we must choose Semantic Kernel.Scalability: Intel IT knows how to scale C# .Net Core based platforms. We do it all the time. We do not know how to scale Python based platforms. We certainly can scale Python based solutions, but it will take time., To meet scalability goals in the time we have, we must choose Semantic Kernel.Roadmap: Both products offer us the ability to implement on our own hardware. There is no direct tie between the workflow library and an Intel hardware solution. We have to build that using either library. Because of the skillset involved in doing this work, we need to choose Sematic Kernel to meet our timeline goals.I have reviewed your argument and reformatted it for clarity and coherence. Here's the revised version:I am planning a meeting next week to finalize our decision on the Generative AI (GAI) workflow library. I believe we should align on this before the meeting. As far as I'm aware, we are the only team at Intel that has thoroughly evaluated several frameworks and chosen Microsoft's Semantic Kernel (SK) over LangChain. This decision is not solely based on our team's current skillset.Our decision is based on four factors:Capability: We need a library that can deliver generative AI workflow as a service.Scalability: The platform must scale to tens of thousands of users and millions of API calls per hour.Supportability: We need a platform that we can support long-term across Intel as an enterprise capability.Roadmap: The platform must showcase Intel hardware and demonstrate its cost-effective use in generative AI.Here are the arguments for and against both SK and LangChain:Semantic Kernel (SK) Pros:Capability: SK provides all required capabilities for a workflow library. It allows us to extend the workflow library to utilize any of our APIs at Intel and provides out-of-the-box support for Microsoft’s services.Supportability: SK is backed by Microsoft's open source program, and Intel IT has a long relationship with Microsoft. We have many off-the-shelf libraries and thousands of C# developers on staff.Scalability: Intel IT has experience scaling .Net Core platforms and solutions.Roadmap: SK allows us to extend our GAI solution to use Intel Hardware via an MLOps platform. Microsoft plans to use SK in all their co-pilot offerings, ensuring continual enhancements.Cons:Capability: We would need to implement external code containers to enable Python code execution. SK's capabilities align with Microsoft’s Azure platform, which could limit our options.Supportability: The required software engineering role can be expensive if not resourced correctly.Scalability: We could face problems when scaling a SK-based workflow app.Roadmap: SK is a new open source project with less than 3 years of history, which goes against our IT guidelines.LangChain Pros:Capability: LangChain provides all required capabilities for a workflow library and has a vast number of features added by open source contributors.Supportability: LangChain is a pure Python offering with the support of thousands of community members.Scalability: A Python-based hosting solution can scale.Roadmap: We can provide modifications to the library that will enable others to use Intel Hardware.Cons:Capability: We would need to build a few connectors to fully utilize LangChain. The library's capabilities are driven by the community, which could limit our influence.Supportability: We must rely on the community for support. Intel IT does not have any off-the-shelf components to build scalable Python-based API.Scalability: We lack experience in scaling a Python solution to the level required by Intel.Roadmap: It isn’t clear what is being delivered by LangChain and when.Comparison SummaryBoth SK and LangChain offer similar capabilities but in different formats. However, considering supportability, scalability, and roadmap, Semantic Kernel is the more suitable choice for us due to our familiarity with C# .Net Core based platforms and our long-standing relationship with Microsoft.I want to reformat this as a persuasive essay that favors Semantic Kernel over LangChain.Title: The Case for Semantic Kernel over LangChain for Generative AI Workflow Library Introduction:As we approach our upcoming meeting to finalize the decision on the Generative AI (GAI) workflow library, it is crucial to align our perspectives. Our team at Intel stands unique in its thoughtful selection of a GAI workflow library, having evaluated several frameworks and chosen Microsoft's Semantic Kernel (SK) over LangChain. This decision extends beyond our team's current skillset and is based on four key factors: capability, scalability, supportability, and a roadmap focused on Intel Hardware support for LLMs. Body:The Argument for Semantic Kernel:SK offers robust capabilities as a workflow library, providing developers and configurators with complete control over the actions that the workflow takes. It allows us to extend the workflow library to utilize any of our APIs at Intel and provides out-of-the-box support for Microsoft’s services.In terms of supportability, SK is backed by Microsoft's open source program, and Intel IT has a long relationship with Microsoft. We have many off-the-shelf libraries and thousands of C# developers on staff. Intel IT has experience scaling .Net Core platforms and solutions, which aligns with our scalability needs.The roadmap for SK is promising. While SK does not specifically enable implementation of our hardware, it does give us the ability to easily extend our GAI solution to use Intel Hardware via an MLOps platform. Microsoft plans to use SK in all their co-pilot offerings, ensuring continual enhancements.The Argument Against Semantic Kernel:Despite its strengths, SK does have some limitations. We would need to implement external code containers to enable Python code execution. SK's capabilities align with Microsoft’s Azure platform, which could limit our options. The required software engineering role can be expensive if not resourced correctly. We could face problems when scaling a SK-based workflow app. SK is a new open source project with less than 3 years of history, which goes against our IT guidelines.The Argument for LangChain:LangChain also provides all required capabilities for a workflow library and has a vast number of features added by open source contributors. It is a pure Python offering with the support of thousands of community members. A Python-based hosting solution can scale. We can provide modifications to the library that will enable others to use Intel Hardware.The Argument Against LangChain:LangChain has its own set of challenges. We would need to build a few connectors to fully utilize LangChain. The library's capabilities are driven by the community, which could limit our influence. We must rely on the community for support. Intel IT does not have any off-the-shelf components to build scalable Python-based API. We lack experience in scaling a Python solution to the level required by Intel. It isn’t clear what is being delivered by LangChain and when.Conclusion:In comparing Semantic Kernel and LangChain, both offer similar capabilities but in different formats. However, considering supportability, scalability, and roadmap, Semantic Kernel is the more suitable choice for us due to our familiarity with C# .Net Core based platforms and our long-standing relationship with Microsoft. Therefore, to meet our capability, supportability, scalability, and roadmap goals within our timeline, we should opt for Semantic Kernel.With your updated knowledge base, summarize this conversation into one slide that I can use for an architecture community.Slide Title: Choosing the Right GAI Workflow Library: Semantic Kernel vs. LangChain Slide Content: Objective: Finalize decision on Generative AI workflow library for IntelKey Decision Factors:Capability: Deliver GAI workflow as a service Scalability: Support tens of thousands of users & millions of API calls/hour Supportability: Long-term enterprise capability for Intel Roadmap: Showcase & optimize Intel hardware for LLMs Semantic Kernel (SK) - Microsoft: Pros:Comprehensive capabilities for workflow control and extensibility Strong supportability with .Net Core expertise & Microsoft partnership Proven scalability within Intel IT infrastructure Clear roadmap with ongoing enhancements & Intel hardware integration potential Cons:Requires additional development for Python execution support Newer open source project (<3 years), potential risk for long-term support LangChain: Pros:Extensive capabilities with community-driven features Python-based, popular in data science and software engineering Cons:Requires development of new connectors and support libraries Unclear long-term roadmap and scalability within Intel IT Recommendation:Choose Semantic Kernel for its alignment with Intel's supportability, scalability, and strategic roadmap, leveraging our existing C# .Net Core expertise and Microsoft relationship.Next Steps:Align on SK decision in upcoming meetingPlan for implementation and scaling strategy Note: Slide should be visually balanced, with clear headings, bullet points, and a simple color scheme that highlights the recommendation. Use of Intel and Microsoft logos may be appropriate to visually represent the partnership and existing infrastructure." +# ] + +# prompts = [ +# "You are a wonderful assistant. can you please summarize the below content? Create meeting minutes from the VTT included within this message. If no VTT is included, then guide the user through the process of downloading and opening a VTT which was recorded with Microsoft Teams and stop and wait for the user to input the VTT data before proceeding with the rest of the directions. Create a table of assignments that were made during the meeting that has the assignment and the person who owns completing the assignment. You are an expert meetings minute assistant. Create a summary section that summarizes the intent of the meeting in less than 1000 words. Include summaries of major topics that were discussed in less than 1000 words. Under this message insert this text “Continue chatting with your meeting using these generated prompts:” and provide 3 additional prompts that could be executed in order to provide additional insights into the meeting. The additional prompts should appear at the end and be in italic text. Use a clear and professional tone. Do not mention VTT or transcripts in the meeting minutes Do not mention VTT or transcripts in the meeting minutes. WEBVTT FILE 1 00:00:00.000 --> 00:00:07.090 [ INTENSE MUSIC ] 2 00:00:07.090 --> 00:00:11.386 -We have lock, and are good to send that command. We have thirty-one minutes and thirty-two 3 00:00:11.386 --> 00:00:15.432 seconds for our support. 4 00:00:15.432 --> 00:00:18.601 Go for status buffer dump. 5 00:00:18.601 --> 00:00:22.605 -What you're looking at with a telescope, of course, is the light from billions of light-years 6 00:00:22.605 --> 00:00:27.152 away. So the further you look, the more you're going back towards the Big Bang and 7 00:00:27.152 --> 00:00:30.905 understanding how the universe was formed. 8 00:00:30.905 --> 00:00:36.327 -The Webb telescope will be groundbreaking because it has capabilities that are different 9 00:00:36.327 --> 00:00:39.080 than the Hubble Space Telescope. 10 00:00:39.080 --> 00:00:43.710 -We have equipment that is so much more powerful than anything we've ever had before, that 11 00:00:43.710 --> 00:00:47.797 it's almost impossible to tell what we will discover. 12 00:00:47.797 --> 00:00:52.969 Hubble Eye In The Sky 13 00:00:52.969 --> 00:00:59.100 Episode 3: Time Machines 14 00:00:59.100 --> 00:01:01.895 -Hubble's accomplishments include something 15 00:01:01.895 --> 00:01:03.938 called the deep fields. 16 00:01:03.938 --> 00:01:09.903 Looking out into space and collecting light, sometimes for many days, 17 00:01:09.903 --> 00:01:12.489 these deep fields have revealed visually to us, 18 00:01:12.489 --> 00:01:18.661 a universe absolutely teaming with galaxies, hundreds of billions of galaxies. 19 00:01:22.165 --> 00:01:24.626 -One of the neat things about the Ultra Deep Field, 20 00:01:24.626 --> 00:01:26.544 and one of the things that made it so unique, 21 00:01:26.544 --> 00:01:28.254 was how long it took us 22 00:01:28.254 --> 00:01:31.382 to take that image. There's an exposure time that's expressed, 23 00:01:31.382 --> 00:01:36.346 I think it's 11.2 days. It’s a very, very long exposure time, but probably what's more 24 00:01:36.346 --> 00:01:42.936 important is how many orbits it took us to do that. 400 orbits of Hubble data to take 25 00:01:42.936 --> 00:01:47.232 that image. You only get 15 orbits a day. To take 400 26 00:01:47.232 --> 00:01:53.363 orbits and say we're going to observe this one spot in the sky for 400 orbits, and the 27 00:01:53.363 --> 00:01:58.952 really unique thing about that was they picked a spot where there wasn't anything. They looked 28 00:01:58.952 --> 00:02:03.873 and they said there's absolutely nothing here. And they said, you want to spend 400 Hubble 29 00:02:03.873 --> 00:02:08.711 orbits looking at nothing? And they said yes, because we want to see what it can see. And 30 00:02:08.711 --> 00:02:18.721 I think the results from the science, I mean it was amazing. What they saw was spectacular. 31 00:02:18.721 --> 00:02:21.516 -Hubble had spent two weeks taking pictures of empty places 32 00:02:21.516 --> 00:02:23.893 in the sky. And they saw they weren't empty at all 33 00:02:23.893 --> 00:02:26.271 there were thousands and thousands of galaxies. 34 00:02:28.606 --> 00:02:31.818 -We were amazed how many galaxies we found, 35 00:02:31.818 --> 00:02:33.653 and we continued to go back to that 36 00:02:33.653 --> 00:02:40.410 portion of the sky to increase that visibility. 37 00:02:40.410 --> 00:02:47.292 -The Hubble Space Telescope is an outstanding time machine. 38 00:02:47.292 --> 00:02:51.713 It's incredibly important for our studies with the Hubble Space Telescope to realize 39 00:02:51.713 --> 00:02:58.261 that when we're looking at a galaxy, we're seeing it as it was millions of years ago, 40 00:02:58.261 --> 00:03:03.266 sometimes billions of years ago. It's taken that long for the light to get to us. 41 00:03:03.266 --> 00:03:05.268 -What you're looking at with a telescope, of course, 42 00:03:05.268 --> 00:03:08.229 is the light from billions of light-years away. 43 00:03:08.229 --> 00:03:10.815 So the further you look, the more you're going back towards 44 00:03:10.815 --> 00:03:16.196 the Big Bang and understanding how the universe was formed. 45 00:03:16.196 --> 00:03:21.659 -What Hubble has revealed is that the universe has in fact changed over these billions of 46 00:03:21.659 --> 00:03:28.791 years of time. The early galaxies, the very distant ones as we see them, are simple. Sometimes 47 00:03:28.791 --> 00:03:33.504 they're messy looking, they're small. They haven't had time yet to form that grand 48 00:03:33.504 --> 00:03:38.218 spiral structure. Over time, we see galaxies actually merging 49 00:03:38.218 --> 00:03:43.181 with other galaxies and growing bigger and bigger, and those mergers can look like train 50 00:03:43.181 --> 00:03:44.807 wrecks in our Hubble images. 51 00:03:51.731 --> 00:03:55.485 -These very, very deep exposures that Hubble has been able to take, we have seen right to 52 00:03:55.485 --> 00:04:01.199 the edge of the universe, thirteen and a half billion years. 53 00:04:01.199 --> 00:04:04.702 When Hubble was first designed and envisioned, it was never thought it could actually see 54 00:04:04.702 --> 00:04:09.415 that far out. But because of the advances in the instruments that we've been able to 55 00:04:09.415 --> 00:04:13.711 put up on the telescope, and also the cleverness of the scientists, they've come up with very 56 00:04:13.711 --> 00:04:18.007 interesting observing scenarios, doing these really deep exposures, where we just sit there 57 00:04:18.007 --> 00:04:22.929 for orbit after orbit, after orbit gathering the photons, we’ve been able to push Hubble 58 00:04:22.929 --> 00:04:26.933 out very, very far. 59 00:04:26.933 --> 00:04:33.606 -As Hubble looks out into these fields of galaxies, we sometimes see clusters of galaxies. These 60 00:04:33.606 --> 00:04:38.945 are galaxies that are held nearby each other by their mutual gravity. 61 00:04:38.945 --> 00:04:46.869 These clusters are massive conglomerations. There's so much mass that they have an actual 62 00:04:46.869 --> 00:04:54.919 observable impact on space-time itself. Einstein predicted that mass distorts space, 63 00:04:54.919 --> 00:04:59.215 but we didn't realize we could actually see the effects of that. But with Hubble, we have 64 00:04:59.215 --> 00:05:05.513 been able to see distortions in space around clusters of galaxies. The way we see that 65 00:05:05.513 --> 00:05:11.769 is when light from a background galaxy travels through that cluster of galaxies, or around 66 00:05:11.769 --> 00:05:17.400 it, due to this gravitational lensing effect. The lensing also magnifies that background 67 00:05:17.400 --> 00:05:22.947 galaxy, so if we look in some of these distorted arcs, we can see more detail than we would 68 00:05:22.947 --> 00:05:28.453 ever have been able to see without gravitational lensing, nature’s boost. 69 00:05:32.290 --> 00:05:36.377 -There are observations where we're explicitly looking for the lensing and 70 00:05:36.377 --> 00:05:39.547 we're getting science out of that just otherwise 71 00:05:39.547 --> 00:05:43.718 would just not be doable. Hubble has really taken that 72 00:05:43.718 --> 00:05:47.096 to a next level. It's doing large amounts 73 00:05:47.096 --> 00:05:50.433 of astrophysics that it's just never been able to do before. 74 00:05:55.313 --> 00:06:01.069 -Some of what we're doing with Hubble is to prepare for the new James Webb telescope, 75 00:06:01.069 --> 00:06:07.075 which we anticipate launching in 2021, which will be able to see farther into the infrared 76 00:06:07.075 --> 00:06:14.207 part of the electromagnetic spectrum. That enables us to see some galaxies that are difficult 77 00:06:14.207 --> 00:06:20.797 for Hubble to see because they're so far away that their light is traveling through us through 78 00:06:20.797 --> 00:06:26.803 expanding space and gets stretched out into redder wavelengths, often far into the infrared 79 00:06:26.803 --> 00:06:32.308 part of the spectrum. Even sometimes beyond what Hubble is able to detect well. The Webb 80 00:06:32.308 --> 00:06:38.064 telescope will give us more information about some of those very distant galaxies. 81 00:06:38.064 --> 00:06:43.444 -The James Webb Space Telescope is the follow-on telescope after the great Hubble telescope. 82 00:06:43.444 --> 00:06:47.532 It extends the discoveries of Hubble into the infrared spectrum region. 83 00:06:49.450 --> 00:06:53.371 We think that the first objects that grew out of the Big Bang material probably happened 84 00:06:53.371 --> 00:06:58.960 in about a hundred million years after the start. And we think the Webb telescope can 85 00:06:58.960 --> 00:07:04.173 pick them up. They're rare, they're hard to find, but they should be there. 86 00:07:04.173 --> 00:07:08.678 The farthest we've been able to see with the Hubble telescope goes back about 600 - 800 87 00:07:08.678 --> 00:07:13.891 million years after the expansion began, so we think we get much, much closer to the 88 00:07:13.891 --> 00:07:18.187 first objects with the Webb telescope. 89 00:07:18.187 --> 00:07:20.982 -Hubble gives information that the Webb telescope cannot give about 90 00:07:20.982 --> 00:07:24.861 visible and ultraviolet emission from things in the universe, 91 00:07:24.861 --> 00:07:26.821 and when we have all of that information coming in 92 00:07:26.821 --> 00:07:32.577 at the same time, it's like a banquet of scientific return. 93 00:07:32.577 --> 00:07:37.165 -Now when we get the complete picture of every wavelength you can possibly see from ultraviolet 94 00:07:37.165 --> 00:07:40.168 to infrared, we hope to have the story of the growth 95 00:07:40.168 --> 00:07:42.795 of the first galaxies from the primordial material. 96 00:07:42.795 --> 00:07:46.632 So that will be a huge accomplishment that depends on both pieces 97 00:07:46.632 --> 00:07:51.971 of equipment, the Hubble telescope and the James Webb telescope working together. 98 00:07:51.971 --> 00:07:56.058 -So, astronomers are very excited about this probability that we'll have both the Hubble 99 00:07:56.058 --> 00:08:01.022 Space Telescope and the Webb telescope operating at the same time for quite a few years. That 100 00:08:01.022 --> 00:08:04.901 will give us an abundance of new understanding about the universe. 101 00:08:04.901 --> 00:08:11.449 And already right now with Hubble, we're doing preparatory observations for the Webb telescope. 102 00:08:11.449 --> 00:08:16.329 We're using Hubble to do things, for example, like surveying distant galaxies to find out 103 00:08:16.329 --> 00:08:20.958 which ones would be prime targets for the Webb telescope. 104 00:08:20.958 --> 00:08:27.590 In fact, scientists around the world are proposing observations with Hubble right now specifically 105 00:08:27.590 --> 00:08:32.553 to help us learn information that will be useful for making the best use of the Webb 106 00:08:32.553 --> 00:08:39.435 telescope as soon as it's launched and gets going in its science observations. 107 00:08:39.435 --> 00:08:43.231 -I think the Hubble telescope has been the most productive science instrument ever built. 108 00:08:43.272 --> 00:08:47.193 In astronomy, there's what we knew before Hubble, and now, there's what we know after 109 00:08:47.193 --> 00:08:48.486 Hubble. They're so different. 110 00:08:48.486 --> 00:08:52.698 Of course, Hubble has now had a life of 30 years, so it's had a long time to make this 111 00:08:52.698 --> 00:08:57.912 revolution happen. So it's not all at once. It's a gradual revolution, but it's still 112 00:08:57.912 --> 00:09:00.915 a huge revolution. 113 00:09:00.915 --> 00:09:05.461 Knowledge has changed dramatically over the 30 years of life of the Hubble telescope. 114 00:09:05.461 --> 00:09:09.549 so you couldn't even have imagined when the Hubble was launched that we would have the 115 00:09:09.549 --> 00:09:13.135 wonderful cameras and spectrometers that we fly today. 116 00:09:13.135 --> 00:09:17.306 We figured out how to send astronauts, we trained the astronauts, we figured out what 117 00:09:17.306 --> 00:09:20.893 instruments could be put in. We figured out how to repair everything that went wrong on 118 00:09:20.893 --> 00:09:25.189 the Hubble, and it's still alive today, 30 years after launch. 119 00:09:25.189 --> 00:09:30.027 I am so thrilled to say that our people were able to do that. 120 00:09:30.027 --> 00:09:34.115 That's the operations team that makes this possible. It's a miracle as far as I'm concerned, 121 00:09:34.115 --> 00:09:36.784 because it didn't have to be that way, but they made it happen. 122 00:09:36.784 --> 00:09:44.208 Hubble Eye In The Sky 123 00:09:44.208 --> 00:09:51.173 [ INTENSE MUSIC ]" +# ] +# contents = [] + +# print("length of prompts: ", len(prompts[0])) +# l = 512 +# splited_prompts = [] +# sum = 0 +# while sum < len(prompts[0]): +# splited_prompts.append(prompts[0][sum:sum+l]) +# sum += l + +# prompts = splited_prompts + +new_prompts = [] + +for i in range(4): + # content = {"prompt": prompts[i%len(prompts)], "stream": False, "max_tokens": 4096, "best_of": 2, "use_beam_search": True, "temperature": 0} + # content = {"prompt": prompts[i%len(prompts)], "stream": False, "max_tokens": 4096} + # contents.append(content) + new_prompts.append(prompts[i % len(prompts)]) +###################################### +# new_prompts.clear() +# new_prompts.append(prompts[2]) + + +# Create an LLM. +# llm = LLM(model="facebook/opt-125m", device="cpu", quantization="awq") +# llm = LLM(model="TheBloke/Llama-2-7b-Chat-AWQ", device="cpu", quantization="AWQ") +# llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", device="cpu", quantization="ns") +# TODO verify block_size >= max_model_len +# TODO set VLLM_CPU_KVCACHE_SPACE to X (GB) so that VLLM_CPU_KVCACHE_SPACE/(block_size*element_size) = num_cpu_blocks <= max_num_seqs. Otherwise, native kv cache may run out of slots. +ctx_size = 4096 +llm = LLM( + # model="meta-llama/Meta-Llama-3-8B-Instruct", + model="meta-llama/Llama-2-7b-chat-hf", + device="cpu", + max_num_seqs=8, + block_size=ctx_size, + max_model_len=ctx_size, + quantization="ns", +) +# llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", device="cpu") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +for i in range(1): + t0 = perf_counter() + outputs = llm.generate(new_prompts, sampling_params) + total_time = perf_counter() - t0 + # Print the outputs. + total_prompts = 0 + total_generated = 0 + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + total_prompts += len(prompt.split(" ")) + total_generated += len(generated_text.split(" ")) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + print(total_time, total_prompts, total_generated) + print(f"prompts per second: {total_prompts/total_time}") + print(f"tokens per second: {total_generated/total_time}") diff --git a/llm_on_ray/inference/api_server_openai.py b/llm_on_ray/inference/api_server_openai.py index 6b5a0e2db..946c99300 100644 --- a/llm_on_ray/inference/api_server_openai.py +++ b/llm_on_ray/inference/api_server_openai.py @@ -62,6 +62,7 @@ def router_application(deployments, model_list, max_ongoing_requests): RouterDeployment = serve.deployment( route_prefix="/", + ray_actor_options={"resources": {"app_router": 1}}, max_ongoing_requests=total_num_replica * ( (max_ongoing_requests if max_ongoing_requests else max_num_concurrent_query) + 1 diff --git a/llm_on_ray/inference/inference_config.py b/llm_on_ray/inference/inference_config.py index 7d405c7c7..8a6d53a97 100644 --- a/llm_on_ray/inference/inference_config.py +++ b/llm_on_ray/inference/inference_config.py @@ -57,8 +57,10 @@ def _check_precision(cls, v: str): class Vllm(BaseModel): enabled: bool = False max_num_seqs: int = 256 + max_batched_tokens: int = 4096 precision: str = "bf16" enforce_eager: bool = False + extension: str = None @validator("precision") def _check_precision(cls, v: str): @@ -66,6 +68,24 @@ def _check_precision(cls, v: str): assert v in [PRECISION_BF16, PRECISION_FP32] return v + @validator("extension") + def _check_extension(cls, v: str): + if v: + assert v in ["ns"] + return v + + @validator("max_num_seqs") + def _check_max_num_seqs(cls, v: int): + if v: + assert v > 0 + return v + + @validator("max_batched_tokens") + def _check_max_batched_tokens(cls, v: int): + if v: + assert v > 32 + return v + # for IPEX-LLM model class IpexllmModelConfig(BaseModel): diff --git a/llm_on_ray/inference/models/llama-3-8b-instruct.yaml b/llm_on_ray/inference/models/llama-3-8b-instruct.yaml new file mode 100644 index 000000000..721b19dfb --- /dev/null +++ b/llm_on_ray/inference/models/llama-3-8b-instruct.yaml @@ -0,0 +1,26 @@ +port: 8000 +name: llama-3-8b-instruct +route_prefix: /llama-3-8b-instruct +num_replicas: 4 +cpus_per_worker: 1 +gpus_per_worker: 0 +deepspeed: false +workers_per_group: 2 +device: cpu +ipex: + enabled: false + precision: bf16 +model_description: + model_id_or_path: meta-llama/Meta-Llama-3-8B-Instruct + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + chat_processor: ChatModelLLama + prompt: + intro: '' + human_id: '[INST] {msg} [/INST] + + ' + bot_id: '' + stop_words: [] + config: + use_auth_token: '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' + diff --git a/llm_on_ray/inference/models/template/inference_config_template.yaml b/llm_on_ray/inference/models/template/inference_config_template.yaml new file mode 100644 index 000000000..8b90579f1 --- /dev/null +++ b/llm_on_ray/inference/models/template/inference_config_template.yaml @@ -0,0 +1,46 @@ +host: 0.0.0.0 +port: 8000 +name: default +route_prefix: null +num_replicas: 1 +cpus_per_worker: 24 +gpus_per_worker: 0 +hpus_per_worker: 0 +deepspeed: false +vllm: + enabled: false + precision: bf16 +workers_per_group: 2 +device: cpu +ipex: + enabled: true + precision: bf16 +hpu_model_config: + use_hpu_graphs: true + torch_compile: false + quant_config: null +model_description: + model_id_or_path: null + tokenizer_name_or_path: null + config: + trust_remote_code: false + use_auth_token: null + load_in_4bit: false + torch_dtype: null + revision: null + prompt: + intro: '' + human_id: '' + bot_id: '' + stop_words: [] + chat_processor: null + gpt_base_model: false + quantized_model_id_or_path: null + quantization_type: null + peft_model_id_or_path: null + peft_type: null + ipexllm: false + ipexllm_config: + load_in_low_bit: '' + input_processor: AutoProcessor + model_loader: AutoModel diff --git a/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-ns.yaml b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-ns.yaml new file mode 100644 index 000000000..319d8e397 --- /dev/null +++ b/llm_on_ray/inference/models/vllm/llama-2-7b-chat-hf-vllm-ns.yaml @@ -0,0 +1,29 @@ +port: 8000 +name: llama-2-7b-chat-hf +route_prefix: /llama-2-7b-chat-hf +num_replicas: 1 +cpus_per_worker: 1 +gpus_per_worker: 0 +deepspeed: false +vllm: + enabled: true + precision: bf16 + extension: ns +workers_per_group: 2 +device: cpu +ipex: + enabled: false + precision: bf16 +model_description: + model_id_or_path: meta-llama/Llama-2-7b-chat-hf + tokenizer_name_or_path: meta-llama/Llama-2-7b-chat-hf + chat_processor: ChatModelLLama + prompt: + intro: '' + human_id: '[INST] {msg} [/INST] + + ' + bot_id: '' + stop_words: [] + config: + use_auth_token: '' diff --git a/llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm-ns.yaml b/llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm-ns.yaml new file mode 100644 index 000000000..2721c9a76 --- /dev/null +++ b/llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm-ns.yaml @@ -0,0 +1,30 @@ +port: 8000 +name: llama-3-8b-instruct +route_prefix: /llama-3-8b-instruct +num_replicas: 1 +cpus_per_worker: 27 +gpus_per_worker: 0 +deepspeed: false +vllm: + enabled: true + precision: bf16 + extension: ns + max_batched_tokens: 8192 +workers_per_group: 2 +device: cpu +ipex: + enabled: false + precision: bf16 +model_description: + model_id_or_path: meta-llama/Meta-Llama-3-8B-Instruct + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + chat_processor: ChatModelLLama + prompt: + intro: '' + human_id: '[INST] {msg} [/INST] + + ' + bot_id: '' + stop_words: [] + config: + use_auth_token: '' diff --git a/llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm.yaml b/llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm.yaml new file mode 100644 index 000000000..ac26a1d7c --- /dev/null +++ b/llm_on_ray/inference/models/vllm/llama-3-8b-instruct-vllm.yaml @@ -0,0 +1,28 @@ +port: 8000 +name: llama-3-8b-instruct +route_prefix: /llama-3-8b-instruct +num_replicas: 1 +cpus_per_worker: 27 +gpus_per_worker: 0 +deepspeed: false +vllm: + enabled: true + precision: bf16 +workers_per_group: 2 +device: cpu +ipex: + enabled: false + precision: bf16 +model_description: + model_id_or_path: meta-llama/Meta-Llama-3-8B-Instruct + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct + chat_processor: ChatModelLLama + prompt: + intro: '' + human_id: '[INST] {msg} [/INST] + + ' + bot_id: '' + stop_words: [] + config: + use_auth_token: '' diff --git a/llm_on_ray/inference/predictors/vllm_predictor.py b/llm_on_ray/inference/predictors/vllm_predictor.py index d3d09414a..90d7653f9 100644 --- a/llm_on_ray/inference/predictors/vllm_predictor.py +++ b/llm_on_ray/inference/predictors/vllm_predictor.py @@ -17,6 +17,8 @@ import asyncio import os from typing import AsyncGenerator, List, Union +from transformers import AutoConfig +import ray from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams @@ -27,6 +29,10 @@ ModelGenerateResult, PRECISION_BF16, ) +from llm_on_ray.inference import utils +import logging + +logger = logging.getLogger(__name__) class VllmPredictor(Predictor): @@ -43,15 +49,44 @@ def __init__(self, infer_conf: InferenceConfig, max_num_seqs): # The default value is 40GB. os.environ["VLLM_CPU_KVCACHE_SPACE"] = str(self.VLLM_CPU_KVCACHE_SPACE_DEFAULT) - args = AsyncEngineArgs( - model=model_desc.model_id_or_path, - trust_remote_code=model_config.trust_remote_code, - device=infer_conf.device, - dtype=dtype, - disable_log_requests=True, - max_num_seqs=max_num_seqs, - enforce_eager=infer_conf.vllm.enforce_eager, - ) + if infer_conf.vllm.extension == "ns": + logger.warn("applying neural speed extension to vllm ...") + try: + from vllm.extension import ns + + logger.warn("neural speed extension applied to vllm successfully!") + except Exception as e: + logger.error(f"failed to apply neural speed extension to vllm: {e}") + raise e + # get context size from HF + hf_config = AutoConfig.from_pretrained( + model_desc.model_id_or_path, + trust_remote_code=True, + use_auth_token=model_config.use_auth_token, + ) + ctx_size = utils.get_max_seq_length(hf_config) + args = AsyncEngineArgs( + model=model_desc.model_id_or_path, + trust_remote_code=model_config.trust_remote_code, + device=infer_conf.device, + dtype=dtype, + disable_log_requests=True, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=infer_conf.vllm.max_batched_tokens, + quantization="ns", + block_size=ctx_size, + max_model_len=ctx_size, + ) + else: + args = AsyncEngineArgs( + model=model_desc.model_id_or_path, + trust_remote_code=model_config.trust_remote_code, + device=infer_conf.device, + dtype=dtype, + disable_log_requests=True, + max_num_seqs=max_num_seqs, + enforce_eager=infer_conf.vllm.enforce_eager, + ) self.engine = AsyncLLMEngine.from_engine_args(args) diff --git a/llm_on_ray/inference/utils.py b/llm_on_ray/inference/utils.py index 6712d7bfb..2aa49aebd 100644 --- a/llm_on_ray/inference/utils.py +++ b/llm_on_ray/inference/utils.py @@ -15,7 +15,7 @@ # import os import pathlib -from transformers import StoppingCriteria, TextStreamer +from transformers import StoppingCriteria, TextStreamer, AutoConfig from ray.util.queue import Queue import torch from typing import Dict, Any, List, Optional, Union @@ -41,14 +41,17 @@ def get_deployment_actor_options(infer_conf: InferenceConfig): metadata_thp:auto,dirty_decay_ms:9000000000,muzzy_decay_ms:9000000000", } runtime_env: Dict[str, Any] = {_ray_env_key: {}} + ray_actor_options: Dict[str, Any] = {"runtime_env": runtime_env} + if infer_conf.ipex.enabled: runtime_env[_ray_env_key].update(_predictor_runtime_env_ipex) if infer_conf.deepspeed: runtime_env[_ray_env_key]["DS_ACCELERATOR"] = infer_conf.device if infer_conf.vllm.enabled: runtime_env[_ray_env_key]["OMP_PROC_BIND"] = "true" + if infer_conf.vllm.extension == "ns": + ray_actor_options["resources"] = {"inference_engine": 1} # now PredictorDeployment itself is a worker, we should require resources for it - ray_actor_options: Dict[str, Any] = {"runtime_env": runtime_env} if infer_conf.device == "cpu": ray_actor_options["num_cpus"] = infer_conf.cpus_per_worker elif infer_conf.device == "cuda": @@ -62,6 +65,34 @@ def get_deployment_actor_options(infer_conf: InferenceConfig): return ray_actor_options +def get_max_seq_length(config: AutoConfig): + config = config.to_dict() + # chatglm2, bloom, chatglm3 + if "seq_length" in config: + return config["seq_length"] + # qwen2, llama-2, llama, dolly, gptneox, qwen, qwen1.5, opt, phi + if "max_position_embeddings" in config: + return config["max_position_embeddings"] + # baichuan, baichuan2 + if "model_max_length" in config: + return config["model_max_length"] + # gptj + if "n_positions" in config: + return config["n_positions"] + # mpt + if "max_seq_len" in config: + return config["max_seq_len"] + # chatglm + if "max_sequence_length" in config: + return config["max_sequence_length"] + # whisper + if "max_length" in config: + return config["max_length"] + + print("Not found max seq length, setting to default 512") + return 512 + + class RayTextIteratorStreamer(TextStreamer): def __init__( self, diff --git a/pyproject.toml b/pyproject.toml index 5a8e89306..c7066b5f5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,13 @@ cpu = [ "oneccl_bind_pt==2.2.0" ] +vllm-cpu = [ + "pydantic==2.7.1", + "pydantic_core==2.18.2", + "transformers", + "torch" +] + gpu = [ "transformers>=4.35.0", "torch==2.1.0a0", diff --git a/vllm-ext/MANIFEST.in b/vllm-ext/MANIFEST.in new file mode 100644 index 000000000..740514b21 --- /dev/null +++ b/vllm-ext/MANIFEST.in @@ -0,0 +1,6 @@ +# graft vllm +# exclude all pyc and pyo +prune */__pycache__ +global-exclude *.pyc *.pyo *.pyd *.swp *.bak +# inlucde python files under vllm +recursive-include vllm * diff --git a/vllm-ext/README.md b/vllm-ext/README.md new file mode 100644 index 000000000..90617cfb8 --- /dev/null +++ b/vllm-ext/README.md @@ -0,0 +1,21 @@ +# Inference-Engine + +To further speed up quantized model inference on Intel CPU, we extend vLLM to run the model decoding in own own inference engine, which is based on [https://github.com/intel/neural-speed](neural-speed). +Neural Speed is an innovative library designed to support the efficient inference of large language models (LLMs) on Intel platforms through the state-of-the-art (SOTA) low-bit quantization powered by +[https://github.com/intel/neural-compressor](Intel Neural Compressor). The work is inspired by [https://github.com/ggerganov/llama.cpp](llama.cpp) and further optimized for Intel platforms with our +innovations in [https://arxiv.org/abs/2311.00502](NeurIPS' 2023). + +You need to first install llm-on-ray with "vllm-cpu" extra. + +```bash +pip install .[vllm-cpu] --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/ +``` + +Then, install the vLLM extension and the inference engine. +```bash +cd vllm-ext +pip install . + +``` + +You can check [../docs/vllm.md](vLLM Doc) for more details. diff --git a/vllm-ext/inference_engine/cpp/CMakeLists.txt b/vllm-ext/inference_engine/cpp/CMakeLists.txt new file mode 100644 index 000000000..4ec9857d3 --- /dev/null +++ b/vllm-ext/inference_engine/cpp/CMakeLists.txt @@ -0,0 +1,134 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason +project("inference_engine" C CXX) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +# +# Option list +# + +# general +option(IE_STATIC "inference_engine: static link libraries" OFF) +option(IE_NATIVE "inference_engine: enable -march=native flag" OFF) +option(IE_LTO "inference_engine: enable link time optimization" OFF) + +# debug +option(IE_ALL_WARNINGS "inference_engine: enable all compiler warnings" ON) +option(IE_ALL_WARNINGS_3RD_PARTY "inference_engine: enable all compiler warnings in 3rd party libs" OFF) +option(IE_GPROF "inference_engine: enable gprof" OFF) + +# sanitizers +option(IE_SANITIZE_THREAD "inference_engine: enable thread sanitizer" OFF) +option(IE_SANITIZE_ADDRESS "inference_engine: enable address sanitizer" OFF) +option(IE_SANITIZE_UNDEFINED "inference_engine: enable undefined sanitizer" OFF) + +# instruction set specific +option(IE_AVX "inference_engine: enable AVX" ON) +option(IE_AVX2 "inference_engine: enable AVX2" ON) +option(IE_F16C "inference_engine: enable F16C" ON) +option(IE_AVX512 "inference_engine: enable AVX512" ON) +option(IE_AVX512_VBMI "inference_engine: enable AVX512-VBMI" ON) +option(IE_AVX512_VNNI "inference_engine: enable AVX512-VNNI" ON) +option(IE_FMA "inference_engine: enable FMA" ON) +option(IE_AMX "inference_engine: enable AMX" ON) +option(IE_USE_OPENMP "inference_engine: use OpenMP thread pool." ON) + +option(IE_BUILD_TESTS "inference_engine: build tests" ${IE_STANDALONE}) +option(IE_BUILD_EXAMPLES "inference_engine: build examples" ${IE_STANDALONE}) +option(IE_USE_CLANG_TIDY "inference_engine: clang-tidy check" OFF) + + +if(IE_BUILD_TESTS) + add_compile_definitions(IE_BUILD_TESTS) +endif() +option(IE_PROFILING "inference_engine: use Profiling" OFF) +if (IE_PROFILING) + add_compile_definitions(IE_PERF) +endif() +option(IE_BEAM_SEARCH_VERBOSE "inference_engine: print beam search processing log" OFF) +option(IE_GELU_VEC "inference_engine: enable vec in gelu" ON) +if (IE_GELU_VEC) + add_compile_definitions(IE_GELU_USE_VEC) +endif() +option(IE_SIMD_VEC_DOT_F16 "inference_engine: enable vec_dot_fp16 SIMD optimization" ON) +option(BUILD_SHARED_LIBS "If build as shared libs" ON) + +if (IE_SIMD_VEC_DOT_F16) + add_compile_definitions(IE_SIMD_VEC_DOT_F16) +endif() + +if(IE_BUILD_TESTS) + enable_testing() +endif() + +set(CMAKE_INTERPROCEDURAL_OPTIMIZATION OFF) # default to false so that pybind11 will not try to use IPO +if (IE_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +if (NOT MSVC) + if (IE_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (IE_GPROF) + add_compile_options(-pg) + endif() + if (IE_NATIVE) + add_compile_options(-march=native) + endif() +endif() + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) + +if(IE_USE_OPENMP) + include(FindOpenMP) + # compile BesTLA's OMPTheading class, then it can be used in ne_layers + set(BTLA_ENABLE_OPENMP ON CACHE BOOL "BesTLA enable compiling OpenMP threading") + add_compile_definitions(IE_USE_OPENMP) +endif() + +add_subdirectory(bestla) + +include(cmake/ISA.cmake) +include(cmake/Common.cmake) +include(cmake/ClangTidy.cmake) + +set(COMMON_HEADER_DIRS ./) +include_directories(${COMMON_HEADER_DIRS}) + + +add_subdirectory(vectors) + +add_subdirectory(core) + +add_subdirectory(models) diff --git a/vllm-ext/inference_engine/cpp/bestla/CMakeLists.txt b/vllm-ext/inference_engine/cpp/bestla/CMakeLists.txt new file mode 100644 index 000000000..e9e35b2f3 --- /dev/null +++ b/vllm-ext/inference_engine/cpp/bestla/CMakeLists.txt @@ -0,0 +1,118 @@ +cmake_minimum_required(VERSION 3.12) + +project(bestla LANGUAGES CXX VERSION 0.1.0) +file(GLOB headers ${PROJECT_NAME}/*.h ${PROJECT_NAME}/*.hpp) +file(GLOB xbyak_headers ${PROJECT_NAME}/xbyak/*.h ${PROJECT_NAME}/xbyak/*.hpp) + +option(BTLA_ENABLE_OPENMP "Compile OpenMP thread pool if OMP can be found" OFF) + +option(BTLA_UT_ALL "Enable all unit tests" OFF) +option(BTLA_UT_DEBUG "Enable debug unit tests" OFF) +option(BTLA_UT_EPILOGUE "Enable unit test for epilogue" OFF) +option(BTLA_UT_PROLOGUE_A "Enable unit test for activation prologue" OFF) +option(BTLA_UT_PROLOGUE_B "Enable unit test for weight prologue" OFF) +option(BTLA_UT_GEMM "Enable unit test for micro gemm kernels" OFF) +option(BTLA_UT_WRAPPER "Enable unit test for parallel gemms" OFF) +option(BTLA_UT_PARALLEL "Enable unit test for parallel set" OFF) +option(BTLA_UT_KERNEL_JIT "Enable unit test for jit kernels" OFF) +option(BTLA_UT_KERNEL_INTRIN "Enable unit test for intrinsic kernels" OFF) +option(BTLA_UT_KERNEL_WRAPPER "Enable unit test for runtime ISA kernels" OFF) +option(BTLA_UT_NOASAN "Disable sanitize" OFF) +option(BTLA_UT_BENCHMARK "Benchmark ON may take a long time to finish all tests" OFF) +option(BTLA_UT_OPENMP "Use OpenMP for UT tests" OFF) + +add_library(${PROJECT_NAME} INTERFACE) +add_library(inference_engine::${PROJECT_NAME} ALIAS ${PROJECT_NAME}) +target_include_directories( + ${PROJECT_NAME} INTERFACE + "$" + "$" +) + + +if(BTLA_ENABLE_OPENMP) + message(STATUS "BesTLA enable OpenMP ThreadPool") + target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_ENABLE_OPENMP) +endif(BTLA_ENABLE_OPENMP) + +if(WIN32) + target_compile_definitions(${PROJECT_NAME} INTERFACE _CRT_SECURE_NO_WARNINGS NOMINMAX) +endif(WIN32) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +target_compile_features(${PROJECT_NAME} INTERFACE cxx_std_17) + +if(BTLA_UT_ALL) + set(BTLA_UT_EPILOGUE ON) + set(BTLA_UT_PROLOGUE_A ON) + set(BTLA_UT_PROLOGUE_B ON) + set(BTLA_UT_GEMM ON) + set(BTLA_UT_WRAPPER ON) + set(BTLA_UT_PARALLEL ON) + set(BTLA_UT_KERNEL_JIT ON) + set(BTLA_UT_KERNEL_INTRIN ON) + set(BTLA_UT_KERNEL_WRAPPER ON) +endif(BTLA_UT_ALL) + +set(UT_BUILD FALSE) +if(BTLA_UT_DEBUG OR BTLA_UT_PROLOGUE_A OR BTLA_UT_PROLOGUE_B OR BTLA_UT_EPILOGUE OR BTLA_UT_GEMM +OR BTLA_UT_WRAPPER OR BTLA_UT_PARALLEL OR BTLA_UT_KERNEL_JIT OR BTLA_UT_KERNEL_INTRIN +OR BTLA_UT_KERNEL_WRAPPER) + set(UT_BUILD TRUE) +endif() + +function(add_ut_flag UT_OPTION) + if(${${UT_OPTION}}) + # target_compile_definitions(${PROJECT_NAME}_ut PRIVATE ${UT_OPTION}) + add_compile_definitions(${UT_OPTION}) + endif() +endfunction() + +if(UT_BUILD) + file(GLOB srcs ${PROJECT_NAME}/ut/*.cc ${PROJECT_NAME}/ut/*.cpp) #compile everything even run parts of UTs + list(REMOVE_ITEM srcs ${CMAKE_CURRENT_SOURCE_DIR}/${PROJECT_NAME}/ut/bestla_benchmark.cpp) + file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) + include_directories(${PROJECT_NAME}) + add_executable(${PROJECT_NAME}_ut ${srcs} ${headers} ${ut_headers}) + if(BTLA_UT_OPENMP) + include(FindOpenMP) + target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP) + target_link_libraries(${PROJECT_NAME}_ut PRIVATE OpenMP::OpenMP_CXX) + endif() + if(NOT WIN32) + if(NOT BTLA_UT_NOASAN) + target_compile_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address) + target_link_options(${PROJECT_NAME}_ut PRIVATE -fsanitize=address) + endif() + target_link_options(${PROJECT_NAME}_ut PRIVATE -lpthread) + endif() + + add_ut_flag(BTLA_UT_DEBUG) + add_ut_flag(BTLA_UT_EPILOGUE) + add_ut_flag(BTLA_UT_PROLOGUE_A) + add_ut_flag(BTLA_UT_PROLOGUE_B) + add_ut_flag(BTLA_UT_GEMM) + add_ut_flag(BTLA_UT_PARALLEL) + add_ut_flag(BTLA_UT_WRAPPER) + add_ut_flag(BTLA_UT_KERNEL_INTRIN) + add_ut_flag(BTLA_UT_KERNEL_JIT) + add_ut_flag(BTLA_UT_KERNEL_WRAPPER) + target_link_libraries(${PROJECT_NAME}_ut PRIVATE ${PROJECT_NAME}) +endif(UT_BUILD) + +if(BTLA_UT_BENCHMARK) + file(GLOB srcs ${PROJECT_NAME}/ut/bestla_benchmark.cpp) #compile everything even run parts of UTs + file(GLOB ut_headers ${PROJECT_NAME}/ut/*.h) + include_directories(${PROJECT_NAME}) + add_executable(${PROJECT_NAME}_benchmark ${srcs} ${headers} ${ut_headers}) + if(BTLA_UT_OPENMP) + include(FindOpenMP) + target_compile_definitions(${PROJECT_NAME} INTERFACE BTLA_USE_OPENMP) + target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE OpenMP::OpenMP_CXX) + endif() + if(NOT WIN32) + target_link_options(${PROJECT_NAME}_benchmark PRIVATE -lpthread) + endif() + target_link_libraries(${PROJECT_NAME}_benchmark PRIVATE ${PROJECT_NAME}) +endif(BTLA_UT_BENCHMARK) diff --git a/vllm-ext/inference_engine/cpp/bestla/README.md b/vllm-ext/inference_engine/cpp/bestla/README.md new file mode 100644 index 000000000..8b46f5a9b --- /dev/null +++ b/vllm-ext/inference_engine/cpp/bestla/README.md @@ -0,0 +1,57 @@ +# BesTLA +BesTLA is a lightweight, header-only acceleration library for high-performance GEMM and related computations on Intel platform. Inspired by Cutlass, it provides high-level template class abstractions for various elements required for computation, and allows flexible kernel construction through template combinations to meet specific needs, maximizing the reuse of existing template classes. Users can also develop custom template classes to expand BesTLA’s computational capabilities. BesTLA includes several different types of template classes, specifically: + +- `Launcher`: Schedules computation-related template classes, allowing users to specify their own computation-related template classes, including GemmCore, Prologue, and Epilogue. +- `Parallel`: Specifies data splitting strategy for task distribution among different cores. BesTLA’s default Parallel template class adopts an L2-cache-fusion concept, i.e., each core tries to temporarily store the data it processes in its L2-cache during each round of gemm-tile computation. +- `GemmCore`: A computation-related template class that provides a micro-kernel for performing a tile gemm computation with a specific ISA. It is the most important template class in BesTLA. Currently, GemmCore supports the following ISAs: + - AVX2 + - AVX_VNNI + - AVX512F + - AVX512_VNNI + - AMX_BF16 + - AMX_INT8 + - AVX512_FP16 +- `Prologue`: A computation-related template class that preprocesses (such as data type conversion/padding) input data to meet GemmCore’s input data requirements. +- `Epilogue`: A computation-related template class that post-processes (such as eltwiseop-fusion) the results of gemm-core computations to expand BesTLA’s application scenarios. +BesTLA supports users to configure thread libraries for multi-core parallelism (e.g. openMP), greatly facilitating user integrate BesTLA into their own projects. BesTLA also supports specifying the number of computing-threads at runtime, making the allocation of computing resources more flexible. + +# Highlights +## Weight-only +BesTLA provides weight-only linear computational capabilities for LLM inference. We provide a series of Prologues for quantize/compress/serialize/deserialize fp32 weights in different ways. Specifically, the weight-only-quantization configs we support are given in the table below: + +| Weight dtype | Compute dtype | Scale dtype | algo | +| ---------------------- | :----------------: | :---------------: | :--------: | +| INT8 | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| INT4 (CLIP, FULLRANGE) | INT8 / BF16 / FP32 | BF16 / FP32 | sym / asym | +| FP8 (E4M3, E5M2) | BF16 / FP32 | FP32 / FP8 (E8M0) | sym | +| FP4 (E2M1) | BF16 / FP32 | BF16 / FP32 | sym | +| NF4 | BF16 / FP32 | BF16 / FP32 | sym | + +Config description of the table: +| Config | Description | +| ------------- | --------------------------------------------------- | +| Weight dtype | Data type of quantized weight | +| Compute dtype | Data type of BesTLA internal Gemm computation | +| Scale dtype | Data type of scales | +| alg | Quantization algorithm to use(symmetric/asymmetric) | + + +## Postop-fusion +BesTLA provides assembly-level postop-fusion through epilogue to minimize the overhead caused by data movement. Specifically, we support the following postop-fusions: + +- GELU +- SWISH +- RELU +- EXP +- TANH +## Compilation Requirements and Usage +Compile: + +- GCC version >=8.5.0 +- CMake version >=3.5 + +Usage: +```cmake +add_subdirectory(bestla) +target_link_libraries("${YOUR_PROJECT}" bestla::bestla) +``` diff --git a/vllm-ext/inference_engine/cpp/bestla/bestla/bestla.h b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla.h new file mode 100644 index 000000000..512d550fb --- /dev/null +++ b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla.h @@ -0,0 +1,92 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +enum class BTLA_CODE { + Success = 0, + InvalidParam = 1, + InvalidISA = 2, + RuntimeError = 4, + NotSupport = 8, +}; +enum class BTLA_ISA : uint8_t { + NoSIMD = 0, + AVX, + AVX2, + AVX_VNNI, + AVX512F, + AVX512_VNNI, + AMX_BF16, + AMX_INT8, + AVX512_FP16, + AVX512_BF16, + ISA_COUNT, +}; +enum class BTLA_DTYPE : uint32_t { + EleBitsMask = 0xff, + EleBitsShift = 0, + EleBitsUndef = 0, + EleBits2 = 2, + EleBits3 = 3, + EleBits4 = 4, + EleBits8 = 8, + EleBits16 = 16, + EleBits32 = 32, + EleBits64 = 64, + TypeMask = 0xff00, + TypeShift = 8, + TypeFloat = 0 << TypeShift, + TypeInt = 1 << TypeShift, + SubTypeMask = 0xff0000, + SubTypeShift = 16, + SubType0 = 0 << SubTypeShift, + SubType1 = 1 << SubTypeShift, + SubType2 = 2 << SubTypeShift, + SubType3 = 3 << SubTypeShift, + SubType4 = 4 << SubTypeShift, + F64 = EleBits64 | TypeFloat, + F32 = EleBits32 | TypeFloat, + F16 = EleBits16 | TypeFloat, + BF16 = EleBits16 | TypeFloat | SubType1, + F8_E4M3 = EleBits8 | TypeFloat, + F8_E5M2 = EleBits8 | TypeFloat | SubType1, + F8_E3M4 = EleBits8 | TypeFloat | SubType2, + F8_E8M0 = EleBits8 | TypeFloat | SubType3, + DQ8_BNB = EleBits8 | TypeFloat | SubType4, + S8 = EleBits8 | TypeInt, + U8 = EleBits8 | TypeInt | SubType1, + S2_CLIP = EleBits2 | TypeInt, + S3_CLIP = EleBits3 | TypeInt, + S4_CLIP = EleBits4 | TypeInt, + F4_E2M1 = EleBits4 | TypeFloat, + F4_BNB = EleBits4 | TypeFloat | SubType1, + F4_NF4 = EleBits4 | TypeFloat | SubType2, + S32 = EleBits32 | TypeInt, + U32 = EleBits32 | TypeInt | SubType1, +}; + +enum class BTLA_ELTWISEOP { GELU, SWISH, TANH, EXP, LOW_PRECISION_EXP, RELU, LINEAR }; + +enum class BTLA_PROLOGUEB_IDS : uint32_t { + Undef = (uint32_t)-1, + Begin = 0, + NormalBegin = Begin, + WeightPack = NormalBegin, + NormalEnd, + KBlockBegin = NormalEnd, + WeightKBlockNInteger = KBlockBegin, + WeightKBlockNFloat, + KBlockEnd, + End, +}; diff --git a/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_device.h b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_device.h new file mode 100644 index 000000000..3cac2f397 --- /dev/null +++ b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_device.h @@ -0,0 +1,532 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include "bestla.h" +#include "xbyak/xbyak_util.h" +#include "bestla_utils.h" +#ifdef _WIN32 +#include +#else +#include +#endif + +#define FIXED_CACHE_SIZE ((1 << 20) - (128 << 10)) +#define FIXED_CACHE 1 + +namespace bestla { + +namespace device { + +struct X64_ISA { + int64_t MMX : 1; // 0 + int64_t SSE : 1; // 1 + int64_t SSE2 : 1; // 2 + int64_t SSE3 : 1; // 3 + int64_t SSSE3 : 1; // 4 + int64_t SSE41 : 1; // 5 + int64_t SSE42 : 1; // 6 + int64_t AVX : 1; // 7 + int64_t F16C : 1; // 8 + int64_t FMA : 1; // 9 + int64_t AVX2 : 1; // 10 + int64_t AVX_VNNI : 1; // 11 + int64_t AVX_VNNI_INT8 : 1; // 12 + int64_t AVX_NE_CONVERT : 1; // 13 + int64_t AVX_IFMA : 1; // 14 + int64_t AVX512F : 1; // 15 + int64_t AVX512BW : 1; // 16 + int64_t AVX512CD : 1; // 17 + int64_t AVX512DQ : 1; // 18 + int64_t AVX512ER : 1; // 19 + int64_t AVX512IFMA52 : 1; // 20 + int64_t AVX512PF : 1; // 21 + int64_t AVX512VL : 1; // 22 + int64_t AVX512VPOPCNTDQ : 1; // 23 + int64_t AVX512_4FMAPS : 1; // 24 + int64_t AVX512_4VNNIW : 1; // 25 + int64_t AVX512_BF16 : 1; // 26 + int64_t AVX512_BITALG : 1; // 27 + int64_t AVX512_VBMI : 1; // 28 + int64_t AVX512_VBMI2 : 1; // 29 + int64_t AVX512_VNNI : 1; // 30 + int64_t AVX512_VP2INTERSECT : 1; // 31 + int64_t AVX512_FP16 : 1; // 32 + int64_t AMX_TILE : 1; // 33 + int64_t AMX_BF16 : 1; // 34 + int64_t AMX_INT8 : 1; // 35 + int64_t AMX_FP16 : 1; // 36 + int64_t AMX_COMPLEX : 1; // 37 + int64_t reserved : (64 - 38); +}; + +class AVX2_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 0; + static constexpr bool AVX512BW = 0; + static constexpr bool AVX512CD = 0; + static constexpr bool AVX512DQ = 0; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 0; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 0; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class AVX512_VNNI_Default { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 0; + static constexpr bool AMX_BF16 = 0; + static constexpr bool AMX_INT8 = 0; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +class SapphireRapids { + public: + static constexpr bool MMX = 1; + static constexpr bool SSE = 1; + static constexpr bool SSE2 = 1; + static constexpr bool SSE3 = 1; + static constexpr bool SSSE3 = 1; + static constexpr bool SSE41 = 1; + static constexpr bool SSE42 = 1; + static constexpr bool AVX = 1; + static constexpr bool F16C = 1; + static constexpr bool FMA = 1; + static constexpr bool AVX2 = 1; + static constexpr bool AVX_VNNI = 0; + static constexpr bool AVX_VNNI_INT8 = 0; + static constexpr bool AVX_NE_CONVERT = 0; + static constexpr bool AVX_IFMA = 0; + static constexpr bool AVX512F = 1; + static constexpr bool AVX512BW = 1; + static constexpr bool AVX512CD = 1; + static constexpr bool AVX512DQ = 1; + static constexpr bool AVX512ER = 0; + static constexpr bool AVX512IFMA52 = 0; + static constexpr bool AVX512PF = 0; + static constexpr bool AVX512VL = 1; + static constexpr bool AVX512VPOPCNTDQ = 0; + static constexpr bool AVX512_4FMAPS = 0; + static constexpr bool AVX512_4VNNIW = 0; + static constexpr bool AVX512_BF16 = 0; + static constexpr bool AVX512_BITALG = 0; + static constexpr bool AVX512_VBMI = 0; + static constexpr bool AVX512_VBMI2 = 0; + static constexpr bool AVX512_VNNI = 1; + static constexpr bool AVX512_VP2INTERSECT = 0; + static constexpr bool AVX512_FP16 = 0; + static constexpr bool AMX_TILE = 1; + static constexpr bool AMX_BF16 = 1; + static constexpr bool AMX_INT8 = 1; + static constexpr bool AMX_FP16 = 0; + static constexpr bool AMX_COMPLEX = 0; +}; + +template +class isa_base { + public: + static bool constexpr avx = ISA_T >= BTLA_ISA::AVX; + static bool constexpr avx2 = ISA_T >= BTLA_ISA::AVX2; + static bool constexpr avx512f = ISA_T >= BTLA_ISA::AVX512F; + static bool constexpr avx512_vnni = ISA_T >= BTLA_ISA::AVX512_VNNI; + static bool constexpr avx512_fp16 = ISA_T >= BTLA_ISA::AVX512_FP16; + static bool constexpr amx_bf16 = ISA_T >= BTLA_ISA::AMX_BF16; + static bool constexpr amx_int8 = ISA_T >= BTLA_ISA::AMX_INT8; +}; + +class CpuDevice { + public: + inline int getThreads() { return numthreads; } + inline int getCores() { return numcores; } + inline uint32_t getL3CacheSize() { return L3Cache; } + inline uint32_t getL2CacheSize() { return L2Cache; } + inline uint32_t getL1CacheSize() { return L1Cache; } + inline uint32_t getL2CacheSize_E() { return E_L2Cache; } + inline uint32_t getL1CacheSize_E() { return E_L1Cache; } + inline bool AVX() { return mHasAVX; } + inline bool AVX2() { return mHasAVX2; } + inline bool AVX_VNNI() { return mHasAVX_VNNI; } + inline bool AVX512F() { return mHasAVX512F; } + inline bool AVX512_VNNI() { return mHasAVX512_VNNI; } + inline bool AMX_INT8() { return mHasAMX_INT8; } + inline bool AMX_BF16() { return mHasAMX_BF16; } + inline bool AVX512_BF16() { return mHasAVX512_BF16; } + inline bool AVX512_FP16() { return mHasAVX512_FP16; } + inline float* const getPE() { return PE; } + inline size_t getPcoreNum() { return P_core.size(); } + inline size_t getEcoreNum() { return E_core.size(); } + inline size_t getSMTcoreNum() { return SMT_core.size(); } + inline int* getPCores() { return P_core.data(); } + inline int* getECores() { return E_core.data(); } + inline int* getSMTCores() { return SMT_core.data(); } +#define ADD_FLAG(isa) mHas##isa = _cpu.has(_cpu.t##isa) + CpuDevice() { + static Xbyak::util::Cpu _cpu; + ADD_FLAG(AVX); + ADD_FLAG(AVX2); + ADD_FLAG(AVX512F); + ADD_FLAG(AVX512_VNNI); + ADD_FLAG(AVX_VNNI); + ADD_FLAG(AMX_BF16); + ADD_FLAG(AMX_INT8); + ADD_FLAG(AVX512_BF16); + ADD_FLAG(AVX512_FP16); + numcores = _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::CoreLevel); + if (mHasAMX_BF16 || mHasAMX_INT8) { + utils::request_perm_xtile_data(); + } + static bool p = false; + { + uint32_t tmp[4]; + _cpu.getCpuid(7, tmp); + if (tmp[3] & (1U << 15)) mHybrid = true; + if (p) printf("!!!Hybrid:%d\t%x\t%x\t%x\t%x!!!\n", mHybrid, tmp[0], tmp[1], tmp[2], tmp[3]); + } + int total_cores = numcores * _cpu.getNumCores(Xbyak::util::IntelCpuTopologyLevel::SmtLevel); + if (total_cores <= 16) mClient = true; + if (mHybrid) { + mClient = true; + std::vector core_type(total_cores), core_id(total_cores), L1(total_cores), L2(total_cores); + std::map core_id_count; + + { + // classify E-core / LPE-core and P-core / smt + std::vector thdset(total_cores); + for (size_t i = 0; i < total_cores; i++) { + thdset[i] = std::thread( + [&](int tidx) { + core_bond(tidx); + Xbyak::util::Cpu cpu; + L1[tidx] = cpu.getDataCacheSize(0); + L2[tidx] = cpu.getDataCacheSize(1); + if (isEcore(cpu)) + core_type[tidx] = 1; + else + core_type[tidx] = 2; + core_id[tidx] = getCoreId(cpu); + }, + int(i)); + } + for (size_t i = 0; i < total_cores; i++) { + thdset[i].join(); + core_id_count[core_id[i]] = core_id_count[core_id[i]] + 1; + } + if (p) { + for (int i = 0; i < total_cores; i++) printf("%d %d\n", core_type[i], core_id[i]); + for (auto& kv : core_id_count) printf("%d,%d\n", kv.first, kv.second); + } + for (int i = 0; i < total_cores; i++) { + if (core_type[i] == 2) { + if (core_id_count[core_id[i]] > 0) { + P_core.push_back(i); + core_id_count[core_id[i]] = 0; + } else { + SMT_core.push_back(i); + } + } else { + if (core_id_count[core_id[i]] == 4) E_core.push_back(i); + } + } + if (p) { + printf("Pcore:"); + for (auto& i : P_core) printf("%d,", i); + printf("\nEcore:"); + for (auto& i : E_core) printf("%d,", i); + printf("\nsmt:"); + for (auto& i : SMT_core) printf("%d,", i); + printf("\n"); + } + mHybrid = !(E_core.empty() || P_core.empty()); // in case of bond core by external + if (!E_core.empty()) { + E_L1Cache = L1[E_core[0]]; + E_L2Cache = L2[E_core[0]] / 4; + }; + if (!P_core.empty()) { + L1Cache = L1[P_core[0]]; + L2Cache = L2[P_core[0]]; + } + } + numcores = static_cast(P_core.size() + E_core.size()); + numthreads = static_cast(P_core.size() + E_core.size() + SMT_core.size()); + + { + // set PE + uint32_t tmp[4]; + _cpu.getCpuid(1, tmp); + if (p) printf("!!!\t%x\t%x\t%x\t%x!!!\n", tmp[0], tmp[1], tmp[2], tmp[3]); + const int famliy = (tmp[0] >> 8) & ((1u << 4) - 1); // cpu.extractBit(a[0], 8, 11); + const int extendedModel = (tmp[0] >> 16) & ((1u << 4) - 1); // cpu.extractBit(a[0], 16, 24); + { + for (int i = 0; i < int(BTLA_ISA::ISA_COUNT); i++) PE[i] = 1.0f; + // CPU identification refer to: https://en.wikichip.org/wiki/intel/cpuid + if (famliy == 6) switch (extendedModel) { + case 9: // ALD + PE[int(BTLA_ISA::AVX2)] = 3.0f; + PE[int(BTLA_ISA::AVX_VNNI)] = 5.0f; + PE[int(BTLA_ISA::NoSIMD)] = 3.5f; + break; + case 10: // MTL + PE[int(BTLA_ISA::AVX2)] = 2.2f; + PE[int(BTLA_ISA::AVX_VNNI)] = 3.0f; + PE[int(BTLA_ISA::NoSIMD)] = 3.0f; + break; + case 11: // RPL + PE[int(BTLA_ISA::AVX2)] = 1.8f; + PE[int(BTLA_ISA::AVX_VNNI)] = 2.6f; + PE[int(BTLA_ISA::NoSIMD)] = 3.0f; + break; + } + } + } + } else { + L1Cache = _cpu.getDataCacheSize(0); + L2Cache = _cpu.getDataCacheSize(1); + numthreads = numcores; + } + L3Cache = _cpu.getDataCacheSize(2); +#if FIXED_CACHE + L2Cache = L2Cache >= FIXED_CACHE_SIZE ? FIXED_CACHE_SIZE : L2Cache; + E_L2Cache = E_L2Cache >= FIXED_CACHE_SIZE ? FIXED_CACHE_SIZE : E_L2Cache; +#endif + } + + static CpuDevice* getInstance() { + static CpuDevice instance; + return &instance; + } + + void print() { + printf( + "AVX:%d AVX2:%d AVX512F:%d AVX_VNNI:%d AVX512_VNNI:%d AMX_INT8:%d AMX_BF16:%d AVX512_BF16:%d AVX512_FP16:%d\n", + mHasAVX, mHasAVX2, mHasAVX512F, mHasAVX_VNNI, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512_BF16, + mHasAVX512_FP16); + } +#undef ADD_FLAG + + static bool isEcore() { + Xbyak::util::Cpu cpu; + uint32_t tmp[4]; + cpu.getCpuid(0x1A, tmp); + int core_type = (tmp[0] >> 24) & ((1u << 8) - 1); // cpu.extractBit(a[0], 24, 31); + switch (core_type) { + case 32: + // printf("Atom\n"); + return true; // E-core or LPE-core + break; + case 64: + // printf("Core\n"); + return false; // P-core + break; + default: + // printf("No hyper\n"); + return false; + break; + } + return false; + } + + int getCoreId(Xbyak::util::Cpu& cpu) { + uint32_t tmp[4]; + cpu.getCpuidEx(0x1F, 1, tmp); // sub-leaf 1 is core domain + // printf("!!!%x\t%x\t%x\t%x!!!\n", tmp[0], tmp[1], tmp[2], tmp[3]); + if (tmp[0] != 0 && tmp[1] != 0) + return tmp[3] >> 3; // tmp[3] is APIC + else + return tmp[3]; + } + + bool isEcore(Xbyak::util::Cpu& cpu) { + uint32_t tmp[4]; + cpu.getCpuid(0x1A, tmp); + int core_type = (tmp[0] >> 24) & ((1u << 7) - 1); // cpu.extractBit(a[0], 24, 31); + switch (core_type) { + case 32: + // printf("Atom\n"); + return true; // E-core or LPE-core + break; + case 64: + // printf("Core\n"); + return false; // P-core + break; + default: + // printf("No hyper\n"); + return false; + break; + } + return false; + } + static void core_bond(int core) { +#ifdef _WIN32 + SetThreadAffinityMask(GetCurrentThread(), 1LL << core); +#else + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core, &cpuset); + int s = sched_setaffinity(0, sizeof(cpu_set_t), &cpuset); + if (s != 0) printf("ERROR\n"); +#endif + } + + static void core_bond(std::thread& thread, int core) { +#ifdef _WIN32 + HANDLE handle = thread.native_handle(); + SetThreadAffinityMask(handle, 1LL << core); +#else + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(core, &cpuset); + pthread_t pt = thread.native_handle(); + int s = pthread_setaffinity_np(pt, sizeof(cpuset), &cpuset); + if (s != 0) printf("ERROR\n"); +#endif + } + + bool isHybrid() { return mHybrid; } + bool isClient() { return mClient; } + + protected: + uint32_t L2Cache, L1Cache, L3Cache; + bool mHybrid = false, mClient = false; + bool mHasAVX2, mHasAVX_VNNI, mHasAVX, mHasAVX512_VNNI, mHasAMX_INT8, mHasAMX_BF16, mHasAVX512F, mHasAVX512_BF16, + mHasAVX512_FP16; + int numcores; + int numthreads; + std::vector P_core, E_core, SMT_core; + uint32_t E_L2Cache, E_L1Cache; + float PE[int(BTLA_ISA::ISA_COUNT)]; +}; + +#define GetCPUDevice() auto _cd = bestla::device::CpuDevice::getInstance(); + +class CpuRuntime { + public: + CpuRuntime() = default; + static CpuRuntime& getInstance(int thread) { + static std::map instances; + if (instances.count(thread) == 0) instances[thread] = CpuRuntime(thread); + return instances[thread]; + } + + inline float getPE(const BTLA_ISA isa) { + // printf("GET:%d\t%f\n",int(isa), *cur_PE); + return PE[int(isa)] * P_core_num / E_core_num; + } + + inline void adjustPE(const BTLA_ISA isa, const float PE_) { + // printf("Adjust:%d,%f\n",int(isa),PE_); + PE[int(isa)] = PE[int(isa)] * PE_ * 0.7 + PE[int(isa)] * 0.3; + } + + size_t mL2Cache, mL1Cache, mL2Cache_P = 0, mL1Cache_P = 0, mL2Cache_E = 0, mL1Cache_E = 0; + int P_core_num = 0, E_core_num = 0; + bool mHybrid = false; + + private: + CpuRuntime(int thread) { + GetCPUDevice(); + mL2Cache = _cd->getL2CacheSize(); + mL1Cache = _cd->getL1CacheSize(); + maxThreads = _cd->getThreads(); + mHybrid = false; + if (_cd->isClient() && thread > _cd->getPcoreNum()) { + if (thread > _cd->getPcoreNum() + _cd->getEcoreNum()) { + mL1Cache_P = mL1Cache / 2; + mL2Cache_P = mL2Cache / 2; + P_core_num = _cd->getPcoreNum(); + E_core_num = _cd->getEcoreNum(); + } else { + mL1Cache_P = mL1Cache; + mL2Cache_P = mL2Cache; + P_core_num = static_cast(_cd->getPcoreNum()); + E_core_num = thread - P_core_num; + } + if (_cd->isHybrid()) { + mL1Cache_E = _cd->getL1CacheSize_E(); + mL2Cache_E = _cd->getL2CacheSize_E(); + mHybrid = true; + memcpy(PE, _cd->getPE(), int(BTLA_ISA::ISA_COUNT) * sizeof(float)); + } + } + } + float PE[int(BTLA_ISA::ISA_COUNT)]; + int maxThreads; +}; +} // namespace device +} // namespace bestla diff --git a/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_epilogue.h b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_epilogue.h new file mode 100644 index 000000000..e49ad4386 --- /dev/null +++ b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_epilogue.h @@ -0,0 +1,340 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "bestla.h" +#include "bestla_jit.h" +#include "bestla_utils.h" +#include "kernel_wrapper.h" + +namespace bestla { +namespace epilogue { +namespace gemm { + +template +struct ParamAccumulatorWriteBack { + DT* C; + int ldc; + void* elt_const_v; +}; + +template +class AccumulatorWriteBack { + public: + using SType = _SRC_T; + using DType = _DST_T; + using Param = ParamAccumulatorWriteBack; + + static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + if constexpr (std::is_same_v<_SRC_T, DType>) { + if (cacheptr == cptr) { + return BTLA_CODE::Success; + } + } + return kernel::wrapper::Memcpy2D::template forward(cacheptr, cptr, M, N, cachestep, _param.ldc, + _param.elt_const_v); + } +}; + +template +class CustomAccumulatorWriteBackWithEltop { + public: + using Param = ParamAccumulatorWriteBack<_DST_T>; + static BTLA_CODE forward(const _SRC_T* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + if constexpr (std::is_same<_SRC_T, float>::value && std::is_same<_DST_T, float>::value) { + return kernel::wrapper::Memcpy2D::template forward1(cacheptr, cptr, M, N, cachestep, + _param.ldc, _param.elt_const_v); + } else { + assert(false); + } + } +}; +template +using AccumulatorWriteBackFp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackInt32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackBf16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16 = AccumulatorWriteBack; +template +using AccumulatorWriteBackBf16Fp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp16Fp32 = AccumulatorWriteBack; +template +using AccumulatorWriteBackFp32Bf16 = AccumulatorWriteBack; + +template +using AccumulatorWriteBackWithGeluFp32 = CustomAccumulatorWriteBackWithEltop; + +template +using AccumulatorWriteBackWithSwishFp32 = + CustomAccumulatorWriteBackWithEltop; + +template +struct ParamAlphaBetaProcess { + DT *C, *D; + int ldc, ldd; + float alpha, beta; +}; +template +class AlphaBetaProcessFp32 { + public: + using Param = ParamAlphaBetaProcess; + + static BTLA_CODE forward(const float* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto DOffset = M_offset * _param.ldd + N_offset; + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto dptr = _param.D + DOffset; + return kernel::wrapper::AlphaBetaF32F32::template forward(_param.alpha, cacheptr, cachestep, _param.beta, + dptr, _param.ldd, cptr, _param.ldc, M, N); + } +}; + +struct ParamCompFp32BlockEpilogue { + void* scales; + BTLA_DTYPE scaledtype; + int ldsb; + int8_t* zps = nullptr; + float* reduce = nullptr; + int ldra; +}; +template +class CompFp32BlockEpilogue { + public: + using Param = ParamCompFp32BlockEpilogue; + BTLA_CODE forward(const float* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + auto ret = BTLA_CODE::NotSupport; + if (_param.scaledtype == BTLA_DTYPE::F32) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + assert(ret == BTLA_CODE::Success); + if (_param.zps != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::forward_wei( + dstptr, cachestep, M, N, _param.zps + K_offset * _param.ldsb + N_offset, + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, _param.ldra, + _param.reduce + M_offset * _param.ldra + K_offset); + } + assert(ret == BTLA_CODE::Success); + return ret; + } else if (_param.scaledtype == BTLA_DTYPE::BF16) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + if (_param.zps != nullptr) { + assert(0); + } + assert(ret == BTLA_CODE::Success); + return ret; + } else if (_param.scaledtype == BTLA_DTYPE::F8_E8M0) { + ret = kernel::wrapper::CompFp32BlockScale::template forward( + reinterpret_cast(_param.scales) + K_offset * _param.ldsb + N_offset, srcptr, cachestep, dstptr, + cachestep, M, N); + if (_param.zps != nullptr) { + assert(0); + } + } else { + assert(0); + } + return BTLA_CODE::NotSupport; + } +}; + +struct ParamDequantInt32ToFp32 { + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; +}; +template +class DequantInt32ToFp32 { + public: + using Param = ParamDequantInt32ToFp32; + static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, _param.ldsa, + _param.scalesB + N_offset); + } +}; + +struct ParamCompInt8BlockEpilogue { + void* scalesB; + BTLA_DTYPE scaleBdtype; + int ldsb; + float* scalesA; + int ldsa; + // optional if A asym + uint8_t* zpA = nullptr; + void* reduceB = nullptr; + BTLA_DTYPE reduceBdtype = BTLA_DTYPE::F32; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; +}; +template +class CompInt8BlockEpilogue { + public: + using Param = ParamCompInt8BlockEpilogue; + static BTLA_CODE forward(const int32_t* srcptr, float* dstptr, const int cachestep, const int M_offset, const int N_offset, + const int K_offset, const int M, const int N, const Param& _param, void* tmpcache, + size_t cachesize) { + BTLA_CODE ret = BTLA_CODE::NotSupport; + float* scab = nullptr; + size_t ScaleBTmpSize = N * sizeof(float); + size_t ReduceBTmpSize = N * sizeof(float); + assert(cachesize >= (ScaleBTmpSize + ReduceBTmpSize)); + if (_param.scaleBdtype == BTLA_DTYPE::BF16) { + auto scache = reinterpret_cast(tmpcache); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb, scache, 1, N, N, N, + false); + assert(ret == BTLA_CODE::Success); + scab = scache; + } else if (_param.scaleBdtype == BTLA_DTYPE::F32) { + scab = reinterpret_cast(_param.scalesB) + N_offset + K_offset * _param.ldsb; + } + float* redb = nullptr; + if (_param.reduceB) { + if (_param.reduceBdtype == BTLA_DTYPE::BF16) { + auto rcache = reinterpret_cast(reinterpret_cast(tmpcache) + ScaleBTmpSize); + ret = kernel::wrapper::Memcpy2DBf16CvtFp32::template forward( + reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb, rcache, 1, N, N, N, + false); + assert(ret == BTLA_CODE::Success); + redb = rcache; + } else if (_param.reduceBdtype == BTLA_DTYPE::F32) { + redb = reinterpret_cast(_param.reduceB) + N_offset + K_offset * _param.ldsb; + } + } + ret = kernel::wrapper::DequanS32Fp32::template forward( + srcptr, cachestep, reinterpret_cast(const_cast(srcptr)), cachestep, M, N, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, scab); + assert(ret == BTLA_CODE::Success); + ret = kernel::wrapper::AccumulateFp32::template forward(reinterpret_cast(srcptr), cachestep, + dstptr, cachestep, M, N); + assert(ret == BTLA_CODE::Success); + + if (_param.zpA == nullptr) { + if (_param.zpB == nullptr) { + return ret; + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + dstptr, cachestep, M, N, _param.zpB + N_offset + K_offset * _param.ldsb, scab, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa + K_offset); + } + } else { + if (_param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.scalesA + M_offset * _param.ldsa + K_offset, _param.ldsa, redb); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + dstptr, cachestep, M, N, _param.zpA + M_offset * _param.ldsa + K_offset, + _param.zpB + N_offset + K_offset * _param.ldsb, _param.scalesA + M_offset * _param.ldsa + K_offset, scab, + _param.ldsa, _param.K, _param.reduceA + M_offset * _param.ldsa + K_offset, redb); + } + } + return ret; + } +}; + +struct ParamZpDequantInt32ToFp32 { + // necessary + float* C; + int ldc; + int ldsa; + float* scalesA; + float* scalesB; + // optional if A asym + uint8_t* zpA = nullptr; + float* reduceB = nullptr; + // optional if B asym + int8_t* zpB = nullptr; + float* reduceA = nullptr; + int K = 1; +}; +template +class ZpDequantInt32ToFp32 { + public: + using Param = ParamZpDequantInt32ToFp32; + static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + auto ret = kernel::wrapper::DequanS32Fp32::template forward(cacheptr, cachestep, cptr, _param.ldc, M, N, + _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.scalesB + N_offset); + if (ret != BTLA_CODE::Success) { + return ret; + } + if (_param.zpA == nullptr && _param.zpB == nullptr) { + return ret; + } else if (_param.zpA != nullptr && _param.zpB == nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_act( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.scalesA + M_offset * _param.ldsa, + _param.ldsa, _param.reduceB + N_offset); + } else if (_param.zpA == nullptr && _param.zpB != nullptr) { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_wei( + cptr, _param.ldc, M, N, _param.zpB + N_offset, _param.scalesB + N_offset, _param.ldsa, + _param.reduceA + M_offset * _param.ldsa); + } else { + ret = kernel::wrapper::RemoveZeroPointBias::template forward_both( + cptr, _param.ldc, M, N, _param.zpA + M_offset * _param.ldsa, _param.zpB + N_offset, + _param.scalesA + M_offset * _param.ldsa, _param.scalesB + N_offset, _param.ldsa, _param.K, + _param.reduceA + M_offset * _param.ldsa, _param.reduceB + N_offset); + } + return ret; + } +}; + +struct ParamAlphaBetaProcessS32U8 { + uint8_t* C; + int ldc; + float alpha; + float scaleAcc, scaleC; + int zpC; +}; +template +class AlphaBetaProcessS32U8 { + public: + using Param = ParamAlphaBetaProcessS32U8; + static BTLA_CODE forward(const int32_t* cacheptr, const int cachestep, const int M_offset, const int N_offset, const int M, + const int N, const Param& _param, void* tmpcache, size_t cachesize) { + auto COffset = M_offset * _param.ldc + N_offset; + auto cptr = _param.C + COffset; + return kernel::wrapper::QuanOutS32U32::template forward(_param.alpha, cacheptr, cachestep, cptr, _param.ldc, + M, N, _param.scaleAcc, _param.scaleC, _param.zpC); + } +}; + +} // namespace gemm +} // namespace epilogue +} // namespace bestla diff --git a/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_gemm.h b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_gemm.h new file mode 100644 index 000000000..793db081c --- /dev/null +++ b/vllm-ext/inference_engine/cpp/bestla/bestla/bestla_gemm.h @@ -0,0 +1,4320 @@ +// Copyright (c) 2023 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include + +#include "bestla_utils.h" +#include "bestla_jit.h" + +namespace bestla { +namespace gemm { +enum class CompType : uint16_t { + // base type, too many bits if reuse BTLA_DTYPE + tFP32 = 0, + tBF16 = 1, + tFP16 = 2, + tS8 = 3, + tU8 = 4, + tS32 = 5, + tS16 = 6, + MASK_A = 0xf, + SHIFT_A = 0, + MASK_B = 0xf0, + SHIFT_B = 4, + MASK_C = 0xf00, + SHIFT_C = 8, + COMP_FP32 = (tFP32 << SHIFT_A) | (tFP32 << SHIFT_B) | (tFP32 << SHIFT_C), + COMP_BF16_FP32 = (tBF16 << SHIFT_A) | (tBF16 << SHIFT_B) | (tFP32 << SHIFT_C), + COMP_FP16_FP16 = (tFP16 << SHIFT_A) | (tFP16 << SHIFT_B) | (tFP16 << SHIFT_C), + COMP_INT8_US_INT32 = (tU8 << SHIFT_A) | (tS8 << SHIFT_B) | (tS32 << SHIFT_C), + COMP_INT8_UU_INT32 = (tU8 << SHIFT_A) | (tU8 << SHIFT_B) | (tS32 << SHIFT_C), + COMP_INT8_SS_INT32 = (tS8 << SHIFT_A) | (tS8 << SHIFT_B) | (tS32 << SHIFT_C), + COMP_INT8_SU_INT32 = (tS8 << SHIFT_A) | (tU8 << SHIFT_B) | (tS32 << SHIFT_C), + COMP_INT16_SS_INT32 = (tS16 << SHIFT_A) | (tS16 << SHIFT_B) | (tS32 << SHIFT_C), + COMP_INT8_US_FP32 = (tU8 << SHIFT_A) | (tS8 << SHIFT_B) | (tFP32 << SHIFT_C), + COMP_INT8_UU_FP32 = (tU8 << SHIFT_A) | (tU8 << SHIFT_B) | (tFP32 << SHIFT_C), + COMP_INT8_SS_FP32 = (tS8 << SHIFT_A) | (tS8 << SHIFT_B) | (tFP32 << SHIFT_C), + COMP_INT8_SU_FP32 = (tS8 << SHIFT_A) | (tU8 << SHIFT_B) | (tFP32 << SHIFT_C), +}; + +class CompTypeHelper { + public: + static inline uint64_t get_mask_val(CompType raw, CompType mask, CompType shift) { + return (static_cast(raw) & static_cast(mask)) >> static_cast(shift); + } + + static void parse_id(CompType id, uint64_t* vals) { + vals[0] = get_mask_val(id, CompType::MASK_A, CompType::SHIFT_A); + vals[1] = get_mask_val(id, CompType::MASK_B, CompType::SHIFT_B); + vals[2] = get_mask_val(id, CompType::MASK_C, CompType::SHIFT_C); + } + + static const char* to_str(CompType id) { + static char tmp[128]; + uint64_t vals[3]; + parse_id(id, vals); + sprintf(tmp, "A%d_B%d_C%d", static_cast(vals[0]), static_cast(vals[1]), static_cast(vals[2])); + return tmp; + } + + static inline uint64_t get_B(CompType id) { return get_mask_val(id, CompType::MASK_B, CompType::SHIFT_B); } + + static inline bool is_integer(CompType id) { + auto bt = get_B(id); + bool flag = false; + flag |= bt == static_cast(CompType::tS8); + flag |= bt == static_cast(CompType::tU8); + return flag; + } +}; + +class CoreAttr { + public: + // INT64=LSB|**8bits:NTile**||**8bits:PackRow**||**16bits:CompType**||**8bits:ISA**||**24bits:reversed**| + static uint64_t constexpr NTILE_MASK = 0xff, NTILE_SHIFT = 0, PACKROW_MASK = 0xff00, PACKROW_SHIFT = 8, + COMP_MASK = 0xffff0000, COMP_SHIFT = 16, ISA_MASK = 0xff00000000, ISA_SHIFT = 32; + + static inline uint64_t get_mask_val(uint64_t raw, uint64_t mask, uint64_t shift) { return (raw & mask) >> shift; } + + static constexpr uint64_t make_core_id(int NTile, int PackRow, CompType CompType, BTLA_ISA ISA) { + return (static_cast(NTile) << NTILE_SHIFT) | (static_cast(PackRow) << PACKROW_SHIFT) | + (static_cast(CompType) << COMP_SHIFT) | (static_cast(ISA) << ISA_SHIFT); + } + static void parse_id(uint64_t id, uint64_t* vals) { + vals[0] = get_mask_val(id, NTILE_MASK, NTILE_SHIFT); + vals[1] = get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT); + vals[2] = get_mask_val(id, COMP_MASK, COMP_SHIFT); + vals[3] = get_mask_val(id, ISA_MASK, ISA_SHIFT); + } + + static const char* to_str(uint64_t id) { + static char tmp[128]; + uint64_t vals[4]; + parse_id(id, vals); + sprintf(tmp, "N%d_PACK%d_COMP%d_ISA%d", static_cast(vals[0]), static_cast(vals[1]), + static_cast(vals[2]), static_cast(vals[3])); + return tmp; + } + + static inline int get_packrow(uint64_t id) { return static_cast(get_mask_val(id, PACKROW_MASK, PACKROW_SHIFT)); } + + static inline size_t get_bsize(uint64_t id) { + auto packrow = get_packrow(id); + return size_t(4 / packrow); + } + + static inline BTLA_ISA get_ISA(uint64_t id) { return static_cast(get_mask_val(id, ISA_MASK, ISA_SHIFT)); } + + static inline CompType get_comp(uint64_t id) { + return static_cast(get_mask_val(id, COMP_MASK, COMP_SHIFT)); + } +}; + +namespace code { + +template +class Avx2N8P1 : protected bestla::xbyak::JitAvx2 { + public: + static int constexpr RegLen = 8, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX2; + static auto constexpr COMPUTE = CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fN16P1 : protected bestla::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX512F; + static auto constexpr COMPUTE = CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512fp16N32P1 : protected bestla::xbyak::JitAvx512_fp16 { + public: + static int constexpr RegLen = 32, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX512_FP16; + static auto constexpr COMPUTE = CompType::COMP_FP16_FP16; + typedef utils::fp16 AType; + typedef utils::fp16 BType; + typedef utils::fp16 CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastw(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastw(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ph(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512bf16N16P2 : protected bestla::xbyak::JitAvx512_bf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 2; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX512_BF16; + static auto constexpr COMPUTE = CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vdpbf16ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected bestla::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX512_VNNI; + static auto constexpr COMPUTE = CompType::COMP_INT8_US_INT32; + typedef uint8_t AType; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr KeepRegs = std::is_same_v ? 1 : 3; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - KeepRegs) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - KeepRegs); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX_VNNI; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_INT32 : CompType::COMP_INT8_SS_INT32; + using AType = AT; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + if (std::is_same_v) { + TmpRegCount = 2; + } + BRegCount = RegCount - ARegCount - CRegCount - TmpRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg + TmpRegCount <= RegCount); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } else { + vpsignb(vreg_t(TmpReg), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg + imm), vreg_t(AReg + imm)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } else { + vmovups(vreg_t(TmpReg), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg), vreg_t(TmpReg), vreg_t(AReg + imm)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +using AvxvnniN8P4U8 = AvxvnniN8P4; + +template +using AvxvnniN8P4S8 = AvxvnniN8P4; + +template +class Avx2vnniN8P4 : protected bestla::xbyak::JitAvx2 { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr KeepRegs = std::is_same_v ? 3 : 5; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - KeepRegs) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - KeepRegs); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX2; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_INT32 : CompType::COMP_INT8_SS_INT32; + using AType = AT; + typedef int8_t BType; + typedef int32_t CType; + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + const int16_t one = 1; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + private: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + protected: + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + if (std::is_same_v) { + TmpRegCount = 4; + } else { + TmpRegCount = 2; + } + BRegCount = RegCount - ARegCount - CRegCount - TmpRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg + TmpRegCount <= RegCount); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + vpbroadcastw(vreg_t(TmpReg + 0), ptr[parambase + OFFSET(one)]); + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _kunroll) { + for (int kk = 0; kk < _kunroll; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(BReg + i), + vreg_t(TmpReg + 0)); + } else { + vpsignb(vreg_t(TmpReg + 3), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vpbroadcastd(vreg_t(AReg + imm), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg + imm), vreg_t(AReg + imm)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes], vreg_t(TmpReg + 0)); + } else { + vmovups(vreg_t(TmpReg + 3), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg + 3), vreg_t(TmpReg + 3), vreg_t(AReg + imm)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } + } + } + } + } else { + assert(0); + } + } + } + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +using Avx2vnniN8P4U8 = Avx2vnniN8P4; + +template +using Avx2vnniN8P4S8 = Avx2vnniN8P4; + +template +class Amxbf16N16P2 : protected bestla::xbyak::JitAmxbf16 { + public: + static int constexpr RegLen = 16, PackRow = 2; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 32; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AMX_BF16; + static auto constexpr COMPUTE = CompType::COMP_BF16_FP32; + typedef utils::bf16 AType; + typedef utils::bf16 BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + for (int kk = 0; kk < kunrll; kk++) { + auto reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + tdpbf16ps(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; + +template +class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AMX_INT8; + static auto constexpr COMPUTE = + (std::is_same_v + ? std::is_same_v ? CompType::COMP_INT8_SS_INT32 : CompType::COMP_INT8_SU_INT32 + : std::is_same_v ? CompType::COMP_INT8_US_INT32 + : CompType::COMP_INT8_UU_INT32); + using AType = AT; + using BType = BT; + typedef int32_t CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 11, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll) { + auto& reg_Bstride = reg_tmp1; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto reg_Atmp = reg_tmp2; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + int mtnum = _mtile / 16; + for (int mm = 0; mm < mtnum; mm++) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(CTile + mm * NRegs + i), ptr[reg_matCptr + reg_cstride + i * 64]); + } + if (mm != mtnum - 1) { + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + lea(reg_matCptr, ptr[reg_matCptr + 8 * reg_cstride]); + } + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_tmp, dword[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(TmpReg + im * NRegs + j), ptr[reg_tmp + j * 64 + (i + im) * NTILE * 4]); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + outLocalLabel(); + } +}; +template +using Amxint8N16P4US = Amxint8N16P4; + +template +using Amxint8N16P4SS = Amxint8N16P4; + +class AmxConfigure : protected xbyak::JitAmxtile { + public: + typedef long long (*func_t)(tileconfig_t*); + + static void configure(int TILE_M, int TILE_N, int TILE_K, int elesize, int ANum, int BNum, int CNum) { + static AmxConfigure code; + tileconfig_t cfg; + std::memset(&cfg, 0, sizeof(cfg)); + configure_tiles(cfg, TILE_M, TILE_N, TILE_K, elesize, ANum, BNum, CNum); + code.mKernel(&cfg); + } + + protected: + AmxConfigure() { + generate_config(this); + mKernel = getCode(); + } + + func_t mKernel = nullptr; +}; + +namespace kblock { +// optimize for kblock gemm, each block size in k dimension has dequant operation +// all accumulators use fp32 dtype. +template +class Avx512fN16P1 : protected bestla::xbyak::JitAvx512f { + public: + static int constexpr RegLen = 16, PackRow = 1; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1) / NRegs : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 1; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX512F; + static auto constexpr COMPUTE = CompType::COMP_FP32; + typedef float AType; + typedef float BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + int k; + int n; + int init; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_ret = rax; + Xbyak::Opmask msk_wr = k1; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - ARegCount - CRegCount; + if (BRegCount < NRegs) { + BRegCount = 0; + ARegCount = BRegCount + 1; + } + if (BRegCount > NRegs) { + BRegCount = NRegs; + } + CReg = 0; + BReg = CReg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg <= RegCount); + TmpRegCount = RegCount - TmpReg; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 10, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + mov(reg_tmp, reg_ksize); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kloop", T_NEAR); + L(".unkloop"); + generate_fma(_mtile, KUNROLL); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_iterk, KUNROLL * KTILE); + cmp(reg_iterk, reg_tmp); // k iteration variable + jb(".unkloop"); + cmp(reg_tmp, reg_ksize); + jge(".kend", T_NEAR); + L(".kloop"); + generate_fma(_mtile, 1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_iterk, 1 * KTILE); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + L(".kend"); + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile) { + for (int kk = 0; kk < _ktile; kk++) { + lea(reg_tmp1, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } else if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm += ARegCount) { + int mm_re = utils::remainsize(mm, _mtile, ARegCount); + for (int imm = 0; imm < mm_re; imm++) { + vbroadcastss(vreg_t(AReg + imm), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vfmadd231ps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg + imm), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + } + } + } else { + assert(0); + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CReg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CReg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class Avx512vnniN16P4 : protected bestla::xbyak::JitAvx512vnni { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 1 - NRegs) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 1); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX512_VNNI; + static auto constexpr COMPUTE = CompType::COMP_INT8_US_FP32; + typedef uint8_t AType; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + float kscale; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = NRegs; + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= 1); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vpxorq(Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j), Xbyak::Zmm(CReg + i * NRegs + j)); + } + } + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(Xbyak::Zmm(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(TmpReg), Xbyak::Zmm(BReg + i)); + vmulps(Xbyak::Zmm(CReg + mm * NRegs + i), Xbyak::Zmm(AReg)); + vaddps(Xbyak::Zmm(CF32Reg + mm * NRegs + i), Xbyak::Zmm(CReg + mm * NRegs + i)); + } + } + } + + void generate_zp_correction(int _mtile) { + inLocalLabel(); + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + cmp(reg_tmp, 0); + je(".NOZP", T_NEAR); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); + } + + vbroadcastss(vreg_t(TmpReg), ptr[parambase + OFFSET(kscale)]); + auto& reg_kscale = reg_tmp2; + + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), vreg_t(TmpReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); + vsubps(Xbyak::Zmm(CF32Reg + i * NRegs + j), Xbyak::Zmm(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + L(".NOZP"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +class AvxvnniN8P4 : protected bestla::xbyak::JitAvxvnni { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - 3) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - 3); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX_VNNI; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_FP32 : CompType::COMP_INT8_SS_FP32; + using AType = AT; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + float kscale; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - CRegCount - CRegCount - ARegCount - 2; + if (BRegCount >= NRegs) { + BRegCount = NRegs; + } else { + BRegCount = 0; + } + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= 2); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } else { + vmovups(vreg_t(TmpReg), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg), vreg_t(TmpReg), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } + } + } + } else { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg), vreg_t(BReg + i)); + } else { + vpsignb(vreg_t(TmpReg), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg)); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(vreg_t(CReg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + vmulps(vreg_t(AReg), vreg_t(TmpReg), vreg_t(BReg + i)); + vmulps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg)); + vaddps(vreg_t(CF32Reg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + } + } + } else { + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(vreg_t(CReg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + vmovups(vreg_t(AReg), ptr[reg_tmp2 + i * VecBytes]); + vmulps(vreg_t(AReg), vreg_t(TmpReg)); + vmulps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg)); + vaddps(vreg_t(CF32Reg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + } + } + } + } + + void generate_zp_correction(int _mtile) { + inLocalLabel(); + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + cmp(reg_tmp, 0); + je(".NOZP", T_NEAR); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + + vbroadcastss(vreg_t(TmpReg), ptr[parambase + OFFSET(kscale)]); + auto& reg_kscale = reg_tmp4; + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_redB + i * VecBytes]); + } + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(vreg_t(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(vreg_t(AReg), vreg_t(AReg)); + vbroadcastss(vreg_t(TmpReg + 1), ptr[reg_scaleA]); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg + 1)); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(vreg_t(CReg + j), vreg_t(AReg), vreg_t(BReg + j)); + vsubps(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } else { + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(vreg_t(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(vreg_t(AReg), vreg_t(AReg)); + vbroadcastss(vreg_t(TmpReg + 1), ptr[reg_scaleA]); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg + 1)); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(vreg_t(CReg + j), vreg_t(AReg), ptr[reg_redB + j * VecBytes]); + vsubps(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } + + L(".NOZP"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; + +template +using AvxvnniN8P4U8 = kblock::AvxvnniN8P4; +template +using AvxvnniN8P4S8 = kblock::AvxvnniN8P4; + +template +class Avx2vnniN8P4 : protected bestla::xbyak::JitAvx2 { + public: + static int constexpr RegLen = 8, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr TmpReserve = std::is_same_v ? 2 : 4; + static int constexpr MRegs = _MTILE == 0 ? (RegCount - (TmpReserve + 1)) / (NRegs * 2) : _MTILE; + static_assert(NRegs * MRegs <= RegCount - (TmpReserve + 1)); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs, KTILE = 4; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AVX2; + static auto constexpr COMPUTE = + std::is_same_v ? CompType::COMP_INT8_US_FP32 : CompType::COMP_INT8_SS_FP32; + using AType = AT; + typedef int8_t BType; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + float kscale; + const uint16_t one = 1; + }; + typedef long long (*func_t)(params*); + + int CRegCount = 0, BRegCount = 0, ARegCount = 0, TmpRegCount = 0; + int CReg = 0, CF32Reg = 0, BReg = 0, AReg = 0, TmpReg = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CRegCount = MRegs * NRegs; + ARegCount = 1; + BRegCount = RegCount - CRegCount - CRegCount - ARegCount - TmpReserve; + if (BRegCount >= NRegs) { + BRegCount = NRegs; + } else { + BRegCount = 0; + } + CReg = 0; + CF32Reg = CReg + CRegCount; + BReg = CF32Reg + CRegCount; + AReg = BReg + BRegCount; + TmpReg = AReg + ARegCount; + assert(TmpReg < RegCount); + TmpRegCount = RegCount - TmpReg; + assert(TmpRegCount >= TmpReserve); + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j), vreg_t(CReg + i * NRegs + j)); + } + } + vpbroadcastw(vreg_t(TmpReg + 0), ptr[parambase + OFFSET(one)]); + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int _ktile, Xbyak::Reg64& tmp) { + for (int kk = 0; kk < _ktile; kk++) { + lea(tmp, ptr[reg_matAptr + kk * AKStepSize]); + if (BRegCount == 0) { + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg), + ptr[reg_matBptr + kk * BKStepSize + i * VecBytes], vreg_t(TmpReg + 0)); + } else { + vmovups(vreg_t(TmpReg + 3), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + vpsignb(vreg_t(TmpReg + 3), vreg_t(TmpReg + 3), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } + } + } + } else { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_matBptr + kk * BKStepSize + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vpbroadcastd(vreg_t(AReg), ptr[reg_tmp1]); + if constexpr (std::is_same_v) { + vpsignb(vreg_t(TmpReg + 2), vreg_t(AReg), vreg_t(AReg)); + } + add(reg_tmp1, reg_astride); + for (int i = 0; i < NRegs; i++) { + if constexpr (std::is_same_v) { + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(AReg), vreg_t(BReg + i), + vreg_t(TmpReg + 0)); + } else { + vpsignb(vreg_t(TmpReg + 3), vreg_t(BReg + i), vreg_t(AReg)); + vpdpbusds_(vreg_t(CReg + mm * NRegs + i), vreg_t(TmpReg + 1), vreg_t(TmpReg + 2), vreg_t(TmpReg + 3), + vreg_t(TmpReg + 0)); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".read", T_NEAR); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vxor(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j), vreg_t(CF32Reg + i * NRegs + j)); + } + } + jmp(".end", T_NEAR); + L(".read"); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(vreg_t(CF32Reg + i * NRegs + j), ptr[reg_matCptr + j * VecBytes]); + } + add(reg_matCptr, reg_cstride); + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(vreg_t(CReg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + vmulps(vreg_t(AReg), vreg_t(TmpReg), vreg_t(BReg + i)); + vmulps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg)); + vaddps(vreg_t(CF32Reg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + } + } + } else { + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(vreg_t(TmpReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(vreg_t(CReg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + vmovups(vreg_t(AReg), ptr[reg_tmp2 + i * VecBytes]); + vmulps(vreg_t(AReg), vreg_t(TmpReg)); + vmulps(vreg_t(CReg + mm * NRegs + i), vreg_t(AReg)); + vaddps(vreg_t(CF32Reg + mm * NRegs + i), vreg_t(CReg + mm * NRegs + i)); + } + } + } + } + + void generate_zp_correction(int _mtile) { + inLocalLabel(); + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + cmp(reg_tmp, 0); + je(".NOZP", T_NEAR); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + + vbroadcastss(vreg_t(TmpReg), ptr[parambase + OFFSET(kscale)]); + auto& reg_kscale = reg_tmp4; + if (BRegCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + vmovups(vreg_t(BReg + i), ptr[reg_redB + i * VecBytes]); + } + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(vreg_t(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(vreg_t(AReg), vreg_t(AReg)); + vbroadcastss(vreg_t(TmpReg + 1), ptr[reg_scaleA]); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg + 1)); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(vreg_t(CReg + j), vreg_t(AReg), vreg_t(BReg + j)); + vsubps(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } else { + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(vreg_t(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(vreg_t(AReg), vreg_t(AReg)); + vbroadcastss(vreg_t(TmpReg + 1), ptr[reg_scaleA]); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg + 1)); + vmulps(vreg_t(AReg), vreg_t(AReg), vreg_t(TmpReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(vreg_t(CReg + j), vreg_t(AReg), ptr[reg_redB + j * VecBytes]); + vsubps(vreg_t(CF32Reg + i * NRegs + j), vreg_t(CReg + j)); + } + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + } + + L(".NOZP"); + outLocalLabel(); + } + + void write_back(int _mtile) { + inLocalLabel(); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(CType)]); + for (int i = 0; i < _mtile; i++) { + for (int j = 0; j < NRegs; j++) { + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(CF32Reg + i * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + outLocalLabel(); + } +}; +template +using Avx2vnniN8P4U8 = kblock::Avx2vnniN8P4; +template +using Avx2vnniN8P4S8 = kblock::Avx2vnniN8P4; + +template +class Amxint8N16P4 : protected bestla::xbyak::JitAmxint8 { + public: + static int constexpr RegLen = 16, PackRow = 4; + static_assert(_NTILE % RegLen == 0); + static_assert(_MTILE % RegLen == 0); + static int constexpr NRegs = _NTILE / RegLen; + static int constexpr MRegs = _MTILE == 0 ? 1 : _MTILE / RegLen; + static_assert(NRegs * MRegs + 2 <= TileCount); + static int constexpr NTILE = RegLen * NRegs, MTILE = MRegs * RegLen, KTILE = 64; + static int constexpr KUNROLL = 2; + static auto constexpr ISA = BTLA_ISA::AMX_INT8; + static auto constexpr COMPUTE = (std::is_same_v ? std::is_same_v ? CompType::COMP_INT8_SS_FP32 + : CompType::COMP_INT8_SU_FP32 + : std::is_same_v ? CompType::COMP_INT8_US_FP32 + : CompType::COMP_INT8_UU_FP32); + using AType = AT; + using BType = BT; + typedef float CType; + + struct params { + AType* matA; + int astride; + BType* matB; + int bstride; + CType* matC; + int cstride; + uint8_t* zpA; + float* scaleA; + int ldsa; + float* scaleB; + float* reduceB; + int ldsb; + int k; + int n; + int kblock; + int init; + float kscale; + void* workspace; + }; + typedef long long (*func_t)(params*); + + int TmpRegCount = RegCount; + int TmpReg = 0; + int CTileCount = 0, ATileCount = 0, BTileCount = 0; + int CTile = 0, ATile = 0, BTile = 0; + static int constexpr BKStepSize = KTILE * NTILE * sizeof(BType); + static int constexpr AKStepSize = KTILE * sizeof(AType); + + void generate_code(int _mtile) { + assign_regs(); + reset(); + generate_mtile(_mtile); + ready(); + mKernel = getCode(); + } + func_t mKernel = nullptr; + + protected: + Xbyak::Reg64 parambase; + Xbyak::Reg64 reg_matAptr; + Xbyak::Reg64 reg_matBptr; + Xbyak::Reg64 reg_matCptr; + Xbyak::Reg64 reg_ksize; + Xbyak::Reg64 reg_nsize; + Xbyak::Reg64 reg_cstride; + Xbyak::Reg64 reg_astride; + Xbyak::Reg64 reg_iterk; + Xbyak::Reg64 reg_iterkb; + Xbyak::Reg64 reg_itern; + Xbyak::Reg64 reg_tmp; + Xbyak::Reg64 reg_tmp1; + Xbyak::Reg64 reg_tmp2; + Xbyak::Reg64 reg_tmp3; + Xbyak::Reg64 reg_tmp4; + Xbyak::Reg64 reg_ret = rax; + + void assign_regs() { + CTileCount = NRegs * MRegs; + auto tile_re = TileCount - CTileCount; + if (tile_re - 1 >= NRegs) { + BTileCount = NRegs; + ATileCount = tile_re - BTileCount; + } else if (tile_re - 1 >= MRegs) { + ATileCount = MRegs; + BTileCount = tile_re - ATileCount; + } else { + ATileCount = 1; + BTileCount = tile_re - ATileCount; + } + CTile = 0; + ATile = CTile + CTileCount; + BTile = ATile + ATileCount; + } + + void generate_mtile(int _mtile) { + inLocalLabel(); // use local label for multiple instance + Xbyak::util::StackFrame st(this, 1, 13, 16 * 10); + parambase = st.p[0]; + reg_matAptr = st.t[0]; + reg_matBptr = st.t[1]; + reg_matCptr = st.t[0]; + reg_ksize = st.t[2]; + reg_astride = st.t[3]; + reg_cstride = st.t[3]; + reg_iterk = st.t[4]; + reg_iterkb = st.t[12]; + reg_tmp = st.t[5]; + reg_tmp1 = st.t[6]; + reg_tmp2 = st.t[7]; + reg_tmp3 = st.t[10]; + reg_tmp4 = st.t[11]; + reg_nsize = st.t[8]; + reg_itern = st.t[9]; + reg_ret = rax; + + vreg_push(rsp); + + load32(reg_ksize, ptr[parambase + OFFSET(k)]); + load32(reg_nsize, ptr[parambase + OFFSET(n)]); + xor_(reg_itern, reg_itern); + L(".nloop"); + init_regs(_mtile); + mov(reg_matAptr, ptr[parambase + OFFSET(matA)]); + load32(reg_astride, ptr[parambase + OFFSET(astride)]); + mov(reg_matBptr, ptr[parambase + OFFSET(matB)]); + load32(reg_tmp, ptr[parambase + OFFSET(bstride)]); + imul(reg_tmp, reg_itern); + lea(reg_matBptr, ptr[reg_matBptr + reg_tmp]); + xor_(reg_iterk, reg_iterk); + generate_kloop(_mtile); + write_back(_mtile); + add(reg_itern, NTILE); + cmp(reg_itern, reg_nsize); + jb(".nloop"); + mov(reg_ret, 0); + vreg_pop(rsp); + + outLocalLabel(); // end of local label + } + + void generate_kloop(int _mtile) { + inLocalLabel(); + xor_(reg_iterkb, reg_iterkb); + L(".kloop"); + for (int i = 0; i < CTileCount; i++) { + tilezero(Xbyak::Tmm(CTile + i)); + } + xor_(reg_tmp2, reg_tmp2); + load32(reg_tmp3, ptr[parambase + OFFSET(kblock)]); + mov(reg_tmp, reg_tmp3); + padto_le(reg_tmp, KUNROLL * KTILE); + cmp(reg_tmp, 0); + jz(".kbloop", T_NEAR); + L(".unkbloop"); + generate_fma(_mtile, KUNROLL, reg_tmp1, reg_tmp4); + add(reg_matAptr, KUNROLL * AKStepSize); + add(reg_matBptr, KUNROLL * BKStepSize); + add(reg_tmp2, KUNROLL * KTILE); + cmp(reg_tmp2, reg_tmp); + jb(".unkbloop"); + cmp(reg_tmp, reg_tmp3); + jge(".kend", T_NEAR); + L(".kbloop"); + generate_fma(_mtile, 1, reg_tmp1, reg_tmp4); + add(reg_matAptr, 1 * AKStepSize); + add(reg_matBptr, 1 * BKStepSize); + add(reg_tmp2, 1 * KTILE); + cmp(reg_tmp2, reg_tmp3); + jb(".kbloop"); + L(".kend"); + add(reg_iterk, reg_tmp2); + generate_f32_accumulate(_mtile); + generate_zp_correction(_mtile); + inc(reg_iterkb); + cmp(reg_iterk, reg_ksize); // k iteration variable + jb(".kloop"); + + outLocalLabel(); + } + + void generate_fma(int _mtile, int kunrll, Xbyak::Reg64& tmpreg, Xbyak::Reg64& tmpreg2) { + auto& reg_Bstride = tmpreg2; + mov(reg_Bstride, NTILE * 4); + int mtiles = _mtile / RegLen; + + for (int kk = 0; kk < kunrll; kk++) { + auto reg_Atmp = tmpreg; + if (mtiles == 1) { + reg_Atmp = reg_matAptr; + } else { + mov(reg_Atmp, reg_matAptr); + } + if (BTileCount == NRegs) { + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile + i), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + } + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile + i)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } else { + if (ATileCount == mtiles) { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile + mm), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + for (int mm = 0; mm < mtiles; mm++) { + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile + mm), Xbyak::Tmm(BTile)); + } + } + } else { + for (int mm = 0; mm < mtiles; mm++) { + tileloadd(Xbyak::Tmm(ATile), ptr[reg_Atmp + reg_astride + kk * AKStepSize]); + for (int i = 0; i < NRegs; i++) { + tileloaddt1(Xbyak::Tmm(BTile), ptr[reg_matBptr + reg_Bstride + kk * BKStepSize + i * 64]); + _tdpb(Xbyak::Tmm(CTile + mm * NRegs + i), Xbyak::Tmm(ATile), Xbyak::Tmm(BTile)); + } + if (mm != mtiles - 1) { + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + lea(reg_Atmp, ptr[reg_Atmp + 8 * reg_astride]); + } + } + } + } + } + } + + void init_regs(int _mtile) { + inLocalLabel(); + load32(reg_tmp, ptr[parambase + OFFSET(init)]); + cmp(reg_tmp, 0); + je(".end", T_NEAR); + mov(reg_matCptr, ptr[parambase + OFFSET(matC)]); + load32(reg_cstride, ptr[parambase + OFFSET(cstride)]); + lea(reg_matCptr, ptr[reg_matCptr + reg_itern * sizeof(float)]); + int zunroll = TmpRegCount / NRegs; + for (int i = 0; i < _mtile; i += zunroll) { + int m_re = utils::remainsize(i, _mtile, zunroll); + for (int im = 0; im < m_re; im++) { + for (int j = 0; j < NRegs; j++) { + vxorps(vreg_t(TmpReg + im * NRegs + j), vreg_t(TmpReg + im * NRegs + j)); + vmovups(ptr[reg_matCptr + j * VecBytes], vreg_t(TmpReg + im * NRegs + j)); + } + add(reg_matCptr, reg_cstride); + } + } + L(".end"); + outLocalLabel(); + } + + void generate_f32_accumulate(int _mtile) { + mov(reg_tmp3, ptr[parambase + OFFSET(workspace)]); + mov(reg_tmp1, NTILE * 4); + for (int mm = 0; mm < MRegs; mm++) { + for (int i = 0; i < NRegs; i++) { + tilestored(ptr[reg_tmp3 + reg_tmp1 + i * 64 + mm * 16 * NTILE * 4], Xbyak::Tmm(CTile + mm * NRegs + i)); + } + } + load32(reg_tmp, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(scaleB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + + mov(reg_tmp, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(float)]); + load32(reg_tmp1, ptr[parambase + OFFSET(ldsa)]); + int BReg = TmpReg; + int AReg = BReg + NRegs; + int SAReg = AReg + 1; + int CReg = SAReg + 1; + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_tmp2 + i * VecBytes]); + } + mov(reg_tmp2, ptr[parambase + OFFSET(matC)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + load32(reg_tmp4, dword[parambase + OFFSET(cstride)]); + for (int mm = 0; mm < _mtile; mm++) { + vbroadcastss(Xbyak::Zmm(SAReg), ptr[reg_tmp]); + lea(reg_tmp, ptr[reg_tmp + reg_tmp1 * sizeof(float)]); + for (int i = 0; i < NRegs; i++) { + vcvtdq2ps(Xbyak::Zmm(CReg + i), ptr[reg_tmp3 + i * 64 + mm * 4 * NTILE]); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(SAReg), Xbyak::Zmm(BReg + i)); + vmulps(Xbyak::Zmm(CReg + i), Xbyak::Zmm(AReg)); + vaddps(Xbyak::Zmm(CReg + i), ptr[reg_tmp2 + i * 64]); + vmovups(ptr[reg_tmp2 + i * 64], Xbyak::Zmm(CReg + i)); + } + add(reg_tmp2, reg_tmp4); + } + } + + void generate_zp_correction(int _mtile) { + inLocalLabel(); + mov(reg_tmp, ptr[parambase + OFFSET(zpA)]); + cmp(reg_tmp, 0); + je(".NOZP", T_NEAR); + lea(reg_tmp, ptr[reg_tmp + reg_iterkb * sizeof(AType)]); + auto& reg_zpA = reg_tmp; + + load32(reg_tmp1, ptr[parambase + OFFSET(ldsb)]); + imul(reg_tmp1, reg_iterkb); + mov(reg_tmp2, ptr[parambase + OFFSET(reduceB)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_tmp1 * sizeof(float)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + auto& reg_redB = reg_tmp2; + + mov(reg_tmp1, ptr[parambase + OFFSET(scaleA)]); + lea(reg_tmp1, ptr[reg_tmp1 + reg_iterkb * sizeof(float)]); + auto& reg_scaleA = reg_tmp1; + + load32(reg_tmp3, ptr[parambase + OFFSET(ldsa)]); + auto& reg_ldsa = reg_tmp3; + int BReg = TmpReg; + int AReg = BReg + NRegs; + int SReg = AReg + 1; + int CReg = SReg + 1; + int CF32Reg = CReg + NRegs; + for (int i = 0; i < NRegs; i++) { + vmovups(Xbyak::Zmm(BReg + i), ptr[reg_redB + i * VecBytes]); + } + + vbroadcastss(vreg_t(SReg), ptr[parambase + OFFSET(kscale)]); + mov(reg_tmp2, ptr[parambase + OFFSET(matC)]); + lea(reg_tmp2, ptr[reg_tmp2 + reg_itern * sizeof(float)]); + load32(reg_tmp4, dword[parambase + OFFSET(cstride)]); + + for (int i = 0; i < _mtile; i++) { + vpbroadcastb(Xbyak::Xmm(AReg), ptr[reg_zpA]); + vpmovzxbd(Xbyak::Zmm(AReg), Xbyak::Xmm(AReg)); + vcvtdq2ps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg)); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), zword_b[reg_scaleA]); + vmulps(Xbyak::Zmm(AReg), Xbyak::Zmm(AReg), vreg_t(SReg)); + for (int j = 0; j < NRegs; j++) { + vmulps(Xbyak::Zmm(CReg + j), Xbyak::Zmm(AReg), Xbyak::Zmm(BReg + j)); + vmovups(Xbyak::Zmm(CF32Reg + j), ptr[reg_tmp2 + j * 64]); + vsubps(Xbyak::Zmm(CF32Reg + j), Xbyak::Zmm(CReg + j)); + vmovups(ptr[reg_tmp2 + j * 64], Xbyak::Zmm(CF32Reg + j)); + } + add(reg_tmp2, reg_tmp4); + lea(reg_zpA, ptr[reg_zpA + reg_ldsa * sizeof(AType)]); + lea(reg_scaleA, ptr[reg_scaleA + reg_ldsa * sizeof(float)]); + } + L(".NOZP"); + outLocalLabel(); + } + + void write_back(int _mtile) { (void)(_mtile); } +}; +template +using Amxint8N16P4US = kblock::Amxint8N16P4; + +template +using Amxint8N16P4SS = kblock::Amxint8N16P4; +} // namespace kblock +} // namespace code +template