-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][arith] Fix arith.select lowering after #166513
#166692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Matthias Springer (matthias-springer) Changes#166513 broke the lowering of Full diff: https://github.com/llvm/llvm-project/pull/166692.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..b8e3023b25569 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
/// ArrayRef<NamedAttribute>.
template <typename SourceOp, typename TargetOp,
template <typename, typename> typename AttrConvert =
- AttrConvertPassThrough>
+ AttrConvertPassThrough,
+ bool FailOnUnsupportedFP = true>
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
public:
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
"unsupported floating point type");
return success();
};
- for (Value operand : op->getOperands())
- if (failed(checkType(operand)))
+ if (FailOnUnsupportedFP) {
+ for (Value operand : op->getOperands())
+ if (failed(checkType(operand)))
+ return failure();
+ if (failed(checkType(op->getResult(0))))
return failure();
- if (failed(checkType(op->getResult(0))))
- return failure();
+ }
// Determine attributes for the target op
AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..55cffa1e22d77 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/TypeUtilities.h"
@@ -139,7 +140,9 @@ using RemSIOpLowering =
using RemUIOpLowering =
VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
using SelectOpLowering =
- VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
+ VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp,
+ AttrConvertPassThrough,
+ /*FailOnUnsupportedFP=*/false>;
using ShLIOpLowering =
VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
arith::AttrConvertOverflowToLLVM>;
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
// CHECK: arith.addf {{.*}} : f4E2M1FN
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+// CHECK: llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
- return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+ %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+ return
}
// -----
|
48876ec to
51d3fca
Compare
| template <typename, typename> typename AttrConvert = | ||
| AttrConvertPassThrough> | ||
| AttrConvertPassThrough, | ||
| bool FailOnUnsupportedFP = true> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you check other dialects to Make sure there aren't other operations that need unsupported floats to go through? I suspect vector.extract could be hitting similar issues
Also. arith.bitcast should allow the unsupported FP types.
More generally, shouldn't this flag be the other way around? Ops with float semantics should be marked as falling on unsupported floats, instead of denying this for all ops and having to find all the exceptions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I inverted the logic.
51d3fca to
63de470
Compare
krzysz00
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
Thanks, I verified that it fixes the issue. |
Carrying revert: - llvm/llvm-project@6c640b8, which should be fixed by llvm/llvm-project#166692 Signed-off-by: hanhanW <[email protected]>
Carrying revert: - llvm/llvm-project@6c640b8, which should be fixed by llvm/llvm-project#166692 Apply fixes for llvm/llvm-project@7557304 --------- Signed-off-by: hanhanW <[email protected]>
|
Can we merge it? I'd like to drop the revert in IREE. Thanks! |
Carrying revert: - iree-org/llvm-project@0ac8fc1, which should be fixed by llvm/llvm-project#166692 Signed-off-by: hanhanW <[email protected]>
|
I believe Matthias is OOO so I'll merge. |
) llvm#166513 broke the lowering of `arith.select` with unsupported FP4 types. For this op, it is fine to convert to `i4`.
#166513 broke the lowering of
arith.selectwith unsupported FP4 types. For this op, it is fine to convert toi4.