diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h index cad6cec761ab8..e7ab63abfeaa1 100644 --- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h @@ -86,7 +86,8 @@ class AttrConvertPassThrough { /// ArrayRef. template typename AttrConvert = - AttrConvertPassThrough> + AttrConvertPassThrough, + bool FailOnUnsupportedFP = false> class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern { "unsupported floating point type"); return success(); }; - for (Value operand : op->getOperands()) - if (failed(checkType(operand))) + if (FailOnUnsupportedFP) { + for (Value operand : op->getOperands()) + if (failed(checkType(operand))) + return failure(); + if (failed(checkType(op->getResult(0)))) return failure(); - if (failed(checkType(op->getResult(0)))) - return failure(); + } // Determine attributes for the target op AttrConvert attrConvert(op); diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp index 03ed4d51cc744..b6099902cc337 100644 --- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp +++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp @@ -36,20 +36,23 @@ namespace { /// attribute. template typename AttrConvert = - AttrConvertPassThrough> + AttrConvertPassThrough, + bool FailOnUnsupportedFP = false> struct ConstrainedVectorConvertToLLVMPattern - : public VectorConvertToLLVMPattern { - using VectorConvertToLLVMPattern::VectorConvertToLLVMPattern; + : public VectorConvertToLLVMPattern { + using VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::VectorConvertToLLVMPattern; LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (Constrained != static_cast(op.getRoundingModeAttr())) return failure(); - return VectorConvertToLLVMPattern::matchAndRewrite(op, adaptor, - rewriter); + return VectorConvertToLLVMPattern< + SourceOp, TargetOp, AttrConvert, + FailOnUnsupportedFP>::matchAndRewrite(op, adaptor, rewriter); } }; @@ -78,7 +81,8 @@ struct IdentityBitcastLowering final using AddFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using AddIOpLowering = VectorConvertToLLVMPattern; @@ -87,53 +91,67 @@ using BitcastOpLowering = VectorConvertToLLVMPattern; using DivFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using DivSIOpLowering = VectorConvertToLLVMPattern; using DivUIOpLowering = VectorConvertToLLVMPattern; -using ExtFOpLowering = VectorConvertToLLVMPattern; +using ExtFOpLowering = VectorConvertToLLVMPattern; using ExtSIOpLowering = VectorConvertToLLVMPattern; using ExtUIOpLowering = VectorConvertToLLVMPattern; using FPToSIOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using FPToUIOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using MaximumFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxNumFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MaxSIOpLowering = VectorConvertToLLVMPattern; using MaxUIOpLowering = VectorConvertToLLVMPattern; using MinimumFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinNumFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MinSIOpLowering = VectorConvertToLLVMPattern; using MinUIOpLowering = VectorConvertToLLVMPattern; using MulFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using MulIOpLowering = VectorConvertToLLVMPattern; using NegFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using OrIOpLowering = VectorConvertToLLVMPattern; using RemFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using RemSIOpLowering = VectorConvertToLLVMPattern; using RemUIOpLowering = @@ -151,21 +169,25 @@ using SIToFPOpLowering = VectorConvertToLLVMPattern; using SubFOpLowering = VectorConvertToLLVMPattern; + arith::AttrConvertFastMathToLLVM, + /*FailOnUnsupportedFP=*/true>; using SubIOpLowering = VectorConvertToLLVMPattern; using TruncFOpLowering = ConstrainedVectorConvertToLLVMPattern; + false, AttrConvertPassThrough, + /*FailOnUnsupportedFP=*/true>; using ConstrainedTruncFOpLowering = ConstrainedVectorConvertToLLVMPattern< arith::TruncFOp, LLVM::ConstrainedFPTruncIntr, true, - arith::AttrConverterConstrainedFPToLLVM>; + arith::AttrConverterConstrainedFPToLLVM, /*FailOnUnsupportedFP=*/true>; using TruncIOpLowering = VectorConvertToLLVMPattern; using UIToFPOpLowering = - VectorConvertToLLVMPattern; + VectorConvertToLLVMPattern; using XOrIOpLowering = VectorConvertToLLVMPattern; //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir index b5dcb01d3dc6b..5f1ec66234df2 100644 --- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir @@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref) -> memref { // CHECK: arith.addf {{.*}} : f4E2M1FN // CHECK: arith.addf {{.*}} : vector<4xf4E2M1FN> // CHECK: arith.addf {{.*}} : vector<8x4xf4E2M1FN> -func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) { +// CHECK: llvm.select {{.*}} : i1, i4 +func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) { %0 = arith.addf %arg0, %arg0 : f4E2M1FN %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN> %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN> - return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN> + %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN + return } // -----