Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions build.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def mod_transform_before_build(
mod = relax.transform.FuseTIR()(mod)

mod = web_llm.transform.GroupQuantize(group_size=32, sym=False)(mod)
mod = web_llm.transform.FuseDecodeNTMatmulEwise()(mod)
mod = relax.transform.DeadCodeElimination(model_names)(mod)
mod = relax.transform.LiftTransformParams()(mod)
mod_transform, mod_deploy = utils.split_transform_deploy_mod(mod, model_names)
Expand Down
1 change: 1 addition & 0 deletions web_llm/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .dispatch_tir_operator import DispatchTIROperator
from .quantization import GroupQuantize
from .transpose_matmul import FuseTransposeMatmul
from .decode_NT_matmul_ewise import FuseDecodeNTMatmulEwise
86 changes: 86 additions & 0 deletions web_llm/transform/decode_NT_matmul_ewise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import tvm
from tvm import IRModule
from tvm import relax, tir
from tvm.relax.dpl.pattern import is_op, wildcard
from tvm.relax.dpl.pattern import GlobalVarPattern, TuplePattern


def check_x_1dim(ctx: relax.transform.PatternCheckContext) -> bool:
x = ctx.annotated_expr["x"]
n = x.struct_info.shape[-2]
return isinstance(n, tir.IntImm) and n.value == 1


def check_decoding(ctx: relax.transform.PatternCheckContext) -> bool:
call = ctx.annotated_expr["w"]
gv = call.args[0]
return gv.name_hint.startswith("decode")


def check_NT_matmul(ctx: relax.transform.PatternCheckContext) -> bool:
call = ctx.annotated_expr["NT_matmul"]
gv = call.args[0]
return gv.name_hint.startswith("NT_matmul") or gv.name_hint.startswith("fused_NT_matmul")


def pattern_check(ctx: relax.transform.PatternCheckContext) -> bool:
return check_x_1dim(ctx) and check_decoding(ctx) and check_NT_matmul(ctx)


def decode_NT_matmul_pattern():
w_scaled = wildcard()
scale_min = wildcard()
x = wildcard()
w = is_op("relax.call_tir")(
GlobalVarPattern(), TuplePattern([w_scaled, scale_min]), add_constraint=False
)
NT_matmul = is_op("relax.call_tir")(
GlobalVarPattern(), TuplePattern([x, w]), add_constraint=False
)

annotations = {
"NT_matmul": NT_matmul,
"w": w,
"x": x,
"w_scaled": w_scaled,
"scale_min": scale_min,
}

return NT_matmul, annotations, pattern_check


def decode_NT_matmul_ewise_pattern():
w_scaled = wildcard()
scale_min = wildcard()
x = wildcard()
y = wildcard()
w = is_op("relax.call_tir")(
GlobalVarPattern(), TuplePattern([w_scaled, scale_min]), add_constraint=False
)
NT_matmul_ewise = is_op("relax.call_tir")(
GlobalVarPattern(), TuplePattern([x, w, y]), add_constraint=False
)

annotations = {
"NT_matmul": NT_matmul_ewise,
"w": w,
"x": x,
"w_scaled": w_scaled,
"scale_min": scale_min,
}

return NT_matmul_ewise, annotations, pattern_check


@tvm.transform.module_pass(opt_level=0, name="FuseDecodeNTMatmulEwise")
class FuseDecodeNTMatmulEwise:
def transform_module(self, mod: IRModule, ctx: tvm.transform.PassContext) -> IRModule:
mod = relax.transform.FuseOpsByPattern([("decode_NT_matmul", *decode_NT_matmul_pattern())])(
mod
)
mod = relax.transform.FuseOpsByPattern(
[("decode_NT_matmul_ewise", *decode_NT_matmul_ewise_pattern())]
)(mod)
mod = relax.transform.FuseTIR()(mod)

return mod