diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td index 4d415aeac8f58..48346abd84285 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.td @@ -64,4 +64,12 @@ def MathExpandOpsPass : Pass<"math-expand-ops"> { ]; } +def MathSincosFusionPass : Pass<"math-sincos-fusion"> { + let summary = "Fuse sin and cos operations."; + let description = [{ + Fuse sin and cos operations into a sincos operation. + }]; + let dependentDialects = ["math::MathDialect"]; +} + #endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt index ff62b515533c3..8899c3a1d1a42 100644 --- a/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Math/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRMathTransforms ExpandOps.cpp ExtendToSupportedTypes.cpp PolynomialApproximation.cpp + SincosFusion.cpp UpliftToFMA.cpp ADDITIONAL_HEADER_DIRS diff --git a/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp new file mode 100644 index 0000000000000..69407df201cfa --- /dev/null +++ b/mlir/lib/Dialect/Math/Transforms/SincosFusion.cpp @@ -0,0 +1,80 @@ +//===- SincosFusion.cpp - Fuse sin/cos into sincos -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Math/Transforms/Passes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::math; + +namespace { + +/// Fuse a math.sin and math.cos in the same block that use the same operand and +/// have identical fastmath flags into a single math.sincos. +struct SincosFusionPattern : OpRewritePattern { + using Base::Base; + + LogicalResult matchAndRewrite(math::SinOp sinOp, + PatternRewriter &rewriter) const override { + Value operand = sinOp.getOperand(); + mlir::arith::FastMathFlags sinFastMathFlags = sinOp.getFastmath(); + + math::CosOp cosOp = nullptr; + sinOp->getBlock()->walk([&](math::CosOp op) { + if (op.getOperand() == operand && op.getFastmath() == sinFastMathFlags) { + cosOp = op; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + + if (!cosOp) + return failure(); + + Operation *firstOp = sinOp->isBeforeInBlock(cosOp) ? sinOp.getOperation() + : cosOp.getOperation(); + rewriter.setInsertionPoint(firstOp); + + Type elemType = sinOp.getType(); + auto sincos = math::SincosOp::create(rewriter, firstOp->getLoc(), + TypeRange{elemType, elemType}, operand, + sinOp.getFastmathAttr()); + + rewriter.replaceOp(sinOp, sincos.getSin()); + rewriter.replaceOp(cosOp, sincos.getCos()); + return success(); + } +}; + +} // namespace + +namespace mlir::math { +#define GEN_PASS_DEF_MATHSINCOSFUSIONPASS +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" +} // namespace mlir::math + +namespace { + +struct MathSincosFusionPass final + : math::impl::MathSincosFusionPassBase { + using MathSincosFusionPassBase::MathSincosFusionPassBase; + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + + GreedyRewriteConfig config; + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) + return signalPassFailure(); + } +}; + +} // namespace diff --git a/mlir/test/Dialect/Math/sincos-fusion.mlir b/mlir/test/Dialect/Math/sincos-fusion.mlir new file mode 100644 index 0000000000000..29fb9f12475b8 --- /dev/null +++ b/mlir/test/Dialect/Math/sincos-fusion.mlir @@ -0,0 +1,86 @@ +// RUN: mlir-opt -math-sincos-fusion %s | FileCheck %s + +// CHECK-LABEL: func.func @sincos_fusion( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) { +// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32 +// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32 +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32 +// CHECK: } +func.func @sincos_fusion(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) { + %0 = math.sin %arg0 : f32 + %1 = math.cos %arg0 : f32 + + %2 = math.cos %arg1 : f32 + %3 = math.sin %arg1 : f32 + + func.return %0, %1, %2, %3 : f32, f32, f32, f32 +} + +func.func private @sink(%arg0 : f32) + +// CHECK: func.func private @sink(f32) +// CHECK-LABEL: func.func @sincos_ensure_ssa_dominance( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: f32) -> (f32, f32, f32, f32) { +// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] : f32 +// CHECK: call @sink(%[[VAL_0]]) : (f32) -> () +// CHECK: %[[VAL_2:.*]], %[[VAL_3:.*]] = math.sincos %[[ARG1]] : f32 +// CHECK: call @sink(%[[VAL_3]]) : (f32) -> () +// CHECK: return %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_2]] : f32, f32, f32, f32 +// CHECK: } +func.func @sincos_ensure_ssa_dominance(%arg0 : f32, %arg1 : f32) -> (f32, f32, f32, f32) { + %0 = math.sin %arg0 : f32 + func.call @sink(%0) : (f32) -> () + %1 = math.cos %arg0 : f32 + %2 = math.cos %arg1 : f32 + func.call @sink(%2) : (f32) -> () + %3 = math.sin %arg1 : f32 + func.return %0, %1, %2, %3 : f32, f32, f32, f32 +} + +// CHECK-LABEL: func.func @sincos_fusion_no_match_fmf( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) { +// CHECK: %[[VAL_0:.*]] = math.sin %[[ARG0]] fastmath : f32 +// CHECK: %[[VAL_1:.*]] = math.cos %[[ARG0]] : f32 +// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32 +// CHECK: } +func.func @sincos_fusion_no_match_fmf(%arg0 : f32) -> (f32, f32) { + %0 = math.sin %arg0 fastmath : f32 + %1 = math.cos %arg0 : f32 + func.return %0, %1 : f32, f32 +} + +// CHECK-LABEL: func.func @sincos_no_fusion_different_block( +// CHECK-SAME: %[[ARG0:.*]]: f32, +// CHECK-SAME: %[[ARG1:.*]]: i1) -> f32 { +// CHECK: %[[VAL_0:.*]] = scf.if %[[ARG1]] -> (f32) { +// CHECK: %[[VAL_1:.*]] = math.sin %[[ARG0]] : f32 +// CHECK: scf.yield %[[VAL_1]] : f32 +// CHECK: } else { +// CHECK: %[[VAL_2:.*]] = math.cos %[[ARG0]] : f32 +// CHECK: scf.yield %[[VAL_2]] : f32 +// CHECK: } +// CHECK: return %[[VAL_0]] : f32 +// CHECK: } +func.func @sincos_no_fusion_different_block(%arg0 : f32, %flag : i1) -> f32 { + %0 = scf.if %flag -> f32 { + %s = math.sin %arg0 : f32 + scf.yield %s : f32 + } else { + %c = math.cos %arg0 : f32 + scf.yield %c : f32 + } + func.return %0 : f32 +} + +// CHECK-LABEL: func.func @sincos_fusion_preserve_fastmath( +// CHECK-SAME: %[[ARG0:.*]]: f32) -> (f32, f32) { +// CHECK: %[[VAL_0:.*]], %[[VAL_1:.*]] = math.sincos %[[ARG0]] fastmath : f32 +// CHECK: return %[[VAL_0]], %[[VAL_1]] : f32, f32 +// CHECK: } +func.func @sincos_fusion_preserve_fastmath(%arg0 : f32) -> (f32, f32) { + %0 = math.sin %arg0 fastmath : f32 + %1 = math.cos %arg0 fastmath : f32 + func.return %0, %1 : f32, f32 +}