Skip to content

Conversation

matthias-springer
Copy link
Member

When two complex.bitcast ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an arith.bitcast should be generated. Otherwise, the generated complex.bitcast op is invalid.

Also remove a pattern that convertes non-complex -> non-complex complex.bitcast ops to arith.bitcast. Such complex.bitcast ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with -debug (which will should intermediate IR that does not verify) or with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS (#74270).

When two `complex.bitcast` ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an `arith.bitcast` should be generated. Otherwise, the generated `complex.bitcast` op is invalid.

Also remove a pattern that convertes non-complex -> non-complex `complex.bitcast` ops to `arith.bitcast`. Such `complex.bitcast` ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with `-debug` (which will should intermediate IR that does not verify) or with `MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS` (llvm#74270).
@llvmbot llvmbot added mlir mlir:complex MLIR complex dialect labels Dec 4, 2023
@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir-complex

Author: Matthias Springer (matthias-springer)

Changes

When two complex.bitcast ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an arith.bitcast should be generated. Otherwise, the generated complex.bitcast op is invalid.

Also remove a pattern that convertes non-complex -> non-complex complex.bitcast ops to arith.bitcast. Such complex.bitcast ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with -debug (which will should intermediate IR that does not verify) or with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS (#74270).


Full diff: https://github.com/llvm/llvm-project/pull/74271.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (+12-19)
  • (modified) mlir/test/Dialect/Complex/invalid.mlir (+1-1)
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
   }
 
   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
-    return emitOpError("requires input or output is a complex type");
+    return emitOpError(
+        "requires that either input or output has a complex type");
   }
 
   if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
   LogicalResult matchAndRewrite(BitcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
-      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
-                                             defining.getOperand());
+      if (isa<ComplexType>(op.getType()) ||
+          isa<ComplexType>(defining.getOperand().getType())) {
+        // complex.bitcast requires that input or output is complex.
+        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                               defining.getOperand());
+      } else {
+        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      }
       return success();
     }
 
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
   }
 };
 
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
-  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BitcastOp op,
-                                PatternRewriter &rewriter) const override {
-    if (isa<ComplexType>(op.getType()) ||
-        isa<ComplexType>(op.getOperand().getType()))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
-                                                  op.getOperand());
-    return success();
-  }
-};
-
 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
 // -----
 
 func.func @complex_bitcast_i64(%arg0 : i64) {
-  // expected-error @+1 {{op requires input or output is a complex type}}
+  // expected-error @+1 {{op requires that either input or output has a complex type}}
   %0 = complex.bitcast %arg0: i64 to f64
   return
 }

@llvmbot
Copy link
Member

llvmbot commented Dec 4, 2023

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

When two complex.bitcast ops are folded and the resulting bitcast is a non-complex -> non-complex bitcast, an arith.bitcast should be generated. Otherwise, the generated complex.bitcast op is invalid.

Also remove a pattern that convertes non-complex -> non-complex complex.bitcast ops to arith.bitcast. Such complex.bitcast ops are invalid and should not appear in the input.

Note: This bug can only be triggered by running with -debug (which will should intermediate IR that does not verify) or with MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS (#74270).


Full diff: https://github.com/llvm/llvm-project/pull/74271.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Complex/IR/ComplexOps.cpp (+12-19)
  • (modified) mlir/test/Dialect/Complex/invalid.mlir (+1-1)
diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
index 8fd914dd107ff..6d8706775758e 100644
--- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
+++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp
@@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() {
   }
 
   if (isa<ComplexType>(operandType) == isa<ComplexType>(resultType)) {
-    return emitOpError("requires input or output is a complex type");
+    return emitOpError(
+        "requires that either input or output has a complex type");
   }
 
   if (isa<ComplexType>(resultType))
@@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern<BitcastOp> {
   LogicalResult matchAndRewrite(BitcastOp op,
                                 PatternRewriter &rewriter) const override {
     if (auto defining = op.getOperand().getDefiningOp<BitcastOp>()) {
-      rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
-                                             defining.getOperand());
+      if (isa<ComplexType>(op.getType()) ||
+          isa<ComplexType>(defining.getOperand().getType())) {
+        // complex.bitcast requires that input or output is complex.
+        rewriter.replaceOpWithNewOp<BitcastOp>(op, op.getType(),
+                                               defining.getOperand());
+      } else {
+        rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
+                                                      defining.getOperand());
+      }
       return success();
     }
 
@@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern<arith::BitcastOp> {
   }
 };
 
-struct ArithBitcast final : OpRewritePattern<BitcastOp> {
-  using OpRewritePattern<complex::BitcastOp>::OpRewritePattern;
-
-  LogicalResult matchAndRewrite(BitcastOp op,
-                                PatternRewriter &rewriter) const override {
-    if (isa<ComplexType>(op.getType()) ||
-        isa<ComplexType>(op.getOperand().getType()))
-      return failure();
-
-    rewriter.replaceOpWithNewOp<arith::BitcastOp>(op, op.getType(),
-                                                  op.getOperand());
-    return success();
-  }
-};
-
 void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                             MLIRContext *context) {
-  results.add<ArithBitcast, MergeComplexBitcast, MergeArithBitcast>(context);
+  results.add<MergeComplexBitcast, MergeArithBitcast>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir
index 51b1b0fda202a..ba6995b727bc2 100644
--- a/mlir/test/Dialect/Complex/invalid.mlir
+++ b/mlir/test/Dialect/Complex/invalid.mlir
@@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() {
 // -----
 
 func.func @complex_bitcast_i64(%arg0 : i64) {
-  // expected-error @+1 {{op requires input or output is a complex type}}
+  // expected-error @+1 {{op requires that either input or output has a complex type}}
   %0 = complex.bitcast %arg0: i64 to f64
   return
 }

@matthias-springer
Copy link
Member Author

Should we fix "bugs" like this one? Is it actually bug? I think there is at the moment no requirement that the IR has to verify after each pattern application.

I was looking into this because I had to debug a pass that applies multiple patterns and I wanted to see how an op was getting simplified. So I was running with -debug. Then I saw invalid IR half way through the process and I thought "if the IR already broken at this point, I don't even have to look further".

@matthias-springer matthias-springer merged commit 192439d into llvm:main Dec 5, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:complex MLIR complex dialect mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants