Skip to content

Conversation

@matthias-springer
Copy link
Member

#166513 broke the lowering of arith.select with unsupported FP4 types. For this op, it is fine to convert to i4.

@llvmbot
Copy link
Member

llvmbot commented Nov 6, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Matthias Springer (matthias-springer)

Changes

#166513 broke the lowering of arith.select with unsupported FP4 types. For this op, it is fine to convert to i4.


Full diff: https://github.com/llvm/llvm-project/pull/166692.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+8-5)
  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+4-1)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+4-2)
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index cad6cec761ab8..b8e3023b25569 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -86,7 +86,8 @@ class AttrConvertPassThrough {
 /// ArrayRef<NamedAttribute>.
 template <typename SourceOp, typename TargetOp,
           template <typename, typename> typename AttrConvert =
-              AttrConvertPassThrough>
+              AttrConvertPassThrough,
+          bool FailOnUnsupportedFP = true>
 class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 public:
   using ConvertOpToLLVMPattern<SourceOp>::ConvertOpToLLVMPattern;
@@ -123,11 +124,13 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
                                            "unsupported floating point type");
       return success();
     };
-    for (Value operand : op->getOperands())
-      if (failed(checkType(operand)))
+    if (FailOnUnsupportedFP) {
+      for (Value operand : op->getOperands())
+        if (failed(checkType(operand)))
+          return failure();
+      if (failed(checkType(op->getResult(0))))
         return failure();
-    if (failed(checkType(op->getResult(0))))
-      return failure();
+    }
 
     // Determine attributes for the target op
     AttrConvert<SourceOp, TargetOp> attrConvert(op);
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 03ed4d51cc744..55cffa1e22d77 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -14,6 +14,7 @@
 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
 #include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/TypeUtilities.h"
@@ -139,7 +140,9 @@ using RemSIOpLowering =
 using RemUIOpLowering =
     VectorConvertToLLVMPattern<arith::RemUIOp, LLVM::URemOp>;
 using SelectOpLowering =
-    VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp>;
+    VectorConvertToLLVMPattern<arith::SelectOp, LLVM::SelectOp,
+                               AttrConvertPassThrough,
+                               /*FailOnUnsupportedFP=*/false>;
 using ShLIOpLowering =
     VectorConvertToLLVMPattern<arith::ShLIOp, LLVM::ShlOp,
                                arith::AttrConvertOverflowToLLVM>;
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index b5dcb01d3dc6b..5f1ec66234df2 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -754,11 +754,13 @@ func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
 //       CHECK:   arith.addf {{.*}} : f4E2M1FN
 //       CHECK:   arith.addf {{.*}} : vector<4xf4E2M1FN>
 //       CHECK:   arith.addf {{.*}} : vector<8x4xf4E2M1FN>
-func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>) -> (f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>) {
+//       CHECK:   llvm.select {{.*}} : i1, i4
+func.func @unsupported_fp_type(%arg0: f4E2M1FN, %arg1: vector<4xf4E2M1FN>, %arg2: vector<8x4xf4E2M1FN>, %arg3: f4E2M1FN, %arg4: i1) {
   %0 = arith.addf %arg0, %arg0 : f4E2M1FN
   %1 = arith.addf %arg1, %arg1 : vector<4xf4E2M1FN>
   %2 = arith.addf %arg2, %arg2 : vector<8x4xf4E2M1FN>
-  return %0, %1, %2 : f4E2M1FN, vector<4xf4E2M1FN>, vector<8x4xf4E2M1FN>
+  %3 = arith.select %arg4, %arg0, %arg3 : f4E2M1FN
+  return
 }
 
 // -----

@matthias-springer matthias-springer force-pushed the users/matthias-springer/fix_select branch from 48876ec to 51d3fca Compare November 6, 2025 03:26
template <typename, typename> typename AttrConvert =
AttrConvertPassThrough>
AttrConvertPassThrough,
bool FailOnUnsupportedFP = true>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you check other dialects to Make sure there aren't other operations that need unsupported floats to go through? I suspect vector.extract could be hitting similar issues

Also. arith.bitcast should allow the unsupported FP types.

More generally, shouldn't this flag be the other way around? Ops with float semantics should be marked as falling on unsupported floats, instead of denying this for all ops and having to find all the exceptions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I inverted the logic.

@matthias-springer matthias-springer force-pushed the users/matthias-springer/fix_select branch from 51d3fca to 63de470 Compare November 6, 2025 08:24
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hanhanW
Copy link
Contributor

hanhanW commented Nov 6, 2025

Thanks, I verified that it fixes the issue.

hanhanW added a commit to iree-org/iree that referenced this pull request Nov 6, 2025
Carrying revert:
- llvm/llvm-project@6c640b8, which should be fixed by llvm/llvm-project#166692

Signed-off-by: hanhanW <[email protected]>
hanhanW added a commit to iree-org/iree that referenced this pull request Nov 6, 2025
Carrying revert:
-
llvm/llvm-project@6c640b8,
which should be fixed by
llvm/llvm-project#166692

Apply fixes for
llvm/llvm-project@7557304

---------

Signed-off-by: hanhanW <[email protected]>
@hanhanW
Copy link
Contributor

hanhanW commented Nov 7, 2025

Can we merge it? I'd like to drop the revert in IREE. Thanks!

hanhanW added a commit to iree-org/iree that referenced this pull request Nov 7, 2025
@makslevental
Copy link
Contributor

I believe Matthias is OOO so I'll merge.

@makslevental makslevental merged commit 3740368 into main Nov 7, 2025
10 checks passed
@makslevental makslevental deleted the users/matthias-springer/fix_select branch November 7, 2025 17:59
vinay-deshmukh pushed a commit to vinay-deshmukh/llvm-project that referenced this pull request Nov 8, 2025
)

llvm#166513 broke the lowering of `arith.select` with unsupported FP4 types.
For this op, it is fine to convert to `i4`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants