@@ -26,6 +26,31 @@ using namespace mlir::python;
2626
2727namespace {
2828
29+ class PyPatternRewriter {
30+ public:
31+ PyPatternRewriter (MlirPatternRewriter rewriter)
32+ : rewriter(rewriter), base(mlirPatternRewriterAsBase(rewriter)),
33+ ctx (PyMlirContext::forContext(mlirRewriterBaseGetContext(base))) {}
34+
35+ PyInsertionPoint getInsertionPoint () const {
36+ MlirBlock block = mlirRewriterBaseGetInsertionBlock (base);
37+ MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion (base);
38+
39+ if (mlirOperationIsNull (op)) {
40+ MlirOperation owner = mlirBlockGetParentOperation (block);
41+ auto parent = PyOperation::forOperation (ctx, owner);
42+ return PyInsertionPoint (PyBlock (parent, block));
43+ }
44+
45+ return PyInsertionPoint (PyOperation::forOperation (ctx, op));
46+ }
47+
48+ private:
49+ MlirPatternRewriter rewriter [[maybe_unused]];
50+ MlirRewriterBase base;
51+ PyMlirContextRef ctx;
52+ };
53+
2954#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
3055static nb::object objectFromPDLValue (MlirPDLValue value) {
3156 if (MlirValue v = mlirPDLValueAsValue (value); !mlirValueIsNull (v))
@@ -84,7 +109,8 @@ class PyPDLPatternModule {
84109 void *userData) -> MlirLogicalResult {
85110 nb::handle f = nb::handle (static_cast <PyObject *>(userData));
86111 return logicalResultFromObject (
87- f (rewriter, results, objectsFromPDLValues (nValues, values)));
112+ f (PyPatternRewriter (rewriter), results,
113+ objectsFromPDLValues (nValues, values)));
88114 },
89115 fn.ptr ());
90116 }
@@ -98,7 +124,8 @@ class PyPDLPatternModule {
98124 void *userData) -> MlirLogicalResult {
99125 nb::handle f = nb::handle (static_cast <PyObject *>(userData));
100126 return logicalResultFromObject (
101- f (rewriter, results, objectsFromPDLValues (nValues, values)));
127+ f (PyPatternRewriter (rewriter), results,
128+ objectsFromPDLValues (nValues, values)));
102129 },
103130 fn.ptr ());
104131 }
@@ -143,21 +170,8 @@ class PyFrozenRewritePatternSet {
143170
144171// / Create the `mlir.rewrite` here.
145172void mlir::python::populateRewriteSubmodule (nb::module_ &m) {
146- nb::class_<MlirPatternRewriter>(m, " PatternRewriter" )
147- .def (" ip" , [](MlirPatternRewriter rewriter) {
148- MlirRewriterBase base = mlirPatternRewriterAsBase (rewriter);
149- MlirBlock block = mlirRewriterBaseGetInsertionBlock (base);
150- MlirOperation op = mlirRewriterBaseGetOperationAfterInsertion (base);
151- MlirOperation owner = mlirBlockGetParentOperation (block);
152- auto ctx = PyMlirContext::forContext (mlirRewriterBaseGetContext (base))
153- ->getRef ();
154- if (mlirOperationIsNull (op)) {
155- auto parent = PyOperation::forOperation (ctx, owner);
156- return PyInsertionPoint (PyBlock (parent, block));
157- }
158-
159- return PyInsertionPoint (*PyOperation::forOperation (ctx, op).get ());
160- });
173+ nb::class_<PyPatternRewriter>(m, " PyPatternRewriter" )
174+ .def (" ip" , &PyPatternRewriter::getInsertionPoint);
161175 // ----------------------------------------------------------------------------
162176 // Mapping of the PDLResultList and PDLModule
163177 // ----------------------------------------------------------------------------
0 commit comments