-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir] Handle NativeCodeCallVoid in result patterns. #65804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir Changesdiff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6bb79fb4b4cbe67..bc2731df1850838 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1184,17 +1184,22 @@ void PatternEmitter::emitRewriteLogic() {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os << "\n";
- // Resolve each symbol for all range use so that we can loop over them.
- // We need an explicit cast to `SmallVector` to capture the cases where
- // `{0}` resolves to an `Operation::result_range` as well as cases that
- // are not iterable (e.g. vector that gets wrapped in additional braces by
- // RewriterGen).
- // TODO: Revisit the need for materializing a vector.
- os << symbolInfoMap.getAllRangeUse(
- val,
- "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
- " tblgen_repl_values.push_back(v);\n}\n",
- "\n");
+ if (resultTree.isNativeCodeCall() &&
+ resultTree.getNumReturnsOfNativeCode() == 0) {
+ os << val << ";\n";
+ } else {
+ // Resolve each symbol for all range use so that we can loop over them.
+ // We need an explicit cast to `SmallVector` to capture the cases where
+ // `{0}` resolves to an `Operation::result_range` as well as cases that
+ // are not iterable (e.g. vector that gets wrapped in additional braces by
+ // RewriterGen).
+ // TODO: Revisit the need for materializing a vector.
+ os << symbolInfoMap.getAllRangeUse(
+ val,
+ "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
+ " tblgen_repl_values.push_back(v);\n}\n",
+ "\n");
+ }
}
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
}
|
|
The pattern looks wrong though, from PatternBase.td so in
|
|
Thanks for the clarification! My understanding is that since |
Currently NativeCodeCallVoid is not supported in the result patterns. For example, below code will fail to build with an error message `referencing unbound symbol`. ``` def Foo: NativeCodeCallVoid<"foo()">; def AddToAddV2 : Pattern< (TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1), [(TF_AddV2Op $arg0, $arg1), (Foo)]>; ``` MLIR tablegen-based pattern rewrites does not preserve attributes of the source op, with this change users could mannualy copy source attributes to the target op via NativeCodeCallVoid. This is a replacement reviews.llvm.org/D157032.
|
@llvm/pr-subscribers-mlir-core ChangesCurrently NativeCodeCallVoid is not supported in the result patterns. For example, below code will fail to build with an error message "referencing unbound symbol" def Foo: NativeCodeCallVoid<"foo()">; def AddToAddV2 : Pattern< MLIR tablegen-based pattern rewrites does not preserve attributes of the source op, with this change users could mannualy copy source attributes to the target op via NativeCodeCallVoid. This is a replacement of D157032.Full diff: https://github.com/llvm/llvm-project/pull/65804.diff 1 Files Affected:
diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp
index 6bb79fb4b4cbe67..bc2731df1850838 100644
--- a/mlir/tools/mlir-tblgen/RewriterGen.cpp
+++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp
@@ -1184,17 +1184,22 @@ void PatternEmitter::emitRewriteLogic() {
DagNode resultTree = pattern.getResultPattern(i);
auto val = handleResultPattern(resultTree, offsets[i], 0);
os << "\n";
- // Resolve each symbol for all range use so that we can loop over them.
- // We need an explicit cast to `SmallVector` to capture the cases where
- // `{0}` resolves to an `Operation::result_range` as well as cases that
- // are not iterable (e.g. vector that gets wrapped in additional braces by
- // RewriterGen).
- // TODO: Revisit the need for materializing a vector.
- os << symbolInfoMap.getAllRangeUse(
- val,
- "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
- " tblgen_repl_values.push_back(v);\n}\n",
- "\n");
+ if (resultTree.isNativeCodeCall() &&
+ resultTree.getNumReturnsOfNativeCode() == 0) {
+ os << val << ";\n";
+ } else {
+ // Resolve each symbol for all range use so that we can loop over them.
+ // We need an explicit cast to `SmallVector` to capture the cases where
+ // `{0}` resolves to an `Operation::result_range` as well as cases that
+ // are not iterable (e.g. vector that gets wrapped in additional braces by
+ // RewriterGen).
+ // TODO: Revisit the need for materializing a vector.
+ os << symbolInfoMap.getAllRangeUse(
+ val,
+ "for (auto v: ::llvm::SmallVector<::mlir::Value, 4>{ {0} }) {{\n"
+ " tblgen_repl_values.push_back(v);\n}\n",
+ "\n");
+ }
}
os << "\nrewriter.replaceOp(op0, tblgen_repl_values);\n";
}
|
|
Abandon this change in favor of #66959. |
Currently NativeCodeCallVoid is not supported in the result patterns. For example, below code will fail to build with an error message "referencing unbound symbol"
def Foo: NativeCodeCallVoid<"foo()">;
def AddToAddV2 : Pattern<
(TF_AddOp TF_NumberTensor:$arg0, TF_NumberTensor:$arg1),
[(TF_AddV2Op $arg0, $arg1), (Foo)]>;
MLIR tablegen-based pattern rewrites does not preserve attributes of the source op, with this change users could mannualy copy source attributes to the target op via NativeCodeCallVoid. This is a replacement of D157032.