From a1ccb7fba6f6a8121b9e23031de2bfd3dd8637e4 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 15 Sep 2025 14:19:34 -0400 Subject: [PATCH 1/5] adding pass and updating unroll rule --- src/kirin/dialects/scf/unroll.py | 2 +- src/kirin/passes/aggressive/scf_unroll.py | 32 ++++++++++++++++++++ test/passes/test_unroll_scf.py | 37 +++++++++++++++++++++++ 3 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 src/kirin/passes/aggressive/scf_unroll.py create mode 100644 test/passes/test_unroll_scf.py diff --git a/src/kirin/dialects/scf/unroll.py b/src/kirin/dialects/scf/unroll.py index 6791d320d..09e2ef794 100644 --- a/src/kirin/dialects/scf/unroll.py +++ b/src/kirin/dialects/scf/unroll.py @@ -88,4 +88,4 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: for result, output in zip(node.results, loop_vars): result.replace_by(output) node.delete() - return RewriteResult(has_done_something=True, terminated=True) + return RewriteResult(has_done_something=True) diff --git a/src/kirin/passes/aggressive/scf_unroll.py b/src/kirin/passes/aggressive/scf_unroll.py new file mode 100644 index 000000000..16a3aa18a --- /dev/null +++ b/src/kirin/passes/aggressive/scf_unroll.py @@ -0,0 +1,32 @@ +from dataclasses import field, dataclass + +from kirin import ir, rewrite +from kirin.passes import Pass +from kirin.rewrite import abc +from kirin.passes.typeinfer import TypeInfer +from kirin.dialects.scf.unroll import ForLoop, PickIfElse + +from ..fold import Fold + + +@dataclass +class UnrollScf(Pass): + """This pass can be used to unroll scf.For loops and inline/expand scf.IfElse when + the input are known at compile time. + + """ + + typeinfer: TypeInfer = field(init=False) + fold: Fold = field(init=False) + + def __post_init__(self): + self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) + self.fold = Fold(self.dialects, no_raise=self.no_raise) + + def unsafe_run(self, mt: ir.Method): + result = abc.RewriteResult() + result = rewrite.Walk(PickIfElse()).rewrite(mt.code).join(result) + result = rewrite.Walk(ForLoop()).rewrite(mt.code).join(result) + result = self.typeinfer(mt).join(result) + result = self.fold(mt).join(result) + return result diff --git a/test/passes/test_unroll_scf.py b/test/passes/test_unroll_scf.py new file mode 100644 index 000000000..148ad4c10 --- /dev/null +++ b/test/passes/test_unroll_scf.py @@ -0,0 +1,37 @@ +from kirin.prelude import structural +from kirin.dialects import py, func +from kirin.passes.aggressive.scf_unroll import UnrollScf + + +def test_unroll_scf(): + + @structural + def main(r: list[int]): + for i in range(4): + tmp = r[-1] + if i < 2: + tmp += i * 2 + else: + for j in range(4): + if i > j: + tmp += i + j + else: + tmp += i - j + + r.append(tmp) + + return r + + UnrollScf(structural).fixpoint(main) + + num_adds = 0 + num_calls = 0 + + for op in main.callable_region.walk(): + if isinstance(op, py.Add): + num_adds += 1 + elif isinstance(op, func.Call): + num_calls += 1 + + assert num_adds == 10 + assert num_calls == 4 From dca626f11e64531c41d520b7c71c532abf155194 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 15 Sep 2025 14:25:41 -0400 Subject: [PATCH 2/5] renaming and reexporting pass --- src/kirin/passes/aggressive/__init__.py | 1 + src/kirin/passes/aggressive/{scf_unroll.py => unroll.py} | 5 +++++ test/passes/test_unroll_scf.py | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) rename src/kirin/passes/aggressive/{scf_unroll.py => unroll.py} (87%) diff --git a/src/kirin/passes/aggressive/__init__.py b/src/kirin/passes/aggressive/__init__.py index a370b7b67..94e584626 100644 --- a/src/kirin/passes/aggressive/__init__.py +++ b/src/kirin/passes/aggressive/__init__.py @@ -1 +1,2 @@ from .fold import Fold as Fold +from .unroll import UnrollScf as UnrollScf diff --git a/src/kirin/passes/aggressive/scf_unroll.py b/src/kirin/passes/aggressive/unroll.py similarity index 87% rename from src/kirin/passes/aggressive/scf_unroll.py rename to src/kirin/passes/aggressive/unroll.py index 16a3aa18a..e2aed6688 100644 --- a/src/kirin/passes/aggressive/scf_unroll.py +++ b/src/kirin/passes/aggressive/unroll.py @@ -14,6 +14,11 @@ class UnrollScf(Pass): """This pass can be used to unroll scf.For loops and inline/expand scf.IfElse when the input are known at compile time. + usage: + UnrollScf(dialects).fixpoint(method) + + Note: This pass should be used in a fixpoint manner, to unroll nested scf nodes. + """ typeinfer: TypeInfer = field(init=False) diff --git a/test/passes/test_unroll_scf.py b/test/passes/test_unroll_scf.py index 148ad4c10..7dd099a67 100644 --- a/test/passes/test_unroll_scf.py +++ b/test/passes/test_unroll_scf.py @@ -1,6 +1,6 @@ from kirin.prelude import structural from kirin.dialects import py, func -from kirin.passes.aggressive.scf_unroll import UnrollScf +from kirin.passes.aggressive import UnrollScf def test_unroll_scf(): From 9f559873ec96df8bc2389841f6ffc7730ca7d8b5 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 15 Sep 2025 14:28:34 -0400 Subject: [PATCH 3/5] simplifying imports --- src/kirin/passes/aggressive/unroll.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/kirin/passes/aggressive/unroll.py b/src/kirin/passes/aggressive/unroll.py index e2aed6688..39ecfc894 100644 --- a/src/kirin/passes/aggressive/unroll.py +++ b/src/kirin/passes/aggressive/unroll.py @@ -1,13 +1,11 @@ from dataclasses import field, dataclass -from kirin import ir, rewrite -from kirin.passes import Pass -from kirin.rewrite import abc -from kirin.passes.typeinfer import TypeInfer +from kirin.ir import Method +from kirin.passes import Fold, Pass, TypeInfer +from kirin.rewrite import Walk +from kirin.rewrite.abc import RewriteResult from kirin.dialects.scf.unroll import ForLoop, PickIfElse -from ..fold import Fold - @dataclass class UnrollScf(Pass): @@ -28,10 +26,10 @@ def __post_init__(self): self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise) self.fold = Fold(self.dialects, no_raise=self.no_raise) - def unsafe_run(self, mt: ir.Method): - result = abc.RewriteResult() - result = rewrite.Walk(PickIfElse()).rewrite(mt.code).join(result) - result = rewrite.Walk(ForLoop()).rewrite(mt.code).join(result) + def unsafe_run(self, mt: Method): + result = RewriteResult() + result = Walk(PickIfElse()).rewrite(mt.code).join(result) + result = Walk(ForLoop()).rewrite(mt.code).join(result) result = self.typeinfer(mt).join(result) result = self.fold(mt).join(result) return result From 4f137382e6cf12487c740734bc7396134958cc69 Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 15 Sep 2025 14:31:57 -0400 Subject: [PATCH 4/5] making test more complicated --- test/passes/test_unroll_scf.py | 36 ++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/test/passes/test_unroll_scf.py b/test/passes/test_unroll_scf.py index 7dd099a67..461610f91 100644 --- a/test/passes/test_unroll_scf.py +++ b/test/passes/test_unroll_scf.py @@ -4,26 +4,28 @@ def test_unroll_scf(): - @structural - def main(r: list[int]): - for i in range(4): - tmp = r[-1] - if i < 2: - tmp += i * 2 - else: - for j in range(4): - if i > j: - tmp += i + j - else: - tmp += i - j - - r.append(tmp) - + def main(r: list[int], cond: bool): + if cond: + for i in range(4): + tmp = r[-1] + if i < 2: + tmp += i * 2 + else: + for j in range(4): + if i > j: + tmp += i + j + else: + tmp += i - j + + r.append(tmp) + else: + for i in range(4): + r.append(i) return r UnrollScf(structural).fixpoint(main) - + main.print() num_adds = 0 num_calls = 0 @@ -34,4 +36,4 @@ def main(r: list[int]): num_calls += 1 assert num_adds == 10 - assert num_calls == 4 + assert num_calls == 8 From 074df6a0da41faa5be5b243d725c05541becd2ed Mon Sep 17 00:00:00 2001 From: Phillip Weinberg Date: Mon, 15 Sep 2025 14:32:27 -0400 Subject: [PATCH 5/5] removing print --- test/passes/test_unroll_scf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/passes/test_unroll_scf.py b/test/passes/test_unroll_scf.py index 461610f91..9442236d0 100644 --- a/test/passes/test_unroll_scf.py +++ b/test/passes/test_unroll_scf.py @@ -25,7 +25,7 @@ def main(r: list[int], cond: bool): return r UnrollScf(structural).fixpoint(main) - main.print() + num_adds = 0 num_calls = 0