diff --git a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp index 8fd914dd107ff..6d8706775758e 100644 --- a/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp +++ b/mlir/lib/Dialect/Complex/IR/ComplexOps.cpp @@ -100,7 +100,8 @@ LogicalResult BitcastOp::verify() { } if (isa(operandType) == isa(resultType)) { - return emitOpError("requires input or output is a complex type"); + return emitOpError( + "requires that either input or output has a complex type"); } if (isa(resultType)) @@ -125,8 +126,15 @@ struct MergeComplexBitcast final : OpRewritePattern { LogicalResult matchAndRewrite(BitcastOp op, PatternRewriter &rewriter) const override { if (auto defining = op.getOperand().getDefiningOp()) { - rewriter.replaceOpWithNewOp(op, op.getType(), - defining.getOperand()); + if (isa(op.getType()) || + isa(defining.getOperand().getType())) { + // complex.bitcast requires that input or output is complex. + rewriter.replaceOpWithNewOp(op, op.getType(), + defining.getOperand()); + } else { + rewriter.replaceOpWithNewOp(op, op.getType(), + defining.getOperand()); + } return success(); } @@ -155,24 +163,9 @@ struct MergeArithBitcast final : OpRewritePattern { } }; -struct ArithBitcast final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BitcastOp op, - PatternRewriter &rewriter) const override { - if (isa(op.getType()) || - isa(op.getOperand().getType())) - return failure(); - - rewriter.replaceOpWithNewOp(op, op.getType(), - op.getOperand()); - return success(); - } -}; - void BitcastOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Complex/invalid.mlir b/mlir/test/Dialect/Complex/invalid.mlir index 51b1b0fda202a..ba6995b727bc2 100644 --- a/mlir/test/Dialect/Complex/invalid.mlir +++ b/mlir/test/Dialect/Complex/invalid.mlir @@ -25,7 +25,7 @@ func.func @complex_constant_two_different_element_types() { // ----- func.func @complex_bitcast_i64(%arg0 : i64) { - // expected-error @+1 {{op requires input or output is a complex type}} + // expected-error @+1 {{op requires that either input or output has a complex type}} %0 = complex.bitcast %arg0: i64 to f64 return }