Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,13 @@ OpFoldResult arith::MulFOp::fold(FoldAdaptor adaptor) {
if (matchPattern(adaptor.getRhs(), m_OneFloat()))
return getLhs();

if (arith::bitEnumContainsAll(getFastmath(), arith::FastMathFlags::nnan |
arith::FastMathFlags::nsz)) {
Comment on lines +1285 to +1286
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesn't it need also ninf? inf * 0 -> Nan

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to check this with Alive: https://alive2.llvm.org/ce/z/wvNkdy

Copy link
Member

@kuhar kuhar Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because nnan applies to the result as well:

nnan
No NaNs - Allow optimizations to assume the arguments and result are not NaN.

// mulf(x, 0) -> 0
if (matchPattern(adaptor.getRhs(), m_AnyZeroFloat()))
return getRhs();
}

return constFoldBinaryOp<FloatAttr>(
adaptor.getOperands(),
[](const APFloat &a, const APFloat &b) { return a * b; });
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,18 @@ func.func @test_mulf1(%arg0 : f32, %arg1 : f32) -> (f32) {
return %2 : f32
}

// CHECK-LABEL: @test_mulf2(
func.func @test_mulf2(%arg0 : f32) -> (f32, f32) {
// CHECK-DAG: %[[C0:.+]] = arith.constant 0.000000e+00 : f32
// CHECK-DAG: %[[C0n:.+]] = arith.constant -0.000000e+00 : f32
// CHECK-NEXT: return %[[C0]], %[[C0n]]
%c0 = arith.constant 0.0 : f32
%c0n = arith.constant -0.0 : f32
%0 = arith.mulf %c0, %arg0 fastmath<nnan,nsz> : f32
%1 = arith.mulf %c0n, %arg0 fastmath<nnan,nsz> : f32
return %0, %1 : f32, f32
}

// -----

// CHECK-LABEL: @test_divf(
Expand Down