Skip to content

Commit 48876ec

Browse files
[mlir][arith] Fix arith.select lowering after #166513
1 parent 6c640b8 commit 48876ec

File tree

3 files changed

+16
-8
lines changed

3 files changed

+16
-8
lines changed

mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
8686
/// ArrayRef<NamedAttribute>.
8787
template <typename SourceOp, typename TargetOp,
8888
template <typename, typename> typename AttrConvert =
89-
AttrConvertPassThrough>
89+
AttrConvertPassThrough,
90+
bool FailOnUnsupportedFP = true>
9091
class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
9192
public:
9293
using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
123124
"unsupported floating point type");
124125
return success();
125126
};
126-
for (Value operand : op->getOperands())
127-
if (failed(checkType(operand)))
127+
if (FailOnUnsupportedFP) {
128+
for (Value operand : op->getOperands())
129+
if (failed(checkType(operand)))
130+
return failure();
131+
if (failed(checkType(op->getResult(0))))
128132
return failure();
129-
if (failed(checkType(op->getResult(0))))
130-
return failure();
133+
}
131134

132135
// Determine attributes for the target op
133136
AttrConvert<SourceOp, TargetOp> attrConvert(op);

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
1515
#include "mlir/Dialect/Arith/IR/Arith.h"
1616
#include "mlir/Dialect/Arith/Transforms/Passes.h"
17+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1718
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
1819
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1920
#include "mlir/IR/TypeUtilities.h"
@@ -139,7 +140,9 @@ using RemSIOpLowering =
139140
using RemUIOpLowering =
140141
VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
141142
using SelectOpLowering =
142-
VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
143+
VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp,
144+
AttrConvertPassThrough,
145+
/*FailOnUnsupportedFP=*/false>;
143146
using ShLIOpLowering =
144147
VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
145148
arith::AttrConvertOverflowToLLVM>;

mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
754754
// CHECK: arith.addf {{.*}} : f4E2M1FN
755755
// CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN>
756756
// CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN>
757-
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
757+
// CHECK: llvm.select {{.*}} : i1, i4
758+
func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
758759
%0 = arith.addf %arg0, %arg0 : f4E2M1FN
759760
%1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
760761
%2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
761-
return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
762+
%3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
763+
return
762764
}
763765

764766
// -----

0 commit comments

Comments
 (0)