diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 7cfd6d3a98df8..898d76ce8d9b5 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) { if (matchPattern(adaptor.getRhs(), m_OneFloat())) return getLhs(); + if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan | + arith::FastMathFlags::nsz)) { + // mulf(x, 0) -> 0 + if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat())) + return getRhs(); + } + return constFoldBinaryOp( adaptor.getOperands(), [](const APFloat &a, const APFloat &b) { return a * b; }); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index ca3de3a2d7703..2fe0995c9d4df 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -2216,6 +2216,18 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) { return %2 : f32 } +// CHECK-LABEL: @test_mulf2( +func.func @test_mulf2(%arg0 : f32) -> (f32, f32) { + // CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[C0n:.+]] = arith.constant -0.000000e+00 : f32 + // CHECK-NEXT: return %[[C0]], %[[C0n]] + %c0 = arith.constant 0.0 : f32 + %c0n = arith.constant -0.0 : f32 + %0 = arith.mulf %c0, %arg0 fastmath : f32 + %1 = arith.mulf %c0n, %arg0 fastmath : f32 + return %0, %1 : f32, f32 +} + // ----- // CHECK-LABEL: @test_divf(