From 4dcd00d7e2e35cc651624ba0363f7882391d5e38 Mon Sep 17 00:00:00 2001 From: Jian Cai Date: Fri, 8 Sep 2023 14:28:36 -0700 Subject: [PATCH] Handle NativeCodeCallVoid in result patterns. Currently NativeCodeCallVoid is not supported in the result patterns. For example, below code will fail to build with an error message `referencing unbound symbol`. ``` def Foo: NativeCodeCallVoid<"foo()">; def AddToAddV2 : Pattern< (TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1), [(TF_AddV2Op $arg0, $arg1), (Foo)]>; ``` MLIR tablegen-based pattern rewrites does not preserve attributes of the source op, with this change users could mannualy copy source attributes to the target op via NativeCodeCallVoid. This is a replacement reviews.llvm.org/D157032. --- mlir/tools/mlir-tblgen/RewriterGen.cpp | 27 +++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index 6bb79fb4b4cbe..bc2731df18508 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -1184,17 +1184,22 @@ void PatternEmitter::emitRewriteLogic() { DagNode resultTree = pattern.getResultPattern(i); auto val = handleResultPattern(resultTree, offsets[i], 0); os << "\n"; - // Resolve each symbol for all range use so that we can loop over them. - // We need an explicit cast to `SmallVector` to capture the cases where - // `{0}` resolves to an `Operation::result_range` as well as cases that - // are not iterable (e.g. vector that gets wrapped in additional braces by - // RewriterGen). - // TODO: Revisit the need for materializing a vector. - os << symbolInfoMap.getAllRangeUse( - val, - "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" - " tblgen_repl_values.push_back(v);\n}\n", - "\n"); + if (resultTree.isNativeCodeCall() && + resultTree.getNumReturnsOfNativeCode() == 0) { + os << val << ";\n"; + } else { + // Resolve each symbol for all range use so that we can loop over them. + // We need an explicit cast to `SmallVector` to capture the cases where + // `{0}` resolves to an `Operation::result_range` as well as cases that + // are not iterable (e.g. vector that gets wrapped in additional braces by + // RewriterGen). + // TODO: Revisit the need for materializing a vector. + os << symbolInfoMap.getAllRangeUse( + val, + "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n" + " tblgen_repl_values.push_back(v);\n}\n", + "\n"); + } } os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n"; }