-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[NFC][mlir][vector] Handle potential static cast assertion. #152957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
In FoldArithToVectorOuterProduct pattern, static cast to vector type causes assertion when a scalar type was encountered. It seems the author meant to have a dyn_cast instead. This NFC patch handles it by using dyn_cast.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Md Asghar Ahmad Shahid (shahidact) ChangesIn FoldArithToVectorOuterProduct pattern, static cast to vector type causes assertion when a scalar type was encountered. It seems the author meant to have a dyn_cast instead. This NFC patch handles it by using dyn_cast. Full diff: https://github.com/llvm/llvm-project/pull/152957.diff 1 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
index 2269a40ec8ef1..023c4da7dffdf 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
@@ -2274,7 +2274,7 @@ struct FoldArithToVectorOuterProduct : public OpRewritePattern<MulOpType> {
LogicalResult matchAndRewrite(MulOpType mulOp,
PatternRewriter &rewriter) const override {
- auto resType = llvm::cast<VectorType>(mulOp.getResult().getType());
+ auto resType = llvm::dyn_cast<VectorType>(mulOp.getResult().getType());
if (!resType)
return failure();
if (resType.getRank() != 2)
|
Thanks, do you have a repro that we could use as a test for this PR? |
Yes, I do have a test case, but I could not reproduce it using transform dialect while it is reproducible using a downstream pass. |
Looks like a good fix, but a test would be extra nice. Not possible to add one in the same file as the original PR? Wondering if a version of which acts on tensor type instead of vector type would be sufficient? |
I encountered the assert with above test case using the downstream vectorizer pass. I could not reproduce the assert with above transform schedule. Hence, I did not add the test case. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok. This is clearly a bug, you're hitting it with a cast, it is resolved with a dyn_cast. So from my end it's not worth investing too much effort into getting a test. LGTM; thanks for the fix!
I am not sure why we wouldn't need a test actually? Downstream often ncounter issues which are specific to their setup (they may be mis-using APIs, or have more test coverage), so the usual requirements is to have a test to provide coverage upstream. |
I'm not very familiar with the transform dialect but my guess is that with your script it's not entering the block of the linalg.generic. I have added a test in #154434 that does hit the assert (before the fix in this PR). |
Bug introduced in llvm/llvm-project#93664 The bug was fixed in llvm/llvm-project#152957 But there was no test. This PR adds a test that hits the assertion failure if the fix is reverted (if I change dyn_cast to cast).
In FoldArithToVectorOuterProduct pattern, static cast to vector type causes assertion when a scalar type was encountered. It seems the author meant to have a dyn_cast instead.
This NFC patch handles it by using dyn_cast.