diff --git a/python/config.py.in b/python/config.py.in index 963ac498c..187945f71 100644 --- a/python/config.py.in +++ b/python/config.py.in @@ -5,7 +5,7 @@ llvm_obj_root = "@LLVM_BINARY_DIR@" llvm_lib_dir = "@LLVM_LIBRARY_DIR@" shlib_ext = "@LTDL_SHLIB_EXT@" gc_lib_dir = "@LLVM_LIBRARY_OUTPUT_INTDIR@" - +GC_ENABLE_DNNL_API ="@GC_ENABLE_DNNL_API@" in ["ON", "1"] if sys.platform.startswith("win32"): mlir_runner_utils_dir = os.path.normpath(os.path.join(llvm_obj_root, "bin")) diff --git a/python/gc_mlir/_mlir_libs/_site_initialize_0.py b/python/gc_mlir/_mlir_libs/_site_initialize_0.py index 3fba4fbdd..addd5c52b 100644 --- a/python/gc_mlir/_mlir_libs/_site_initialize_0.py +++ b/python/gc_mlir/_mlir_libs/_site_initialize_0.py @@ -5,6 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # # ===-----------------------------------------------------------------------===# +from gc_mlir.config import GC_ENABLE_DNNL_API def context_init_hook(context): @@ -13,11 +14,9 @@ def context_init_hook(context): register_cpuruntime_dialect(context) - try: + if GC_ENABLE_DNNL_API: from ._gc_mlir.onednn_graph import ( register_dialect as register_onednn_graph_dialect, ) register_onednn_graph_dialect(context) - except ModuleNotFoundError: - print("onednn_graph dialect not found") diff --git a/scripts/correctness.sh b/scripts/correctness.sh index c0ae008ce..509b24066 100755 --- a/scripts/correctness.sh +++ b/scripts/correctness.sh @@ -102,5 +102,8 @@ python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/reduce.mlir || F # mlir # python3 -m benchgc --verbose 0 --driver mlir --case ${CASE_DIR}/llama2.mlir || FAIL=1 +#mlp +python3 -m benchgc --verbose 1 --driver pattern --case mlp --batch_size=32 --hidden_size_list=32x16x64 --has_bias=1x1 --act_type=noop --dtype=f32 + set +e exit $FAIL \ No newline at end of file diff --git a/test/benchgc/CMakeLists.txt b/test/benchgc/CMakeLists.txt index e50f35cf2..d31895de3 100644 --- a/test/benchgc/CMakeLists.txt +++ b/test/benchgc/CMakeLists.txt @@ -39,3 +39,4 @@ add_subdirectory("src/benchgc/mlir") add_subdirectory("src/benchgc/linalg") add_subdirectory("src/benchgc/tensor") add_subdirectory("src/benchgc/arith") +add_subdirectory("src/benchgc/pattern") diff --git a/test/benchgc/README.md b/test/benchgc/README.md index 77499c5fd..9f18cc398 100644 --- a/test/benchgc/README.md +++ b/test/benchgc/README.md @@ -2,32 +2,47 @@ ## Description -Benchgc is a tool used to verify the correctness and performance of graph compiler. Benchgc accepts MLIR files based on the OneDNN graph dialect as test cases and prepares test data for them. For correctness verification, Benchgc will use PyTorch as a reference for comparison. +Benchgc is a tool used to verify the correctness and performance of graph compiler. Benchgc accepts MLIR files as test cases and prepares test data for them. For correctness verification, Benchgc will use PyTorch as a reference for comparison. ## Prerequisite * python >= 3.10 * torch >= 2.2 -* pybind11 +* Enable mlir python binding, Refer to [`python/README.md`](../../python/README.md) for detail -## Build and install +## Build +There are two ways for using benchgc + +* Build `.whl` and install benchgc ``` # Please execute at the top level of the project -mkdir -p build -cd build - +mkdir build && cd build cmake .. -DMLIR_DIR=$MLIR_PATH -DGC_TEST_ENABLE=ON -DGC_ENABLE_BINDINGS_PYTHON=ON -DGC_BENCH_ENABLE=ON make -j benchgc - python -m pip install test/benchgc/dist/benchgc-*.whl ``` +* Run benchgc from source code + +``` +# Please execute at the top level of the project + +mkdir build && cd build +cmake .. -DMLIR_DIR=$MLIR_PATH -DGC_TEST_ENABLE=ON -DGC_ENABLE_BINDINGS_PYTHON=ON -DGC_BENCH_ENABLE=ON +make -j GcPythonModules +export PYTHONPATH=$(pwd)/python_packages/gc_mlir_core/:$(pwd)/../test/benchgc/src/ +``` + ## Synopsis ``` -python -m benchgc [OPTIONS] --driver [DRIVER] --case [CASE] +python -m benchgc [OPTIONS] --mode [MODE] --driver [DRIVER] --case [CASE] ``` -## Flags +## Common Options +### --mode [str] +* C : correctness testing (by default) +* P : performance testing + ### --driver [str] * linalg: test the single op in linalg dialect * mlir: upload a mlir file and run @@ -38,11 +53,25 @@ python -m benchgc [OPTIONS] --driver [DRIVER] --case [CASE] * if driver=pattern, please provide the pre-defined pattern name, such as mlp here * if driver is a dialect name, please provide the detail op name to start a single op test +### --entry [str] +* default : "entry" +* the entry name of the kernel of input mlir or generated mlir + ### --seed [int] * set the seed to generate the test data and reprodce the test ### --verbose [int] -* set the verbose level +* set the verbose level, default : 0 +* 0 : NO_VERBOSE +* 1 : MODULE_VERBOSE, print the module will be executed +* 2 : ARG_VERBOSE, + print arg information +* 3 : COMPARE_VERBOSE, + print threshold for comparison +* 4 : ERROR_OUTPUT_VERBOSE, + print all error data points if failed +* 5 : OUTPUT_VERBOSE, + print all result including passed tensor +* 6 : INPUT_VERBOSE, + print input torch tensors + +### --ir_printing (action=store_true) +* Print the ir during the pass-pipeline ### --md index:SHAPExTYPE * Describe the shape and data type for argument @@ -97,7 +126,28 @@ module { | Norm check | N | threshold | | Benchdnn driver | D | driver_name:dtype:case | +## Bench Options +### --bench_kind [str] +* py : use the MLIR Python API to invoke the kernel and use Python to calculate the time cost +* wrapper : modify MLIR by wrapping the kernel into a new method and calling the `nanoTime()` method before and after calling the kernel. Finally, calculate the difference as the time cost + +### --warm_up [int] +* warm-up times of the execution + +### --repeat [int] +* repeat times of the execution + +## Pattern Options +Each pattern has its own unique options. +### mlp +* `--batch_size`: the input +* `--hidden_size_list`: hidden_sizes of mlp, example: 32x16x64 +* `--has_bias`: if the matmul op has bias, example: 1x0 +* `--act_type`: choices=["noop", "relu"] +* `--dtype`: choices=["bf16", "f32"] + ## Example +### Correctness testing example ``` # single add op test # using the same data filling / compare strategy as the benchdnn primitive driver if not set @@ -254,4 +304,126 @@ p2p check: threshold: 0.0000000 (1, 0): ref: 25.1690636 res: 25.1690636 abs_diff: 0.0000000 rel_diff: 0.0000000 (1, 1): ref: -7.8600063 res: -7.8600044 abs_diff: 0.0000019 rel_diff: 0.0000002 FAIL: linalg.matmul_transpose_b +``` + +### Perf testing example +* single op example +``` +python3 -m benchgc --verbose 1 --mode P --driver linalg --case add --md 0:4x5xf32 --md 1:4x5xf32 --md 2:4x5xf32 + +module { + func.func @entry(%arg0: tensor<4x5xf32>, %arg1: tensor<4x5xf32>) -> tensor<4x5xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<4x5xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<4x5xf32>) -> tensor<4x5xf32> + %2 = linalg.add ins(%arg0, %arg1 : tensor<4x5xf32>, tensor<4x5xf32>) outs(%1 : tensor<4x5xf32>) -> tensor<4x5xf32> + return %2 : tensor<4x5xf32> + } +} + +===========bench result=========== +{ + "args": { + "mode": "P", + "driver": "linalg", + "case": "add", + "md": [ + "0:4x5xf32", + "1:4x5xf32", + "2:4x5xf32" + ], + "fill": [], + "cmp": [], + "seed": 0, + "verbose": 1, + "entry": "entry", + "ir_printing": false, + "cast": "cast_signed", + "dimension": null, + "dimensions": null, + "dilations": null, + "strides": null, + "bench_kind": "py", + "warm_up": 100, + "repeat": 100 + }, + "compile_cost(ms)": 37.72595152258873, + "execute_cost(ms)": 0.00022314488887786865 +} +``` + +* mlir example +``` +python3 -m benchgc --mode P --verbose 1 --driver mlir --case=./test.mlir --bench_kind wrapper --warm_up 50 --repeat 200 +\module { + func.func @entry(%arg0: tensor<512x128xf32>) -> tensor<512x128xf32> attributes {llvm.emit_c_interface} { + %cst = arith.constant 0.000000e+00 : f32 + %0 = tensor.empty() : tensor<512x128xf32> + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<512x128xf32>) -> tensor<512x128xf32> + %2 = linalg.abs ins(%arg0 : tensor<512x128xf32>) outs(%1 : tensor<512x128xf32>) -> tensor<512x128xf32> + return %2 : tensor<512x128xf32> + } +} + +===========bench result=========== +{ + "args": { + "mode": "P", + "driver": "mlir", + "case": "/home/xurui/gc_v2/test.mlir", + "md": [], + "fill": [], + "cmp": [], + "seed": 0, + "verbose": 1, + "entry": "entry", + "ir_printing": false, + "bench_kind": "wrapper", + "warm_up": 50, + "repeat": 200 + }, + "compile_cost(ms)": 70.6995539367199, + "execute_cost(ms)": 0.029325044999999984 +} +``` +* mlp example +``` +python3 -m benchgc --verbose 1 --mode P --driver pattern --case mlp --batch_size=32 --hidden_size_list=32x16x64 --has_bias=0x0 --act_type=noop --dtype=f32 + +module { + func.func @entry(%arg0: tensor<32x32xf32>, %arg1: tensor<32x16xf32>, %arg2: tensor<16x64xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} { + %0 = tensor.empty() : tensor<32x16xf32> + %1 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x16xf32>) outs(%0 : tensor<32x16xf32>) -> tensor<32x16xf32> + %2 = tensor.empty() : tensor<32x64xf32> + %3 = linalg.matmul {cast = #linalg.type_fn} ins(%1, %arg2 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%2 : tensor<32x64xf32>) -> tensor<32x64xf32> + return %3 : tensor<32x64xf32> + } +} + +===========bench result=========== +{ + "args": { + "mode": "P", + "driver": "pattern", + "case": "mlp", + "md": [], + "fill": [], + "cmp": [], + "seed": 0, + "verbose": 1, + "entry": "entry", + "ir_printing": false, + "bench_kind": "py", + "warm_up": 100, + "repeat": 100, + "batch_size": 32, + "hidden_size_list": "32x16x64", + "has_bias": "0x0", + "act_type": "noop", + "dtype": "f32" + }, + "compile_cost(ms)": 109.86808314919472, + "execute_cost(ms)": 0.02944003790616989 +} + ``` \ No newline at end of file diff --git a/test/benchgc/src/benchgc/__main__.py b/test/benchgc/src/benchgc/__main__.py index 1470183a6..283fe9dc6 100644 --- a/test/benchgc/src/benchgc/__main__.py +++ b/test/benchgc/src/benchgc/__main__.py @@ -16,12 +16,12 @@ import argparse +import json import sys from typing import Dict, List import benchgc.mlir.util import benchgc.util -import gc_mlir.ir import runner import torch from benchgc.arg import ( @@ -31,11 +31,23 @@ set_default_fill, ) from benchgc.arg.arg import Arg +from benchgc.bench import mlir_wrapper_bench, py_timeit_bench from benchgc.mlir.arg import get_mlir_args +from benchgc.pattern import get_pattern_clz +from gc_mlir import ir from gc_mlir.graph_compiler import GraphCompiler -try: - parser = argparse.ArgumentParser(prog="benchmark tool for graph compiler") + +def add_common_options(parser: argparse.ArgumentParser): + """common options for benchgc""" + parser.add_argument( + "--mode", + required=False, + help="specify the test mode, C for correctness testing, P for performance testing", + choices=["C", "P"], + default="C", + type=str, + ) parser.add_argument( "--driver", required=False, @@ -97,176 +109,268 @@ benchgc.util.INPUT_VERBOSE, ], ) - parser.add_argument( - "--cast", - required=False, - default="cast_signed", - help="define attribute supported by linalg op such as matmul_transpose_b", - choices=["cast_signed", "cast_unsigned"], - type=str, - ) - # single dimension index - # linalg.softmax parser.add_argument( - "--dimension", + "--entry", required=False, - default=None, - help="define the dimension attribute in linalg op", - type=int, + default="entry", + help="the entry func name of a mlir", + type=str, ) - # multiple dimensions array - # linalg.broadcast / linalg.reduce parser.add_argument( - "--dimensions", - required=False, - default=None, - action="append", - help="define the dimensions attribute in linalg op", - type=int, + "--ir_printing", + action="store_true", + help="if we need print the ir during the pass-pipeline", ) - parser.add_argument( - "--dilations", - required=False, - default=None, - action="append", - help="define the dilations attribute in linalg op", - type=int, - ) - - parser.add_argument( - "--strides", - required=False, - default=None, - action="append", - help="define the strides attribute in linalg op", - type=int, - ) - flags = parser.parse_args() - benchgc.util.set_seed(flags.seed) + if parser.parse_known_args()[0].driver == "linalg": + parser.add_argument( + "--cast", + required=False, + default="cast_signed", + help="define attribute supported by linalg op such as matmul_transpose_b", + choices=["cast_signed", "cast_unsigned"], + type=str, + ) + + # single dimension index + # linalg.softmax + parser.add_argument( + "--dimension", + required=False, + default=None, + help="define the dimension attribute in linalg op", + type=int, + ) + + # multiple dimensions array + # linalg.broadcast / linalg.reduce + parser.add_argument( + "--dimensions", + required=False, + default=None, + action="append", + help="define the dimensions attribute in linalg op", + type=int, + ) + + parser.add_argument( + "--dilations", + required=False, + default=None, + action="append", + help="define the dilations attribute in linalg op", + type=int, + ) + + parser.add_argument( + "--strides", + required=False, + default=None, + action="append", + help="define the strides attribute in linalg op", + type=int, + ) + + +def add_bench_options(parser: argparse.ArgumentParser): + """add options for bench mode""" + if parser.parse_known_args()[0].mode == "P": + parser.add_argument( + "--bench_kind", type=str, choices=["py", "wrapper"], default="py" + ) + parser.add_argument("--warm_up", type=int, default=100) + parser.add_argument("--repeat", type=int, default=100) + + + +def add_pattern_options(parser: argparse.ArgumentParser): + """add options for each pattern""" + if parser.parse_known_args()[0].driver == "pattern": + pattern_name = parser.parse_known_args()[0].case + get_pattern_clz(pattern_name).add_args(parser) + + +def get_module_and_args(flags): + args: List[Arg] = [] + if flags.driver in ["mlir", "pattern"]: + # we need to find all args by reading the entry function + with ir.Context() as ctx: + if flags.driver == "mlir": + with open(flags.case, "r") as mlir_file: + module = ir.Module.parse(mlir_file.read()) + elif flags.driver == "pattern": + pattern_clz = get_pattern_clz(flags.case) + module = pattern_clz(ctx, flags).ir_module + + entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry) + idx: int = 0 + # FIXME: only support RankTensorType now + for i in entry.type.inputs: + args.append(Arg(idx)) + args[-1].dtype = str(i.element_type) + args[-1].shape = list(i.shape) + args[-1].set_scalar() + idx += 1 + + for o in entry.type.results: + args.append(Arg(idx)) + args[-1].dtype = str(o.element_type) + args[-1].shape = list(o.shape) + args[-1].set_scalar() + idx += 1 + + elif flags.driver in ["linalg"]: + # all arg shape/dt should be provided in single op test + for i in range(len(flags.md)): + args.append(Arg(i)) + + for md in flags.md: + colon = md.find(":") + if colon == -1: + raise Exception("Wrong md format: %s", md) + idx = int(md[:colon]) + args[idx].set_md(md[colon + 1 :]) + + from .linalg import mlir_op + + mlir_func = mlir_op[flags.case] + module = mlir_func(flags, args) + else: + raise Exception(f"unsupported driver {flags.driver}") + for fill in flags.fill: + colon = fill.find(":") + if colon == -1: + raise Exception("Wrong fill format: %s", fill) + idx = int(fill[:colon]) + args[idx].set_fill(fill[colon + 1 :]) -except argparse.ArgumentError: - sys.stderr.write("Argument parse failed\n") - sys.exit(1) - -args: List[Arg] = [] - -if flags.driver == "mlir": - # we need to find all args by reading the entry function - with open(flags.case, "r") as mlir_file: - with gc_mlir.ir.Context() as ctx: - module = gc_mlir.ir.Module.parse(mlir_file.read()) - entry = benchgc.mlir.util.get_entry(module) - idx: int = 0 - # FIXME: only support RankTensorType now - for i in entry.type.inputs: - args.append(Arg(idx)) - args[-1].dtype = str(i.element_type) - args[-1].shape = list(i.shape) - args[-1].set_scalar() - idx += 1 - - for o in entry.type.results: - args.append(Arg(idx)) - args[-1].dtype = str(o.element_type) - args[-1].shape = list(o.shape) - args[-1].set_scalar() - idx += 1 -elif flags.driver in ["linalg"]: - # all arg shape/dt should be provided in single op test - for i in range(len(flags.md)): - args.append(Arg(i)) - - for md in flags.md: - colon = md.find(":") + for cmp in flags.cmp: + colon = cmp.find(":") if colon == -1: - raise Exception("Wrong md format: %s", md) - idx = int(md[:colon]) - args[idx].set_md(md[colon + 1 :]) - - from .linalg import mlir_op - - mlir_func = mlir_op[flags.case] - module = mlir_func(flags, args) -else: - raise Exception(f"unsupported driver {flags.driver}") - -for fill in flags.fill: - colon = fill.find(":") - if colon == -1: - raise Exception("Wrong fill format: %s", fill) - idx = int(fill[:colon]) - args[idx].set_fill(fill[colon + 1 :]) - -for cmp in flags.cmp: - colon = cmp.find(":") - if colon == -1: - raise Exception("Wrong cmp format: %s", cmp) - idx = int(cmp[:colon]) - args[idx].set_cmp(cmp[colon + 1 :]) - -entry = benchgc.mlir.util.get_entry(module) - -for i, arg in enumerate(args): - # use zero filling if the arg is return value - set_default_fill(flags, arg, args, i >= len(entry.type.inputs)) - set_default_compare(flags, arg, args, i >= len(entry.type.inputs)) - -for arg in args: - arg.print_verbose(flags.verbose) - -if flags.verbose >= benchgc.util.MODULE_VERBOSE: - print(module) - -ref_args: List[torch.Tensor] = [] -gc_args: List[torch.Tensor | int] = [] -ref_tensors: Dict[str, torch.Tensor] = {} -gc_tensors: Dict[str, torch.Tensor] = {} - -for i in range(len(args)): - tensor = fill_tensor(flags, args[i], i) - gc_tensors["%arg" + str(i)] = tensor - ref_tensors["%arg" + str(i)] = tensor.clone() - ref_args.append(ref_tensors["%arg" + str(i)]) - if args[i].scalar: - gc_args.append(tensor.data_ptr()) + raise Exception("Wrong cmp format: %s", cmp) + idx = int(cmp[:colon]) + args[idx].set_cmp(cmp[colon + 1 :]) + + entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry) + + for i, arg in enumerate(args): + # use zero filling if the arg is return value + set_default_fill(flags, arg, args, i >= len(entry.type.inputs)) + set_default_compare(flags, arg, args, i >= len(entry.type.inputs)) + + for arg in args: + arg.print_verbose(flags.verbose) + + if flags.verbose >= benchgc.util.MODULE_VERBOSE: + print(module) + return module, args + + +def correctness_testing(flags, module, args): + ref_args: List[torch.Tensor] = [] + gc_args: List[torch.Tensor | int] = [] + ref_tensors: Dict[str, torch.Tensor] = {} + gc_tensors: Dict[str, torch.Tensor] = {} + + for i in range(len(args)): + tensor = fill_tensor(flags, args[i], i) + gc_tensors["%arg" + str(i)] = tensor + ref_tensors["%arg" + str(i)] = tensor.clone() + ref_args.append(ref_tensors["%arg" + str(i)]) + if args[i].scalar: + gc_args.append(tensor.data_ptr()) + else: + gc_args.append(tensor) + + entry = benchgc.mlir.util.get_kernel_func_from_module(module, flags.entry) + # ref_out contains return value of the entry + ref_out = runner.ref_run(entry, ref_tensors) + + # we need to swap the result into the args if some arg is the return value + if ref_out is not None: + for i in range(len(ref_out)): + ref_args[0 - i - 1] = ref_out[0 - i - 1] + + mlir_args = get_mlir_args(gc_args) + passes = "any(gc-cpu-pipeline)" + + with module.context as ctx: + if flags.ir_printing: + ctx.enable_multithreading(False) + compiler = GraphCompiler(passes) + engine = compiler.compile_and_jit(module, flags.ir_printing) + engine.invoke(flags.entry, *mlir_args) + + fail, mistrust = False, False + for i in range(len(args)): + # gc_arg contains address for scalar value + # we need to find result by arg name + res = compare_tensor( + args[i], ref_args[i], gc_tensors["%arg" + str(i)], flags.verbose + ) + fail = fail or (not res[0]) + if res[1] is not None: + mistrust = mistrust | res[1] + if fail: + print(f"FAIL: {flags.driver}.{flags.case}") + sys.exit(1) + elif mistrust: + print(f"MISTRUST: {flags.driver}.{flags.case}") else: - gc_args.append(tensor) - - -# ref_out contains return value of the entry -ref_out = runner.ref_run(entry, ref_tensors) - -# we need to swap the result into the args if some arg is the return value -if ref_out is not None: - for i in range(len(ref_out)): - ref_args[0 - i - 1] = ref_out[0 - i - 1] - -mlir_args = get_mlir_args(gc_args) -passes = "any(gc-cpu-pipeline)" - -with module.context: - compiler = GraphCompiler(passes) - engine = compiler.compile_and_jit(module) - engine.invoke("entry", *mlir_args) - -fail, mistrust = False, False -for i in range(len(args)): - # gc_arg contains address for scalar value - # we need to find result by arg name - res = compare_tensor( - args[i], ref_args[i], gc_tensors["%arg" + str(i)], flags.verbose - ) - fail = fail or (not res[0]) - if res[1] is not None: - mistrust = mistrust | res[1] -if fail: - print(f"FAIL: {flags.driver}.{flags.case}") - sys.exit(1) -elif mistrust: - print(f"MISTRUST: {flags.driver}.{flags.case}") -else: - print(f"PASSED: {flags.driver}.{flags.case}") + print(f"PASSED: {flags.driver}.{flags.case}") + + +def performance_testing(flags, module, args): + gc_args: List[torch.Tensor | int] = [] + gc_tensors: Dict[str, torch.Tensor] = {} + for i in range(len(args)): + tensor = fill_tensor(flags, args[i], i) + gc_tensors["%arg" + str(i)] = tensor + if args[i].scalar: + gc_args.append(tensor.data_ptr()) + else: + gc_args.append(tensor) + + mlir_args = get_mlir_args(gc_args) + with module.context as ctx, ir.Location.unknown(): + if flags.ir_printing: + ctx.enable_multithreading(False) + bench_kind = py_timeit_bench if flags.bench_kind == "py" else mlir_wrapper_bench + execute_cost, compile_cost = bench_kind( + module, + flags.entry, + "any(gc-cpu-pipeline)", + mlir_args, + flags.ir_printing, + flags.repeat, + flags.warm_up, + ) + print("===========bench result===========") + json_res = json.dumps( + { + "args": vars(flags), + "compile_cost(ms)": compile_cost, + "execute_cost(ms)": execute_cost, + }, + indent=4, + ) + print(json_res) + + +if __name__ == "__main__": + arg_parser = argparse.ArgumentParser(prog="benchmark tool for graph compiler") + add_common_options(arg_parser) + add_bench_options(arg_parser) + add_pattern_options(arg_parser) + flags = arg_parser.parse_args() + benchgc.util.set_seed(flags.seed) + ir_module, module_args = get_module_and_args(flags) + if flags.mode == "C": + correctness_testing(flags, ir_module, module_args) + elif flags.mode == "P": + performance_testing(flags, ir_module, module_args) + else: + pass diff --git a/test/benchgc/src/benchgc/arg/__init__.py b/test/benchgc/src/benchgc/arg/__init__.py index a2134af2d..73c879e7b 100644 --- a/test/benchgc/src/benchgc/arg/__init__.py +++ b/test/benchgc/src/benchgc/arg/__init__.py @@ -27,6 +27,7 @@ import benchgc.util import torch from benchgc.arg.arg import Arg +from benchgc.pattern import get_pattern_clz onednn_module = { "binary": binary, @@ -53,6 +54,9 @@ def set_default_fill( if flags.driver + "." + flags.case in module.op: module.default_fill(flags, arg, arglist) return + elif flags.driver == "pattern": + get_pattern_clz(flags.case).default_fill(flags, arg, arglist) + return # use N(0, 1) as default arg.fill_type = "N" arg.fill_param = ["0", "1"] @@ -69,11 +73,14 @@ def set_default_compare( if flags.driver + "." + flags.case in module.op: module.default_compare(flags, arg, arglist) return + elif flags.driver == "pattern": + get_pattern_clz(flags.case).default_compare(flags, arg, arglist) + return dtype: torch.dtype = benchgc.util.get_dtype(arg.dtype) arg.cmp_type = "P" if dtype.is_floating_point: - arg.cmp_param = [str(torch.finfo(dtype).eps)] + arg.cmp_param = [str(1e-05)] else: arg.cmp_param = ["0"] if is_return: diff --git a/test/benchgc/src/benchgc/arith/__init__.py b/test/benchgc/src/benchgc/arith/__init__.py index a5f942a72..42d6dd0aa 100644 --- a/test/benchgc/src/benchgc/arith/__init__.py +++ b/test/benchgc/src/benchgc/arith/__init__.py @@ -18,21 +18,19 @@ import importlib from typing import Callable, Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.util import MLIRCache +from gc_mlir import ir ref_op: Dict[ str, Callable[ - [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + [MLIRCache, ir.OpView, Dict[str, torch.Tensor]], Tuple[torch.Tensor, ...], ], ] = {} -mlir_op: Dict[ - str, Callable[[argparse.Namespace, List[Arg], List[Arg]], gc_mlir.ir.Module] -] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg], List[Arg]], ir.Module]] = {} for dri in ["basic"]: mod = importlib.import_module(f"benchgc.arith.{dri}") diff --git a/test/benchgc/src/benchgc/arith/basic.py b/test/benchgc/src/benchgc/arith/basic.py index 7e4b17467..2da0aa022 100644 --- a/test/benchgc/src/benchgc/arith/basic.py +++ b/test/benchgc/src/benchgc/arith/basic.py @@ -18,20 +18,20 @@ import benchgc.util import gc_mlir._mlir_libs._mlir.ir -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def ref_constant( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: value = op.attributes["value"] - if isinstance(value, gc_mlir._mlir_libs._mlir.ir.FloatAttr): + if isinstance(value, ir.FloatAttr): return ( torch.full(size=tuple(), fill_value=value.__float__(), dtype=torch.float), ) - elif isinstance(value, gc_mlir._mlir_libs._mlir.ir.DenseFPElementsAttr): + elif isinstance(value, ir.DenseFPElementsAttr): if value.is_splat: return ( torch.full( @@ -47,12 +47,12 @@ def ref_constant( def ref_mulf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (var[cache.opr[0]] * var[cache.opr[1]],) def ref_addf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (var[cache.opr[0]] + var[cache.opr[1]],) diff --git a/tools/bench.py b/test/benchgc/src/benchgc/bench.py similarity index 99% rename from tools/bench.py rename to test/benchgc/src/benchgc/bench.py index 7cf06d409..483ff0023 100644 --- a/tools/bench.py +++ b/test/benchgc/src/benchgc/bench.py @@ -18,16 +18,16 @@ import ctypes import random import timeit -from typing import List, Sequence, Tuple +from typing import List, Tuple import numpy as np -from gc_mlir import ir, runtime -from gc_mlir.graph_compiler import GraphCompiler -from utils import ( +from benchgc.mlir.util import ( emit_benchmark_wrapped_main_func, emit_nano_time, get_kernel_func_from_module, ) +from gc_mlir import ir, runtime +from gc_mlir.graph_compiler import GraphCompiler def py_timeit_bench( @@ -85,6 +85,7 @@ def mlir_wrapper_bench( ) total_time = 0 ns_to_ms_scale = 1e-6 + def run(engine_invoke, bench_func_name, *mlir_args): engine_invoke(bench_func_name, *mlir_args) diff --git a/test/benchgc/src/benchgc/linalg/__init__.py b/test/benchgc/src/benchgc/linalg/__init__.py index 331bd75dd..e75068c9a 100644 --- a/test/benchgc/src/benchgc/linalg/__init__.py +++ b/test/benchgc/src/benchgc/linalg/__init__.py @@ -18,19 +18,19 @@ import importlib from typing import Callable, Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.util import MLIRCache +from gc_mlir import ir ref_op: Dict[ str, Callable[ - [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + [MLIRCache, ir.OpView, Dict[str, torch.Tensor]], Tuple[torch.Tensor, ...], ], ] = {} -mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg]], gc_mlir.ir.Module]] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, List[Arg]], ir.Module]] = {} for dri in [ "binary", diff --git a/test/benchgc/src/benchgc/linalg/binary.py b/test/benchgc/src/benchgc/linalg/binary.py index ed5d280a3..66f3a7abe 100644 --- a/test/benchgc/src/benchgc/linalg/binary.py +++ b/test/benchgc/src/benchgc/linalg/binary.py @@ -17,22 +17,23 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_add( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.add(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -42,13 +43,14 @@ def mlir_add(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_powf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.pow(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -58,13 +60,14 @@ def mlir_powf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_div( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.div(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -74,13 +77,14 @@ def mlir_div(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.max(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -90,13 +94,14 @@ def mlir_max(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_min( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.min(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -106,13 +111,14 @@ def mlir_min(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_mul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mul(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -122,13 +128,14 @@ def mlir_mul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_sub( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.sub(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_sub(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_sub(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/conv.py b/test/benchgc/src/benchgc/linalg/conv.py index c8fc38efb..f1f8d7f97 100644 --- a/test/benchgc/src/benchgc/linalg/conv.py +++ b/test/benchgc/src/benchgc/linalg/conv.py @@ -17,19 +17,19 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_conv_1d_ncw_fcw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv1d( var[cache.opr[0]], @@ -40,10 +40,9 @@ def ref_conv_1d_ncw_fcw( ) -def mlir_conv_1d_ncw_fcw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_1d_ncw_fcw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -59,10 +58,10 @@ def mlir_conv_1d_ncw_fcw( def ref_conv_1d_nwc_wcf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # src: nwc -> ncw # wei: wcf -> fcw @@ -80,10 +79,9 @@ def ref_conv_1d_nwc_wcf( ) -def mlir_conv_1d_nwc_wcf( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_1d_nwc_wcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -99,10 +97,10 @@ def mlir_conv_1d_nwc_wcf( def ref_conv_1d_ncw_fcw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv1d( var[cache.opr[0]], @@ -113,10 +111,9 @@ def ref_conv_1d_ncw_fcw( ) -def mlir_conv_1d_ncw_fcw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_1d_ncw_fcw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -132,7 +129,7 @@ def mlir_conv_1d_ncw_fcw( def ref_conv_1d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.conv1d( @@ -144,8 +141,9 @@ def ref_conv_1d( ) -def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -159,10 +157,10 @@ def mlir_conv_1d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Modul def ref_conv_2d_nchw_fchw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv2d( var[cache.opr[0]], @@ -173,10 +171,9 @@ def ref_conv_2d_nchw_fchw( ) -def mlir_conv_2d_nchw_fchw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_nchw_fchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -192,10 +189,10 @@ def mlir_conv_2d_nchw_fchw( def ref_conv_2d_ngchw_fgchw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] src = var[cache.opr[0]] wei = var[cache.opr[1]] @@ -221,10 +218,9 @@ def ref_conv_2d_ngchw_fgchw( ) # split group axis from output channel -def mlir_conv_2d_ngchw_fgchw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_ngchw_fgchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -240,10 +236,10 @@ def mlir_conv_2d_ngchw_fgchw( def ref_conv_2d_ngchw_gfchw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] src = var[cache.opr[0]] wei = var[cache.opr[1]] @@ -267,10 +263,9 @@ def ref_conv_2d_ngchw_gfchw( ) # split group axis from output channel -def mlir_conv_2d_ngchw_gfchw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_ngchw_gfchw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -286,10 +281,10 @@ def mlir_conv_2d_ngchw_gfchw( def ref_conv_2d_nhwc_fhwc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv2d( var[cache.opr[0]].permute([0, 3, 1, 2]), @@ -302,10 +297,9 @@ def ref_conv_2d_nhwc_fhwc( ) -def mlir_conv_2d_nhwc_fhwc( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_nhwc_fhwc(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -321,10 +315,10 @@ def mlir_conv_2d_nhwc_fhwc( def ref_conv_2d_nhwc_hwcf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv2d( var[cache.opr[0]].permute([0, 3, 1, 2]), @@ -337,10 +331,9 @@ def ref_conv_2d_nhwc_hwcf( ) -def mlir_conv_2d_nhwc_hwcf( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_2d_nhwc_hwcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -356,7 +349,7 @@ def mlir_conv_2d_nhwc_hwcf( def ref_conv_2d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.conv2d( @@ -368,8 +361,9 @@ def ref_conv_2d( ) -def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -383,10 +377,10 @@ def mlir_conv_2d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Modul def ref_conv_3d_ncdhw_fcdhw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv3d( var[cache.opr[0]], @@ -397,10 +391,9 @@ def ref_conv_3d_ncdhw_fcdhw( ) -def mlir_conv_3d_ncdhw_fcdhw( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_3d_ncdhw_fcdhw(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -416,10 +409,10 @@ def mlir_conv_3d_ncdhw_fcdhw( def ref_conv_3d_ndhwc_dhwcf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.conv3d( var[cache.opr[0]].permute([0, 4, 1, 2, 3]), @@ -432,10 +425,9 @@ def ref_conv_3d_ndhwc_dhwcf( ) -def mlir_conv_3d_ndhwc_dhwcf( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_conv_3d_ndhwc_dhwcf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -451,7 +443,7 @@ def mlir_conv_3d_ndhwc_dhwcf( def ref_conv_3d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.conv3d( @@ -463,8 +455,9 @@ def ref_conv_3d( ) -def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -478,10 +471,10 @@ def mlir_conv_3d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Modul def ref_depthwise_conv_1d_ncw_cw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[1] return ( torch.conv1d( @@ -496,8 +489,9 @@ def ref_depthwise_conv_1d_ncw_cw( def mlir_depthwise_conv_1d_ncw_cw( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -513,10 +507,10 @@ def mlir_depthwise_conv_1d_ncw_cw( def ref_depthwise_conv_1d_nwc_wc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] return ( torch.conv1d( @@ -533,8 +527,9 @@ def ref_depthwise_conv_1d_nwc_wc( def mlir_depthwise_conv_1d_nwc_wc( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -550,10 +545,10 @@ def mlir_depthwise_conv_1d_nwc_wc( def ref_depthwise_conv_1d_nwc_wcm( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] src = var[cache.opr[0]] groups: int = src.shape[-1] wei = var[cache.opr[1]] @@ -575,8 +570,9 @@ def ref_depthwise_conv_1d_nwc_wcm( def mlir_depthwise_conv_1d_nwc_wcm( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -592,10 +588,10 @@ def mlir_depthwise_conv_1d_nwc_wcm( def ref_depthwise_conv_2d_nchw_chw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[1] return ( torch.conv2d( @@ -610,8 +606,9 @@ def ref_depthwise_conv_2d_nchw_chw( def mlir_depthwise_conv_2d_nchw_chw( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -627,10 +624,10 @@ def mlir_depthwise_conv_2d_nchw_chw( def ref_depthwise_conv_2d_nhwc_hwc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] return ( torch.conv2d( @@ -647,8 +644,9 @@ def ref_depthwise_conv_2d_nhwc_hwc( def mlir_depthwise_conv_2d_nhwc_hwc( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -664,10 +662,10 @@ def mlir_depthwise_conv_2d_nhwc_hwc( def ref_depthwise_conv_2d_nhwc_hwcm( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] wei = var[cache.opr[1]] dst = ( @@ -692,8 +690,9 @@ def ref_depthwise_conv_2d_nhwc_hwcm( def mlir_depthwise_conv_2d_nhwc_hwcm( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -709,10 +708,10 @@ def mlir_depthwise_conv_2d_nhwc_hwcm( def ref_depthwise_conv_3d_ncdhw_cdhw( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[1] return ( torch.conv3d( @@ -727,8 +726,9 @@ def ref_depthwise_conv_3d_ncdhw_cdhw( def mlir_depthwise_conv_3d_ncdhw_cdhw( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -744,10 +744,10 @@ def mlir_depthwise_conv_3d_ncdhw_cdhw( def ref_depthwise_conv_3d_ndhwc_dhwc( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] return ( torch.conv3d( @@ -764,8 +764,9 @@ def ref_depthwise_conv_3d_ndhwc_dhwc( def mlir_depthwise_conv_3d_ndhwc_dhwc( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -781,10 +782,10 @@ def mlir_depthwise_conv_3d_ndhwc_dhwc( def ref_depthwise_conv_3d_ndhwc_dhwcm( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] groups: int = var[cache.opr[0]].shape[-1] wei = var[cache.opr[1]] dst = ( @@ -818,8 +819,9 @@ def ref_depthwise_conv_3d_ndhwc_dhwcm( def mlir_depthwise_conv_3d_ndhwc_dhwcm( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/eltwise.py b/test/benchgc/src/benchgc/linalg/eltwise.py index 7ae9b31b7..760fcb0a1 100644 --- a/test/benchgc/src/benchgc/linalg/eltwise.py +++ b/test/benchgc/src/benchgc/linalg/eltwise.py @@ -17,22 +17,23 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_abs( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.abs(var[cache.opr[0]]),) -def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.abs(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -40,13 +41,14 @@ def mlir_abs(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_ceil( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.ceil(var[cache.opr[0]]),) -def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.ceil(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -54,13 +56,14 @@ def mlir_ceil(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_floor( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.floor(var[cache.opr[0]]),) -def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.floor(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -68,21 +71,23 @@ def mlir_floor(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_erf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.erf(var[cache.opr[0]]),) -def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_erf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.erf(arg0, outs=[args[1].get_zero_op(ctx)])], ) -def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.log(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -90,13 +95,14 @@ def mlir_log(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_log( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.log(var[cache.opr[0]]),) -def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -104,19 +110,20 @@ def mlir_negf(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_negf( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.neg(var[cache.opr[0]]),) def ref_exp( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.exp(var[cache.opr[0]]),) -def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.negf(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -124,7 +131,7 @@ def mlir_exp(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_round( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # torch.round is following the priciple "round half to even" # we need another implementation @@ -133,8 +140,9 @@ def ref_round( return (v + torch.where(var[cache.opr[0]] - v >= 0.5, 1, 0),) -def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.round(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -142,13 +150,14 @@ def mlir_round(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_rsqrt( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.rsqrt(var[cache.opr[0]]),) -def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.rsqrt(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -156,13 +165,14 @@ def mlir_rsqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_sqrt( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.sqrt(var[cache.opr[0]]),) -def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.sqrt(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -170,13 +180,14 @@ def mlir_sqrt(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_square( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.square(var[cache.opr[0]]),) -def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.square(arg0, outs=[args[1].get_zero_op(ctx)])], @@ -184,13 +195,14 @@ def mlir_square(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module def ref_tanh( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.tanh(var[cache.opr[0]]),) -def mlir_tanh(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_tanh(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [linalg.tanh(arg0, outs=[args[1].get_zero_op(ctx)])], diff --git a/test/benchgc/src/benchgc/linalg/generic.py b/test/benchgc/src/benchgc/linalg/generic.py index 67228ab47..6cfded39d 100644 --- a/test/benchgc/src/benchgc/linalg/generic.py +++ b/test/benchgc/src/benchgc/linalg/generic.py @@ -18,14 +18,14 @@ import benchgc.runner import benchgc.util -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def generic_loop( cache: MLIRCache, - op: gc_mlir.ir.OpView, + op: ir.OpView, depth: int, iterspace: Dict[str, Tuple[int, int, int]], affine_from: List[str], @@ -42,7 +42,7 @@ def generic_loop( # region cache cache.next.append(MLIRCache()) - block: gc_mlir.ir.Block = op.regions[0].blocks[0] + block: ir.Block = op.regions[0].blocks[0] if len(cache.next[0].next) == 0: # region->block cache cache.next[0].next.append(MLIRCache()) @@ -96,7 +96,7 @@ def generic_loop( def ref_generic( - cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: affine_from: List[str] = [] affine_to: List[List[str]] = [] @@ -110,7 +110,7 @@ def ref_generic( # TODO: support affine expression iterspace: Dict[str, Tuple[int, int, int]] = {} - operands: List[gc_mlir.ir.OpOperand] = list(op.operands) + operands: List[ir.OpOperand] = list(op.operands) loop_var: Dict[str, torch.Tensor] = {} for d in affine_from: @@ -142,7 +142,7 @@ def ref_generic( def reduce_loop( cache: MLIRCache, - op: gc_mlir.ir.OpView, + op: ir.OpView, depth: int, in_shape: List[int], var: Dict[str, torch.Tensor], @@ -155,7 +155,7 @@ def reduce_loop( # we need to execute the block here # we will need to read the block argument name and save it into the cache - block: gc_mlir.ir.Block = op.regions[0].blocks[0] + block: ir.Block = op.regions[0].blocks[0] if len(cache.next) == 0: # region cache @@ -180,7 +180,7 @@ def reduce_loop( # perform the yield operation result_tensor[tuple(out_idx)] = res[0] else: - dimensions: gc_mlir.ir.DenseI64ArrayAttr = op.attributes["dimensions"] + dimensions: ir.DenseI64ArrayAttr = op.attributes["dimensions"] reduce_axis: bool = depth in list(dimensions) for i in range(in_shape[depth]): @@ -214,7 +214,7 @@ def reduce_loop( def ref_reduce( - cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # create the buffer for result tensors tensors[cache.res[0]] = tensors[cache.opr[-1]].clone() diff --git a/test/benchgc/src/benchgc/linalg/matmul.py b/test/benchgc/src/benchgc/linalg/matmul.py index 9efde9612..16ad0519c 100644 --- a/test/benchgc/src/benchgc/linalg/matmul.py +++ b/test/benchgc/src/benchgc/linalg/matmul.py @@ -17,23 +17,24 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType def ref_batch_matmul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.matmul(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -43,15 +44,16 @@ def mlir_batch_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir. def ref_batch_matmul_transpose_a( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.bmm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) def mlir_batch_matmul_transpose_a( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -61,15 +63,16 @@ def mlir_batch_matmul_transpose_a( def ref_batch_matmul_transpose_b( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.bmm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) def mlir_batch_matmul_transpose_b( flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -79,7 +82,7 @@ def mlir_batch_matmul_transpose_b( def ref_batch_matvec( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # pytorch does not support bmv return ( @@ -87,8 +90,9 @@ def ref_batch_matvec( ) -def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -98,7 +102,7 @@ def mlir_batch_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir. def ref_batch_mmt4d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # [B, m, k, m0, k0] -> [B, m, m0, k, k0] _src = var[cache.opr[0]].permute([0, 1, 3, 2, 4]).contiguous() @@ -124,8 +128,9 @@ def ref_batch_mmt4d( return (dst.transpose(2, 3).contiguous(),) -def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -135,7 +140,7 @@ def mlir_batch_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.M def ref_batch_reduce_matmul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.addbmm( @@ -148,10 +153,9 @@ def ref_batch_reduce_matmul( ) -def mlir_batch_reduce_matmul( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_batch_reduce_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -161,15 +165,16 @@ def mlir_batch_reduce_matmul( def ref_batch_vecmat( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), ) -def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -179,13 +184,14 @@ def mlir_batch_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir. def ref_dot( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.dot(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -195,13 +201,14 @@ def mlir_dot(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_matmul( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mm(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -213,15 +220,14 @@ def mlir_matmul(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module def ref_matmul_transpose_a( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mm(var[cache.opr[0]].transpose(-1, -2), var[cache.opr[1]]),) -def mlir_matmul_transpose_a( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_matmul_transpose_a(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -233,15 +239,14 @@ def mlir_matmul_transpose_a( def ref_matmul_transpose_b( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mm(var[cache.opr[0]], var[cache.opr[1]].transpose(-1, -2)),) -def mlir_matmul_transpose_b( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_matmul_transpose_b(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -253,13 +258,14 @@ def mlir_matmul_transpose_b( def ref_matvec( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.mv(var[cache.opr[0]], var[cache.opr[1]]),) -def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -269,7 +275,7 @@ def mlir_matvec(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module def ref_mmt4d( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # [m, k, m0, k0] -> [m, m0, k, k0] _src = var[cache.opr[0]].permute([0, 2, 1, 3]).contiguous() @@ -289,8 +295,9 @@ def ref_mmt4d( return (dst.transpose(1, 2).contiguous(),) -def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -300,15 +307,16 @@ def mlir_mmt4d(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_vecmat( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.matmul(var[cache.opr[0]].unsqueeze(-2), var[cache.opr[1]]).squeeze(-2), ) -def mlir_vecmat(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_vecmat(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/misc.py b/test/benchgc/src/benchgc/linalg/misc.py index cf672956c..05f8ebbbe 100644 --- a/test/benchgc/src/benchgc/linalg/misc.py +++ b/test/benchgc/src/benchgc/linalg/misc.py @@ -19,12 +19,11 @@ from typing import Dict, List, Tuple import benchgc.util -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache -from gc_mlir._mlir_libs._mlir.ir import DenseI64ArrayAttr +from gc_mlir import ir from gc_mlir.dialects import linalg from gc_mlir.dialects.linalg.opdsl.lang.comprehension import TypeFnType @@ -32,20 +31,21 @@ # 1. use to reshape to match ndim # 2. perform broadcast def ref_broadcast( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: dst_shape: List[int] = op.results[0].type.shape tmp_shape = copy.copy(dst_shape) - dimensions: DenseI64ArrayAttr = op.attributes["dimensions"] + dimensions: ir.DenseI64ArrayAttr = op.attributes["dimensions"] for d in dimensions: tmp_shape[d] = 1 return (var[cache.opr[0]].reshape(tmp_shape).broadcast_to(dst_shape).contiguous(),) -def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ @@ -57,13 +57,14 @@ def mlir_broadcast(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Mod def ref_fill( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return (torch.full(tuple(op.results[0].type.shape), var[cache.opr[0]]),) -def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ @@ -75,7 +76,7 @@ def mlir_fill(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: def ref_copy( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( var[cache.opr[0]] @@ -84,9 +85,10 @@ def ref_copy( ) -def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_copy(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ diff --git a/test/benchgc/src/benchgc/linalg/pool.py b/test/benchgc/src/benchgc/linalg/pool.py index 9779256df..755e4c76a 100644 --- a/test/benchgc/src/benchgc/linalg/pool.py +++ b/test/benchgc/src/benchgc/linalg/pool.py @@ -17,19 +17,19 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_pooling_nchw_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool2d( var[cache.opr[0]], @@ -40,10 +40,9 @@ def ref_pooling_nchw_max( ) -def mlir_pooling_nchw_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nchw_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -59,10 +58,10 @@ def mlir_pooling_nchw_max( def ref_pooling_nchw_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool2d or lp_pool2d with p = 1 does not support dilation @@ -83,10 +82,9 @@ def ref_pooling_nchw_sum( ) -def mlir_pooling_nchw_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nchw_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -102,10 +100,10 @@ def mlir_pooling_nchw_sum( def ref_pooling_ncw_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool1d( var[cache.opr[0]], @@ -116,10 +114,9 @@ def ref_pooling_ncw_max( ) -def mlir_pooling_ncw_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ncw_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -135,10 +132,10 @@ def mlir_pooling_ncw_max( def ref_pooling_ncw_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool1d or lp_pool1d with p = 1 does not support dilation @@ -159,10 +156,9 @@ def ref_pooling_ncw_sum( ) -def mlir_pooling_ncw_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ncw_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -178,10 +174,10 @@ def mlir_pooling_ncw_sum( def ref_pooling_ndhwc_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool3d( var[cache.opr[0]].permute([0, -1, 1, 2, 3]), @@ -194,10 +190,9 @@ def ref_pooling_ndhwc_max( ) -def mlir_pooling_ndhwc_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ndhwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -213,10 +208,10 @@ def mlir_pooling_ndhwc_max( def ref_pooling_ndhwc_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool3d or lp_pool3d with p = 1 does not support dilation @@ -239,10 +234,9 @@ def ref_pooling_ndhwc_sum( ) -def mlir_pooling_ndhwc_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_ndhwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -258,10 +252,10 @@ def mlir_pooling_ndhwc_sum( def ref_pooling_nhwc_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool2d( var[cache.opr[0]].permute([0, -1, 1, 2]), @@ -274,10 +268,9 @@ def ref_pooling_nhwc_max( ) -def mlir_pooling_nhwc_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nhwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -293,10 +286,10 @@ def mlir_pooling_nhwc_max( def ref_pooling_nhwc_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool2d or lp_pool2d with p = 1 does not support dilation @@ -319,10 +312,9 @@ def ref_pooling_nhwc_sum( ) -def mlir_pooling_nhwc_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nhwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -338,10 +330,10 @@ def mlir_pooling_nhwc_sum( def ref_pooling_nhwc_min( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool2d( var[cache.opr[0]].permute([0, -1, 1, 2]).neg(), @@ -355,10 +347,9 @@ def ref_pooling_nhwc_min( ) -def mlir_pooling_nhwc_min( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nhwc_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -374,10 +365,10 @@ def mlir_pooling_nhwc_min( def ref_pooling_nwc_max( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool1d( var[cache.opr[0]].permute([0, -1, 1]), @@ -390,10 +381,9 @@ def ref_pooling_nwc_max( ) -def mlir_pooling_nwc_max( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nwc_max(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -409,10 +399,10 @@ def mlir_pooling_nwc_max( def ref_pooling_nwc_min( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] return ( torch.max_pool1d( var[cache.opr[0]].permute([0, -1, 1]).neg(), @@ -426,10 +416,9 @@ def ref_pooling_nwc_min( ) -def mlir_pooling_nwc_min( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nwc_min(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ @@ -445,10 +434,10 @@ def mlir_pooling_nwc_min( def ref_pooling_nwc_sum( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - strides: gc_mlir.ir.DenseIntElementsAttr = op.attributes["strides"] - dilations: gc_mlir.ir.DenseIntElementsAttr = op.attributes["dilations"] + strides: ir.DenseIntElementsAttr = op.attributes["strides"] + dilations: ir.DenseIntElementsAttr = op.attributes["dilations"] # pytorch does not support pooling on sum # avg_pool3d or lp_pool3d with p = 1 does not support dilation @@ -471,10 +460,9 @@ def ref_pooling_nwc_sum( ) -def mlir_pooling_nwc_sum( - flags: argparse.Namespace, args: List[Arg] -) -> gc_mlir.ir.Module: +def mlir_pooling_nwc_sum(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0], args[1]), (args[2],), lambda ctx, arg0, arg1: [ diff --git a/test/benchgc/src/benchgc/linalg/softmax.py b/test/benchgc/src/benchgc/linalg/softmax.py index 20ed39fcb..e56376404 100644 --- a/test/benchgc/src/benchgc/linalg/softmax.py +++ b/test/benchgc/src/benchgc/linalg/softmax.py @@ -17,23 +17,24 @@ import argparse from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.module import init_module from benchgc.mlir.util import MLIRCache +from gc_mlir import ir from gc_mlir.dialects import linalg def ref_softmax( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: - dimension: gc_mlir.ir.IntegerAttr = op.attributes["dimension"] + dimension: ir.IntegerAttr = op.attributes["dimension"] return (torch.softmax(var[cache.opr[0]], dimension.value),) -def mlir_softmax(flags: argparse.Namespace, args: List[Arg]) -> gc_mlir.ir.Module: +def mlir_softmax(flags: argparse.Namespace, args: List[Arg]) -> ir.Module: return init_module( + flags.entry, (args[0],), (args[1],), lambda ctx, arg0: [ diff --git a/test/benchgc/src/benchgc/mlir/arg.py b/test/benchgc/src/benchgc/mlir/arg.py index 364b9d92c..2fb0a9871 100644 --- a/test/benchgc/src/benchgc/mlir/arg.py +++ b/test/benchgc/src/benchgc/mlir/arg.py @@ -21,9 +21,9 @@ import gc_mlir.dialects.arith import gc_mlir.dialects.linalg import gc_mlir.dialects.tensor -import gc_mlir.ir import torch from benchgc.mlir.util import dtype_to_ctype, str_to_mlir_dtype, str_to_mlir_typed_attr +from gc_mlir import ir # scalar should give a address @@ -95,8 +95,7 @@ def set_md(self, md: str): self.dtype = splited[-1] self.shape = [] - for dim in splited[:-1]: - self.shape.append(int(dim)) + self.shape = [int(x) for x in splited[:-1]] self.set_scalar() def set_scalar(self): @@ -118,24 +117,18 @@ def nelem(self) -> int: ret = ret * dim return ret - def get_mlir_type(self, ctx: gc_mlir.ir.Context) -> gc_mlir.ir.Type: + def get_mlir_type(self, ctx: ir.Context) -> ir.Type: if self.scalar: return str_to_mlir_dtype(ctx, self.dtype) else: - return gc_mlir.ir.RankedTensorType.get( + return ir.RankedTensorType.get( self.shape, str_to_mlir_dtype(ctx, self.dtype) ) - def get_ranked_tensor_type( - self, ctx: gc_mlir.ir.Context - ) -> gc_mlir.ir.RankedTensorType: - return gc_mlir.ir.RankedTensorType.get( - self.shape, str_to_mlir_dtype(ctx, self.dtype) - ) + def get_ranked_tensor_type(self, ctx: ir.Context) -> ir.RankedTensorType: + return ir.RankedTensorType.get(self.shape, str_to_mlir_dtype(ctx, self.dtype)) - def get_constant_op( - self, ctx: gc_mlir.ir.Context, cst: Any - ) -> gc_mlir.dialects.tensor.OpView: + def get_constant_op(self, ctx: ir.Context, cst: Any) -> ir.OpView: zero = gc_mlir.dialects.arith.ConstantOp( value=str_to_mlir_typed_attr(ctx, self.dtype, cst), result=str_to_mlir_dtype(ctx, self.dtype), @@ -152,21 +145,17 @@ def get_constant_op( ], ) - def get_zero_op(self, ctx: gc_mlir.ir.Context) -> gc_mlir.dialects.tensor.OpView: + def get_zero_op(self, ctx: ir.Context) -> ir.OpView: return self.get_constant_op(ctx, 0) - def get_max_value_op( - self, ctx: gc_mlir.ir.Context - ) -> gc_mlir.dialects.tensor.OpView: + def get_max_value_op(self, ctx: ir.Context) -> ir.OpView: dtype = benchgc.util.get_dtype(self.dtype) if dtype.is_floating_point: return self.get_constant_op(ctx, torch.finfo(dtype).max) else: return self.get_constant_op(ctx, torch.iinfo(dtype).max) - def get_min_value_op( - self, ctx: gc_mlir.ir.Context - ) -> gc_mlir.dialects.tensor.OpView: + def get_min_value_op(self, ctx: ir.Context) -> ir.OpView: dtype = benchgc.util.get_dtype(self.dtype) if dtype.is_floating_point: return self.get_constant_op(ctx, torch.finfo(dtype).min) diff --git a/test/benchgc/src/benchgc/mlir/module.py b/test/benchgc/src/benchgc/mlir/module.py index 806c9d8b7..69dfd9a90 100644 --- a/test/benchgc/src/benchgc/mlir/module.py +++ b/test/benchgc/src/benchgc/mlir/module.py @@ -16,33 +16,33 @@ from typing import Callable, List, Tuple -import gc_mlir.dialects.tensor -import gc_mlir.ir from benchgc.mlir.arg import MLIRArg +from gc_mlir import ir from gc_mlir.dialects import func def init_module( + entry_name: str, inputs: Tuple[MLIRArg, ...], outputs: Tuple[MLIRArg, ...], op_func: Callable[ - [gc_mlir.ir.Context, Tuple[gc_mlir.ir.BlockArgument, ...]], - List[gc_mlir.ir.OpResult], + [ir.Context, Tuple[ir.BlockArgument, ...]], + List[ir.OpResult], ], -) -> gc_mlir.ir.Module: - with gc_mlir.ir.Context() as ctx, gc_mlir.ir.Location.unknown(): - module = gc_mlir.ir.Module.create() - with gc_mlir.ir.InsertionPoint(module.body): +) -> ir.Module: + with ir.Context() as ctx, ir.Location.unknown(): + module = ir.Module.create() + with ir.InsertionPoint(module.body): f = func.FuncOp( - name="entry", - type=gc_mlir.ir.FunctionType.get( + name=entry_name, + type=ir.FunctionType.get( inputs=[x.get_mlir_type(ctx) for x in inputs], results=[x.get_mlir_type(ctx) for x in outputs], ), ) - f.attributes["llvm.emit_c_interface"] = gc_mlir.ir.UnitAttr.get() + f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with gc_mlir.ir.InsertionPoint(f.add_entry_block()): + with ir.InsertionPoint(f.add_entry_block()): block_args = f.entry_block.arguments func.ReturnOp(op_func(ctx, *block_args)) return module diff --git a/test/benchgc/src/benchgc/mlir/util.py b/test/benchgc/src/benchgc/mlir/util.py index 24169bca1..ee6de870b 100644 --- a/test/benchgc/src/benchgc/mlir/util.py +++ b/test/benchgc/src/benchgc/mlir/util.py @@ -17,19 +17,9 @@ import ctypes from typing import Any, List -import gc_mlir.ir import torch -from gc_mlir.dialects import func - -# only python 3.11 support -# from typing import Self - - -def get_entry(module: gc_mlir.ir.Module, entry: str = '"entry"') -> func.FuncOp: - for op in module.operation.opview.regions[0].blocks[0].operations: - if str(op.name) == entry: - return op - raise Exception(f"entry function {entry} is not found at the top level") +from gc_mlir import ir +from gc_mlir.dialects import arith, func, memref # calling python binding consumes a lot of time e.g. get_name() @@ -72,40 +62,87 @@ def dtype_to_ctype(dtype: torch.dtype): raise ValueError(f"Unsupported torch dtype: {dtype}") -def str_to_mlir_dtype(ctx: gc_mlir.ir.Context, dtype: str) -> gc_mlir.ir.Type: +def str_to_mlir_dtype(ctx: ir.Context, dtype: str) -> ir.Type: if dtype == "f32": - return gc_mlir.ir.F32Type.get(ctx) + return ir.F32Type.get(ctx) elif dtype == "f64": - return gc_mlir.ir.F64Type.get(ctx) + return ir.F64Type.get(ctx) elif dtype == "f16": - return gc_mlir.ir.F16Type.get(ctx) + return ir.F16Type.get(ctx) elif dtype == "bf16": - return gc_mlir.ir.BF16Type.get(ctx) + return ir.BF16Type.get(ctx) elif dtype == "u8": - return gc_mlir.ir.IntegerType.get_unsigned(8, ctx) + return ir.IntegerType.get_unsigned(8, ctx) elif dtype == "s8": - return gc_mlir.ir.IntegerType.get_signed(8, ctx) + return ir.IntegerType.get_signed(8, ctx) elif dtype == "boolean": - return gc_mlir.ir.IntegerType.get_unsigned(1, ctx) + return ir.IntegerType.get_unsigned(1, ctx) elif dtype == "f8_e4m3": - return gc_mlir.ir.Float8E4M3FNType.get(ctx) + return ir.Float8E4M3FNType.get(ctx) elif dtype == "f8_e5m2": - return gc_mlir.ir.Float8E5M2Type.get(ctx) + return ir.Float8E5M2Type.get(ctx) elif dtype == "s32": - return gc_mlir.ir.IntegerType.get_signed(32, ctx) + return ir.IntegerType.get_signed(32, ctx) else: raise Exception(f"data type not support: {dtype}") -def str_to_mlir_typed_attr( - ctx: gc_mlir.ir.Context, dtype: str, value: Any -) -> gc_mlir.ir.Attribute: +def str_to_mlir_typed_attr(ctx: ir.Context, dtype: str, value: Any) -> ir.Attribute: mlir_dtype = str_to_mlir_dtype(ctx, dtype) if dtype in ["f32", "f64", "bf16", "f16", "f8_e4m3", "f8_e5m2"]: - return gc_mlir.ir.FloatAttr.get(mlir_dtype, value) + return ir.FloatAttr.get(mlir_dtype, value) elif dtype in ["u8", "s8", "s32"]: - return gc_mlir.ir.IntegerAttr.get(mlir_dtype, value) + return ir.IntegerAttr.get(mlir_dtype, value) elif dtype == "boolean": - return gc_mlir.ir.BoolAttr.get(value) + return ir.BoolAttr.get(value) else: raise Exception(f"data type not support: {dtype}") + + +def emit_nano_time() -> func.FuncOp: + """Emit a nanoTime function that returns the current time in nanoseconds.""" + nanoTime = func.FuncOp( + "nanoTime", ([], [ir.IntegerType.get_signless(64)]), visibility="private" + ) + nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + return nanoTime + + +def emit_benchmark_wrapped_main_func( + kernel_func: func.FuncOp, timer_func: func.FuncOp +) -> func.FuncOp: + """Emit a wrapped main function that calls the kernel function and records the time taken.""" + memref_of_i64_type = ir.MemRefType.get([1], ir.IntegerType.get_signless(64)) + wrapped_func_name = "wrapped_main" + assert wrapped_func_name != str( + kernel_func.name + ), "wrapped function name should be different from kernel function name" + wrapped_func = func.FuncOp( + wrapped_func_name, + ([memref_of_i64_type] + kernel_func.arguments.types, kernel_func.type.results), + visibility="public", + ) + wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(wrapped_func.add_entry_block()): + timer_buffer = wrapped_func.arguments[0] + start = func.CallOp(timer_func, []) + call_op = func.CallOp( + kernel_func, + list(wrapped_func.arguments[1:]), + ) + end = func.CallOp(timer_func, []) + time_taken = arith.SubIOp(end, start) + zero = arith.ConstantOp.create_index(0) + memref.StoreOp(time_taken, timer_buffer, [zero]) + func.ReturnOp(call_op.results) + return wrapped_func + + +def get_kernel_func_from_module( + module: ir.Module, func_name: str = "entry" +) -> func.FuncOp: + """Get the func op by the name from a module""" + for f in module.operation.regions[0].blocks[0].operations: + if type(f) is func.FuncOp and str(f.name).strip('"') == func_name: + return f + raise ValueError("can not find the entry function") diff --git a/test/benchgc/src/benchgc/pattern/CMakeLists.txt b/test/benchgc/src/benchgc/pattern/CMakeLists.txt new file mode 100644 index 000000000..51683ac19 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/CMakeLists.txt @@ -0,0 +1,22 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + +file(GLOB PYTHON_SCRIPTS "*.py") +foreach(PY_SCRIPT ${PYTHON_SCRIPTS}) + configure_file(${PY_SCRIPT} ${CMAKE_BINARY_DIR}/test/benchgc/src/benchgc/pattern/ COPYONLY) +endforeach() \ No newline at end of file diff --git a/test/benchgc/src/benchgc/pattern/__init__.py b/test/benchgc/src/benchgc/pattern/__init__.py new file mode 100644 index 000000000..b0f5cbce0 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/__init__.py @@ -0,0 +1,26 @@ +################################################################################ +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from .base import Pattern +from .mlp import MLP + +__all__ = ["Pattern", "MLP", "get_pattern_clz"] + + +def get_pattern_clz(name: str): + """Function getting pattern class by name.""" + clz = {"mlp": MLP}[name] + return clz diff --git a/test/benchgc/src/benchgc/pattern/base.py b/test/benchgc/src/benchgc/pattern/base.py new file mode 100644 index 000000000..42527efc3 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/base.py @@ -0,0 +1,43 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +import argparse +from abc import ABC, abstractmethod + +from gc_mlir import ir + + +class Pattern(ABC): + """Abstract class for pattern.""" + + @staticmethod + @abstractmethod + def add_args(parser: argparse.ArgumentParser): + """Add arguments to parser""" + + @abstractmethod + def handle_args(self, args: argparse.Namespace): + """Get and handle the args""" + + def __init__(self, ctx: ir.Context, flags: argparse.Namespace): + self.main_entry = flags.entry + self.handle_args(flags) + self.ir_module = self.init_module(ctx) + + @abstractmethod + def init_module(self, ctx: ir.Context) -> ir.Module: + """Create MLIR moudule by args""" diff --git a/tools/drivers.py b/test/benchgc/src/benchgc/pattern/mlp.py similarity index 63% rename from tools/drivers.py rename to test/benchgc/src/benchgc/pattern/mlp.py index a9bdc95d0..1f663b66a 100644 --- a/tools/drivers.py +++ b/test/benchgc/src/benchgc/pattern/mlp.py @@ -16,91 +16,19 @@ ################################################################################ import argparse -from abc import ABC, abstractmethod +from re import U from typing import List -import numpy as np +from benchgc.arg.arg import Arg +from benchgc.mlir.util import str_to_mlir_dtype +from benchgc.util import to_bool_list, to_int_list from gc_mlir import ir from gc_mlir.dialects import arith, func, linalg, tensor -from gc_mlir.ir import BF16Type, FloatAttr -from utils import ( - STR_TO_MLIR_TYPE, - get_default_passes, - get_kernel_func_from_module, - make_mlir_ndarray, - to_bool_list, - to_int_list, -) +from .base import Pattern -class Driver(ABC): - """Abstract class for driver.""" - @staticmethod - @abstractmethod - def add_args(parser: argparse.ArgumentParser): - """Add arguments to parser""" - pass - - @abstractmethod - def handle_args(self, args: argparse.Namespace): - """Get and handle the args""" - pass - - def __init__(self, ctx: ir.Context, args: argparse.Namespace): - self.main_entry = "main_entry" - self.handle_args(args) - self.ir_module = self.init_module(ctx) - - @abstractmethod - def init_module(self, ctx: ir.Context) -> ir.Module: - """Create MLIR moudule by args""" - pass - - @abstractmethod - def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: - """Create numpy arg for entry function""" - pass - - def get_passes(self) -> str: - """Get pass pipeline""" - return get_default_passes() - - -class LoadMLIR(Driver): - @staticmethod - def add_args(parser: argparse.ArgumentParser): - parser.add_argument("--path", type=str, required=True) - parser.add_argument("--entry", type=str, default="main_entry") - - def handle_args(self, args: argparse.Namespace): - self.path = args.path - self.main_entry = args.entry - - def _get_mlir(self): - with open(self.path, "r") as file: - content = file.read() - return content - - def init_module(self, ctx: ir.Context) -> ir.Module: - module = ir.Module.parse(self._get_mlir(), ctx) - bench_func = get_kernel_func_from_module(module, self.main_entry) - bench_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - return module - - def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: - bench_func = get_kernel_func_from_module(self.ir_module, self.main_entry) - np_args = [] - for arg in bench_func.arguments: - np_args.append(make_mlir_ndarray(arg.type)) - - if not disable_results_to_params: - for res in bench_func.type.results: - np_args.append(make_mlir_ndarray(res)) - - return np_args - -class MLP(Driver): +class MLP(Pattern): @staticmethod def add_args(parser: argparse.ArgumentParser): parser.add_argument("--batch_size", type=int, default=1) @@ -142,7 +70,7 @@ def init_module(self, ctx: ir.Context) -> ir.Module: with ctx, ir.Location.unknown(): layers = len(self.hidden_size_list) - 1 module = ir.Module.create() - dtype = STR_TO_MLIR_TYPE(self.dtype, ctx) + dtype = str_to_mlir_dtype(ctx, self.dtype) src = ir.RankedTensorType.get( [self.batch_size, self.hidden_size_list[0]], dtype ) @@ -177,9 +105,13 @@ def init_module(self, ctx: ir.Context) -> ir.Module: ), ) f.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + with ir.InsertionPoint(f.add_entry_block()): data = f.entry_block.arguments[0] bias_idx = len(weights) + 1 + zero = arith.ConstantOp( + value=ir.FloatAttr.get(dtype, 0.0), result=dtype + ).result for i in range(layers): weight = f.entry_block.arguments[i + 1] if self.has_bias[i]: @@ -191,10 +123,10 @@ def init_module(self, ctx: ir.Context) -> ir.Module: self.batch_size, self.hidden_size_list[i + 1], ] - - data = linalg.matmul( - data, weight, outs=[tensor.EmptyOp(layer_out_shape, dtype)] + out = linalg.fill( + zero, outs=[tensor.EmptyOp(layer_out_shape, dtype)] ) + data = linalg.matmul(data, weight, outs=[out]) if bias: broadcast_bias = linalg.broadcast( bias, @@ -208,7 +140,7 @@ def init_module(self, ctx: ir.Context) -> ir.Module: ) if self.act_type == "relu": - element = FloatAttr.get(dtype, 0) + element = ir.FloatAttr.get(dtype, 0) tensor_type = ir.RankedTensorType.get( layer_out_shape, dtype ) @@ -220,13 +152,24 @@ def init_module(self, ctx: ir.Context) -> ir.Module: func.ReturnOp([data]) return module - def prepare_np_args(self, disable_results_to_params: False) -> List[np.ndarray]: - bench_func = get_kernel_func_from_module(self.ir_module, self.main_entry) - np_args = [] - for arg in bench_func.arguments: - np_args.append(make_mlir_ndarray(arg.type)) - - if not disable_results_to_params: - for res in bench_func.type.results: - np_args.append(make_mlir_ndarray(res)) - return np_args + def default_fill( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], + ): + arg.fill_type = "U" + arg.fill_param = ["0", "1"] + + def default_compare( + flags: argparse.Namespace, + arg: Arg, + arglist: List[Arg], + ): + arg.cmp_type = "P" + if arg.dtype == "f32": + arg.cmp_param = [str(1e-5)] + elif arg.dtype == "bf16": + arg.cmp_param = [str(5e-2)] + else: + raise Exception("Unsupported dtype for mlp pattern") + arg.cmp_param.append("100.0") diff --git a/test/benchgc/src/benchgc/pattern/util.py b/test/benchgc/src/benchgc/pattern/util.py new file mode 100644 index 000000000..62bf74ca5 --- /dev/null +++ b/test/benchgc/src/benchgc/pattern/util.py @@ -0,0 +1,48 @@ +################################################################################ +# Copyright (C) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + +from typing import List + + +def to_int_list(s: str) -> List[int]: + """ + Parsing the cmd for list of int values + + Args: + s (str): int values in cmd, example: 2x3x4 + + Returns: + List[int]: int values in list, example: [2, 3, 4] + """ + if not s or len(s) == 0: + return [] + return [int(i) for i in s.strip().split("x")] + + +def to_bool_list(s: str) -> List[bool]: + """ + Parsing the cmd for list of bool values + + Args: + s (str): bools in cmd, example: 1x0x1 + + Returns: + List[bool]: bools in list, example: [True, False, True] + """ + if not s or len(s) == 0: + return [] + return [bool(int(i)) for i in s.strip().split("x")] diff --git a/test/benchgc/src/benchgc/runner.py b/test/benchgc/src/benchgc/runner.py index 80178baa8..1f4e18e37 100644 --- a/test/benchgc/src/benchgc/runner.py +++ b/test/benchgc/src/benchgc/runner.py @@ -19,16 +19,16 @@ import gc_mlir._mlir_libs import gc_mlir.dialects import gc_mlir.dialects.func -import gc_mlir.ir import torch from benchgc.arith import ref_op as arith_ref_op from benchgc.linalg import ref_op as linalg_ref_op from benchgc.mlir.util import MLIRCache from benchgc.tensor import ref_op as tensor_ref_op +from gc_mlir import ir def dfs_op( - cache: MLIRCache, op: gc_mlir.ir.OpView, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: dialect_call: str = str(op.name) @@ -65,7 +65,7 @@ def dfs_op( def dfs_region( - cache: MLIRCache, region: gc_mlir.ir.Region, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, region: ir.Region, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: build_cache = len(cache.next) == 0 for i in range(len(region.blocks)): @@ -82,7 +82,7 @@ def dfs_region( def dfs_block( - cache: MLIRCache, block: gc_mlir.ir.Block, tensors: Dict[str, torch.Tensor] + cache: MLIRCache, block: ir.Block, tensors: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: build_cache = len(cache.next) == 0 for i in range(len(block.operations)): diff --git a/test/benchgc/src/benchgc/tensor/__init__.py b/test/benchgc/src/benchgc/tensor/__init__.py index 2f8bc98a4..9bfd9b09d 100644 --- a/test/benchgc/src/benchgc/tensor/__init__.py +++ b/test/benchgc/src/benchgc/tensor/__init__.py @@ -18,21 +18,19 @@ import importlib from typing import Callable, Dict, Tuple -import gc_mlir.ir import torch from benchgc.arg import Arg from benchgc.mlir.util import MLIRCache +from gc_mlir import ir ref_op: Dict[ str, Callable[ - [MLIRCache, gc_mlir.ir.OpView, Dict[str, torch.Tensor]], + [MLIRCache, ir.OpView, Dict[str, torch.Tensor]], Tuple[torch.Tensor, ...], ], ] = {} -mlir_op: Dict[ - str, Callable[[argparse.Namespace, Dict[str, Arg]], gc_mlir.ir.Module] -] = {} +mlir_op: Dict[str, Callable[[argparse.Namespace, Dict[str, Arg]], ir.Module]] = {} for dri in ["basic", "shape"]: mod = importlib.import_module(f"benchgc.tensor.{dri}") diff --git a/test/benchgc/src/benchgc/tensor/basic.py b/test/benchgc/src/benchgc/tensor/basic.py index eb56aafbc..a424a7bb2 100644 --- a/test/benchgc/src/benchgc/tensor/basic.py +++ b/test/benchgc/src/benchgc/tensor/basic.py @@ -17,13 +17,13 @@ from typing import Dict, Tuple import benchgc.util -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def ref_empty( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: return ( torch.zeros( diff --git a/test/benchgc/src/benchgc/tensor/shape.py b/test/benchgc/src/benchgc/tensor/shape.py index 18d9fbb2c..25fc20e53 100644 --- a/test/benchgc/src/benchgc/tensor/shape.py +++ b/test/benchgc/src/benchgc/tensor/shape.py @@ -16,16 +16,16 @@ from typing import Dict, List, Tuple -import gc_mlir.ir import torch from benchgc.mlir.util import MLIRCache +from gc_mlir import ir def ref_collapse_shape( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # permute axis and do reshape - reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + reassociation: ir.ArrayAttr = op.attributes["reassociation"] permutation: List[int] = [] shape: List[int] = [] for outdim in reassociation: @@ -43,10 +43,10 @@ def ref_collapse_shape( def ref_expand_shape( - cache: MLIRCache, op: gc_mlir.ir.OpView, var: Dict[str, torch.Tensor] + cache: MLIRCache, op: ir.OpView, var: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, ...]: # permute axis and do reshape - reassociation: gc_mlir.ir.ArrayAttr = op.attributes["reassociation"] + reassociation: ir.ArrayAttr = op.attributes["reassociation"] permutation: List[int] = [0] * len(op.result.type.shape) shape: List[int] = [] diff --git a/test/benchgc/src/benchgc/util.py b/test/benchgc/src/benchgc/util.py index de275f0fd..ae0b4228e 100644 --- a/test/benchgc/src/benchgc/util.py +++ b/test/benchgc/src/benchgc/util.py @@ -332,3 +332,33 @@ def p2p( def nelem(shape: List[int]) -> int: return reduce(operator.mul, shape) + + +def to_int_list(s: str) -> List[int]: + """ + Parsing the cmd for list of int values + + Args: + s (str): int values in cmd, example: 2x3x4 + + Returns: + List[int]: int values in list, example: [2, 3, 4] + """ + if not s or len(s) == 0: + return [] + return [int(i) for i in s.strip().split("x")] + + +def to_bool_list(s: str) -> List[bool]: + """ + Parsing the cmd for list of bool values + + Args: + s (str): bools in cmd, example: 1x0x1 + + Returns: + List[bool]: bools in list, example: [True, False, True] + """ + if not s or len(s) == 0: + return [] + return [bool(int(i)) for i in s.strip().split("x")] diff --git a/tools/README.md b/tools/README.md deleted file mode 100644 index 931a1b906..000000000 --- a/tools/README.md +++ /dev/null @@ -1,92 +0,0 @@ -# Python Tools -## Pre-requisites -### Enable python binding -* Enable MLIR python binding, [README](https://github.com/intel/graph-compiler/blob/main/python/README.md) -### Set env -* **PYTHONPATH**=*${BUILD_DIR}*/python_packages/gc_mlir_core -* **LD_PRELOAD**=path/to/libiomp5.so - - -## Bench -The tool has two different ways to calculate the time cost, and more experiments are needed to test which one is more stable and accurate. Currently, users can choose which way to use through options -* Use the MLIR Python API to invoke the kernel and use Python to calculate the time cost -* Modify MLIR by wrapping the kernel into a new method and calling the `nanoTime()` method before and after calling the kernel. Finally, calculate the difference as the time cost -``` - func.func private @nanoTime() -> i64 attributes {llvm.emit_c_interface} - func.func public @wrapped_main(%arg0: memref<1xi64>, %arg1: tensor<128x512xbf16>, %arg2: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} { - %0 = call @nanoTime() : () -> i64 - %1 = call @main_entry(%arg1, %arg2) : (tensor<128x512xbf16>, tensor<512x256xbf16>) -> tensor<128x256xbf16> - %2 = call @nanoTime() : () -> i64 - %3 = arith.subi %2, %0 : i64 - %c0 = arith.constant 0 : index - memref.store %3, %arg0[%c0] : memref<1xi64> - return %1 : tensor<128x256xbf16> - } -} -``` - -### Examples: -``` -# simple version -python3 ./tools/main.py --driver=load_mlir --path=./tools/workloads/test.mlir - -# complex version -python3 ./tools/main.py --type=bench --bench_kind=py --driver=load_mlir --path=./tools/workloads/test.mlir --warm_up=200 --repeat=200 --print_ir --entry=main_entry -``` - -``` -# result example -===========bench result=========== -{ - "args": { - "type": "bench", - "driver": "load_mlir", - "path": "./tools/workloads/test.mlir", - "entry": "main_entry", - "bench_kind": "py", - "print_ir": false, - "warm_up": 20, - "repeat": 100 - }, - "compile_cost(ms)": 25.58841183781624, - "execute_cost(ms)": 1.7501823976635933 -} -``` - -### Common Options -* `--driver`: the pattern to bench, currently support `mlp` and `load_mlir` -* `--bench_kind`: `py` or `wrapper`, different evaluation implementation of the benchmark -* `--warm_up`: warm-up times of the execution -* `--repeat`: repeat times of the execution -* `--print_ir`: print the ir before execution -* `--disable_results_to_params`: do not use this when using the default pipeline (gc-cpu-pipeline) - -### Driver Specific Options -* load_mlir - * `--path`: the mlir file path - * `--entry`: the name of entry func -``` -python3 ./tools/main.py --driver=load_mlir --path=./tools/workloads/test.mlir -``` - - -* mlp - * `--batch_size`: the input - * `--hidden_size_list`: hidden_sizes of mlp, example: 32x16x64 - * `--has_bias`: if the matmul op has bias, example: 1x0 - * `--act_type`: choices=["noop", "relu", "sigmoid"] - * `--dtype`: choices=["bf16", "f32"] -``` -python3 ./tools/main.py --driver=mlp --batch_size=32 --hidden_size_list=32x16x64 --has_bias=0x0 --act_type=noop --dtype=f32 - -===========bench func name: main_entry =========== -module { - func.func @main_entry(%arg0: tensor<32x32xf32>, %arg1: tensor<32x16xf32>, %arg2: tensor<16x64xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} { - %0 = tensor.empty() : tensor<32x16xf32> - %1 = linalg.matmul {cast = #linalg.type_fn} ins(%arg0, %arg1 : tensor<32x32xf32>, tensor<32x16xf32>) outs(%0 : tensor<32x16xf32>) -> tensor<32x16xf32> - %2 = tensor.empty() : tensor<32x64xf32> - %3 = linalg.matmul {cast = #linalg.type_fn} ins(%1, %arg2 : tensor<32x16xf32>, tensor<16x64xf32>) outs(%2 : tensor<32x64xf32>) -> tensor<32x64xf32> - return %3 : tensor<32x64xf32> - } -} -``` \ No newline at end of file diff --git a/tools/example/simple_test.py b/tools/example/simple_test.py deleted file mode 100644 index 81baa9085..000000000 --- a/tools/example/simple_test.py +++ /dev/null @@ -1,74 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import os -import sys - -import numpy as np -from gc_mlir import ir -from gc_mlir.graph_compiler import GraphCompiler -from numpy.testing import assert_allclose - -project_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -if project_dir not in sys.path: - sys.path.insert(0, project_dir) - -import ml_dtypes -import torch -from utils import get_mlir_args - -# an example of simple validation -if __name__ == "__main__": - with ir.Context() as ctx: - ctx.enable_multithreading(False) - module = ir.Module.parse( - """ - module { - func.func @main_entry(%arg0: tensor<10x10xbf16>, %arg1: tensor<10x10xbf16>) -> tensor<10x10xbf16> attributes {llvm.emit_c_interface} { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<10x10xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<10x10xbf16>) -> tensor<10x10xbf16> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<10x10xbf16>, tensor<10x10xbf16>) outs(%1 : tensor<10x10xbf16>) -> tensor<10x10xbf16> - return %2 : tensor<10x10xbf16> - } - } - """ - ) - torch_arg0 = torch.full((10, 10), 1.0, dtype=torch.bfloat16) - torch_arg1 = torch.full((10, 10), 1.0, dtype=torch.bfloat16) - ref_res = torch.matmul(torch_arg0, torch_arg1) - - np_arg0 = torch_arg0.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - np_arg1 = torch_arg1.view(dtype=torch.uint16).numpy().view(ml_dtypes.bfloat16) - gc_res = np.zeros((10, 10), dtype=ml_dtypes.bfloat16) - - entry = "main_entry" - mlir_args = get_mlir_args(module, entry, [np_arg0, np_arg1, gc_res]) - passes = "any(gc-cpu-pipeline)" - - # just run - compiler = GraphCompiler(passes) - engine = compiler.compile_and_jit(module, ir_printing=True) - engine.invoke(entry, *mlir_args) - - print(gc_res) - assert_allclose( - gc_res.astype(np.float32), - ref_res.to(torch.float32).numpy(), - rtol=1e-5, - atol=0, - ) diff --git a/tools/main.py b/tools/main.py deleted file mode 100644 index 9c08c6913..000000000 --- a/tools/main.py +++ /dev/null @@ -1,102 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import argparse -import json -import numpy as np -from bench import ( - mlir_wrapper_bench, - py_timeit_bench, -) -from drivers import MLP, LoadMLIR -from gc_mlir import ir -from utils import get_mlir_args - - -def get_driver_clz(diver_str: str): - """Function getting driver class by name.""" - clz = {"mlp": MLP, "load_mlir": LoadMLIR}[diver_str] - return clz - - -def add_driver_args(arg_parser: argparse.ArgumentParser): - """Function adding args for different driver.""" - driver = arg_parser.parse_known_args()[0].driver - get_driver_clz(driver).add_args(arg_parser) - - -def do_bench(args: argparse.Namespace): - """Function benching mlir""" - with ir.Context() as ctx, ir.Location.unknown(): - driver_clz = get_driver_clz(args.driver) - driver = driver_clz(ctx, args) - if args.print_ir: - ctx.enable_multithreading(False) - np_args = driver.prepare_np_args(args.disable_results_to_params) - - # TODO need data filling - # for test, fill all data with 1 - for np_arg in np_args: - np.ndarray.fill(np_arg, 1) - - mlir_args = get_mlir_args( - driver.ir_module, driver.main_entry, np_args, args.disable_results_to_params - ) - - print("===========bench func name: ", driver.main_entry, "===========") - print(driver.ir_module) - bench_kind = py_timeit_bench if args.bench_kind == "py" else mlir_wrapper_bench - execute_cost, compile_cost = bench_kind( - driver.ir_module, - driver.main_entry, - driver.get_passes(), - mlir_args, - args.print_ir, - args.repeat, - args.warm_up, - ) - print("===========bench result===========") - json_res = json.dumps( - { - "args": vars(args), - "compile_cost(ms)": compile_cost, - "execute_cost(ms)": execute_cost, - }, - indent=4, - ) - print(json_res) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--type", type=str, choices=["bench"], default="bench") - parser.add_argument( - "--driver", type=str, choices=["load_mlir", "mlp"], required=True - ) - add_driver_args(parser) - parser.add_argument( - "--bench_kind", type=str, choices=["py", "wrapper"], default="py" - ) - parser.add_argument("-p", "--print_ir", action="store_true") - parser.add_argument( - "--disable_results_to_params", action="store_true", default=False - ) - - parser.add_argument("--warm_up", type=int, default=100) - parser.add_argument("--repeat", type=int, default=100) - - do_bench(parser.parse_args()) diff --git a/tools/utils.py b/tools/utils.py deleted file mode 100644 index fc01fd208..000000000 --- a/tools/utils.py +++ /dev/null @@ -1,191 +0,0 @@ -################################################################################ -# Copyright (C) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -import ctypes -from typing import List - -import ml_dtypes -import numpy as np -from gc_mlir import ir -from gc_mlir.dialects import arith, func, memref -from gc_mlir.runtime.np_to_memref import ( - BF16, - get_ranked_memref_descriptor, - make_nd_memref_descriptor, -) - -MLIR_TYPE_TO_NUMPY_TYPE = { - "bf16": ml_dtypes.bfloat16, - "f32": np.float32, - "f64": np.float64, - "i8": np.int8, - "i32": np.int32, - "i64": np.int64, -} - -MLIR_TYPE_TO_C_TYPE = { - "f32": ctypes.c_float, - "f64": ctypes.c_double, - "i32": ctypes.c_int, - "i8": ctypes.c_byte, - "bf16": BF16, -} - - -def STR_TO_MLIR_TYPE(type: str, ctx: ir.Context): - type_map = { - "f32": ir.F32Type.get(ctx), - "f64": ir.F64Type.get(ctx), - "bf16": ir.BF16Type.get(ctx), - "i32": ir.IntegerType.get_signed(32, ctx), - "i8": ir.IntegerType.get_signed(8, ctx), - } - return type_map[type] - - -def emit_nano_time() -> func.FuncOp: - """Emit a nanoTime function that returns the current time in nanoseconds.""" - nanoTime = func.FuncOp( - "nanoTime", ([], [ir.IntegerType.get_signless(64)]), visibility="private" - ) - nanoTime.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - return nanoTime - - -def emit_benchmark_wrapped_main_func( - kernel_func: func.FuncOp, timer_func: func.FuncOp -) -> func.FuncOp: - """Emit a wrapped main function that calls the kernel function and records the time taken.""" - memref_of_i64_type = ir.MemRefType.get([1], ir.IntegerType.get_signless(64)) - wrapped_func_name = "wrapped_main" - assert wrapped_func_name != str( - kernel_func.name - ), "wrapped function name should be different from kernel function name" - wrapped_func = func.FuncOp( - wrapped_func_name, - ([memref_of_i64_type] + kernel_func.arguments.types, kernel_func.type.results), - visibility="public", - ) - wrapped_func.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - with ir.InsertionPoint(wrapped_func.add_entry_block()): - timer_buffer = wrapped_func.arguments[0] - start = func.CallOp(timer_func, []) - call_op = func.CallOp( - kernel_func, - list(wrapped_func.arguments[1:]), - ) - end = func.CallOp(timer_func, []) - time_taken = arith.SubIOp(end, start) - zero = arith.ConstantOp.create_index(0) - memref.StoreOp(time_taken, timer_buffer, [zero]) - func.ReturnOp(call_op.results) - return wrapped_func - - -def get_mlir_args( - module: ir.Module, - entry: str, - np_args: List[np.ndarray], - disable_results_to_params=False, -): - """Convert numpy arrays to MLIR args and return a list of pointers to them""" - f = get_kernel_func_from_module(module, entry) - compiled_func_args = [] - if disable_results_to_params: - assert len(np_args) == len(f.arguments), "input args mismatch" - for res in f.type.results: - compiled_func_args.append( - ctypes.pointer( - ctypes.pointer( - make_nd_memref_descriptor( - len(res.shape), MLIR_TYPE_TO_C_TYPE[str(res.element_type)] - )() - ) - ) - ) - for arg in np_args: - compiled_func_args.append( - ctypes.pointer(ctypes.pointer(get_ranked_memref_descriptor(arg))) - ) - return compiled_func_args - - -def make_mlir_ndarray(mlir_type): - """create numpy ndarray from mlir type""" - return np.zeros( - mlir_type.shape, MLIR_TYPE_TO_NUMPY_TYPE[str(mlir_type.element_type)] - ) - - -def get_kernel_func_from_module( - module: ir.Module, func_name: str = "main_entry" -) -> func.FuncOp: - """Get the func op by the name from a module""" - assert ( - len(module.operation.regions) == 1 - ), "Expected kernel module to have only one region" - assert ( - len(module.operation.regions[0].blocks) == 1 - ), "Expected kernel module to have only one block" - for f in module.operation.regions[0].blocks[0].operations: - if type(f) is func.FuncOp and str(f.name).strip('"') == func_name: - return f - raise ValueError("can not find the entry function") - - -def get_default_passes(): - passes = """ - any(gc-cpu-pipeline) - """ - return passes - - -def to_int_list(s: str) -> List[int]: - """ - Parsing the cmd for list of int values - - Args: - s (str): int values in cmd, example: 2x3x4 - - Returns: - List[int]: int values in list, example: [2, 3, 4] - """ - if not s or len(s) == 0: - return [] - return [int(i) for i in s.strip().split("x")] - - -def to_bool_list(s: str) -> List[bool]: - """ - Parsing the cmd for list of bool values - - Args: - s (str): bools in cmd, example: 1x0x1 - - Returns: - List[bool]: bools in list, example: [True, False, True] - """ - if not s or len(s) == 0: - return [] - return [bool(int(i)) for i in s.strip().split("x")] - - -def load_mlir_from_path(path: str) -> str: - """Load MLIR content from path""" - with open(path, "r") as file: - content = file.read() - return content diff --git a/tools/workloads/test.mlir b/tools/workloads/test.mlir deleted file mode 100644 index 9170ec61a..000000000 --- a/tools/workloads/test.mlir +++ /dev/null @@ -1,8 +0,0 @@ - -func.func @main_entry(%arg0: tensor<128x512xbf16>, %arg1: tensor<512x256xbf16>) -> tensor<128x256xbf16> attributes {llvm.emit_c_interface} { - %cst = arith.constant 0.000000e+00 : bf16 - %0 = tensor.empty() : tensor<128x256xbf16> - %1 = linalg.fill ins(%cst : bf16) outs(%0 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - %2 = linalg.matmul ins(%arg0, %arg1 : tensor<128x512xbf16>, tensor<512x256xbf16>) outs(%1 : tensor<128x256xbf16>) -> tensor<128x256xbf16> - return %2 : tensor<128x256xbf16> -} \ No newline at end of file