From 7ec661414aac0aa6fb99d9d0fa0027f7a5ba6355 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Thu, 13 Apr 2023 16:27:21 -0400 Subject: [PATCH] Match decode + NT-GeMV + [ewise] pattern This PR uses FuseTIRByPattern to match the decode + NT-GeMV + optionally a trailing element-wise TIR function. E2E verified locally. The next step is to turn off NT-matmul and update the quantization encoding/decoding accordingly so that the quantization encoding func transposes the weights from T to N, and also update this pattern match function accordingly. --- build.py | 1 + web_llm/transform/__init__.py | 1 + web_llm/transform/decode_NT_matmul_ewise.py | 86 +++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 web_llm/transform/decode_NT_matmul_ewise.py diff --git a/build.py b/build.py index 9118e654..7fe0d890 100644 --- a/build.py +++ b/build.py @@ -117,6 +117,7 @@ def mod_transform_before_build( mod = relax.transform.FuseTIR()(mod) mod = web_llm.transform.GroupQuantize(group_size=32, sym=False)(mod) + mod = web_llm.transform.FuseDecodeNTMatmulEwise()(mod) mod = relax.transform.DeadCodeElimination(model_names)(mod) mod = relax.transform.LiftTransformParams()(mod) mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names) diff --git a/web_llm/transform/__init__.py b/web_llm/transform/__init__.py index cdc48121..34008954 100644 --- a/web_llm/transform/__init__.py +++ b/web_llm/transform/__init__.py @@ -1,3 +1,4 @@ from .dispatch_tir_operator import DispatchTIROperator from .quantization import GroupQuantize from .transpose_matmul import FuseTransposeMatmul +from .decode_NT_matmul_ewise import FuseDecodeNTMatmulEwise diff --git a/web_llm/transform/decode_NT_matmul_ewise.py b/web_llm/transform/decode_NT_matmul_ewise.py new file mode 100644 index 00000000..e54762f3 --- /dev/null +++ b/web_llm/transform/decode_NT_matmul_ewise.py @@ -0,0 +1,86 @@ +import tvm +from tvm import IRModule +from tvm import relax, tir +from tvm.relax.dpl.pattern import is_op, wildcard +from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern + + +def check_x_1dim(ctx: relax.transform.PatternCheckContext) -> bool: + x = ctx.annotated_expr["x"] + n = x.struct_info.shape[-2] + return isinstance(n, tir.IntImm) and n.value == 1 + + +def check_decoding(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["w"] + gv = call.args[0] + return gv.name_hint.startswith("decode") + + +def check_NT_matmul(ctx: relax.transform.PatternCheckContext) -> bool: + call = ctx.annotated_expr["NT_matmul"] + gv = call.args[0] + return gv.name_hint.startswith("NT_matmul") or gv.name_hint.startswith("fused_NT_matmul") + + +def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool: + return check_x_1dim(ctx) and check_decoding(ctx) and check_NT_matmul(ctx) + + +def decode_NT_matmul_pattern(): + w_scaled = wildcard() + scale_min = wildcard() + x = wildcard() + w = is_op("relax.call_tir")( + GlobalVarPattern(), TuplePattern([w_scaled, scale_min]), add_constraint=False + ) + NT_matmul = is_op("relax.call_tir")( + GlobalVarPattern(), TuplePattern([x, w]), add_constraint=False + ) + + annotations = { + "NT_matmul": NT_matmul, + "w": w, + "x": x, + "w_scaled": w_scaled, + "scale_min": scale_min, + } + + return NT_matmul, annotations, pattern_check + + +def decode_NT_matmul_ewise_pattern(): + w_scaled = wildcard() + scale_min = wildcard() + x = wildcard() + y = wildcard() + w = is_op("relax.call_tir")( + GlobalVarPattern(), TuplePattern([w_scaled, scale_min]), add_constraint=False + ) + NT_matmul_ewise = is_op("relax.call_tir")( + GlobalVarPattern(), TuplePattern([x, w, y]), add_constraint=False + ) + + annotations = { + "NT_matmul": NT_matmul_ewise, + "w": w, + "x": x, + "w_scaled": w_scaled, + "scale_min": scale_min, + } + + return NT_matmul_ewise, annotations, pattern_check + + +@tvm.transform.module_pass(opt_level=0, name="FuseDecodeNTMatmulEwise") +class FuseDecodeNTMatmulEwise: + def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule: + mod = relax.transform.FuseOpsByPattern([("decode_NT_matmul", *decode_NT_matmul_pattern())])( + mod + ) + mod = relax.transform.FuseOpsByPattern( + [("decode_NT_matmul_ewise", *decode_NT_matmul_ewise_pattern())] + )(mod) + mod = relax.transform.FuseTIR()(mod) + + return mod