Skip to content

Commit 51a1aab

Browse files
authored
[mlir][math] Add clampf and clean math ExpandOps API (#151153)
This patch adds the `clampf` operation to the math dialect. The semantics op are defined as: ``` clampf(x, min_v, max_v) = max(min(x, min_v), max_v) ``` The reasoning behind adding this operation is that some GPU vendors offer specialized intrinsics for this operation, or subsets of this operation. For example, [__saturatef](https://docs.nvidia.com/cuda/cuda-math-api/cuda_math_api/group__CUDA__MATH__INTRINSIC__SINGLE.html#group__cuda__math__intrinsic__single_1ga2c84f08e0db7117a14509d21c3aec04e) in NVIDIA GPUs, or `__builtin_amdgcn_fmed3f` in AMD GPUs. This patch also removes `test-expand-math` in favor of `math-expand-ops`. Finally, it removes individual expansion population API calls like `populateExpandCoshPattern` in favor of: ```C++ void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef<StringRef> opMnemonics = {}); ```
1 parent 036b33d commit 51a1aab

File tree

11 files changed

+188
-147
lines changed

11 files changed

+188
-147
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,37 @@ def Math_CeilOp : Math_FloatUnaryOp<"ceil"> {
352352
let hasFolder = 1;
353353
}
354354

355+
//===----------------------------------------------------------------------===//
356+
// ClampFOp
357+
//===----------------------------------------------------------------------===//
358+
359+
def Math_ClampFOp : Math_FloatTernaryOp<"clampf"> {
360+
let summary = "floating point clamping operation";
361+
let description = [{
362+
The `clampf` operation takes three operands and returns one result, each of
363+
these is required to be the same type. Operands must be of floating point type
364+
(i.e., scalar, tensor or vector).
365+
366+
The semantics of the operation are described by:
367+
```
368+
clampf(value, min, max) = maxf(minf(value, min), max)
369+
```
370+
371+
Example:
372+
373+
```mlir
374+
%d = math.clampf %value to [%min, %max] : f64
375+
```
376+
}];
377+
let arguments = (ins FloatLike:$value, FloatLike:$min, FloatLike:$max,
378+
DefaultValuedAttr<Arith_FastMathAttr,
379+
"::mlir::arith::FastMathFlags::none">:$fastmath);
380+
let assemblyFormat = [{
381+
$value `to` ` ` `[` $min `,` $max `]` (`fastmath` `` $fastmath^)?
382+
attr-dict `:` type($result)
383+
}];
384+
}
385+
355386
//===----------------------------------------------------------------------===//
356387
// CopySignOp
357388
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,16 @@ class ConversionTarget;
2323
class RewritePatternSet;
2424
class TypeConverter;
2525

26-
void populateExpandCtlzPattern(RewritePatternSet &patterns);
27-
void populateExpandTanPattern(RewritePatternSet &patterns);
28-
void populateExpandSinhPattern(RewritePatternSet &patterns);
29-
void populateExpandCoshPattern(RewritePatternSet &patterns);
30-
void populateExpandTanhPattern(RewritePatternSet &patterns);
31-
void populateExpandAsinhPattern(RewritePatternSet &patterns);
32-
void populateExpandAcoshPattern(RewritePatternSet &patterns);
33-
void populateExpandAtanhPattern(RewritePatternSet &patterns);
34-
void populateExpandFmaFPattern(RewritePatternSet &patterns);
35-
void populateExpandCeilFPattern(RewritePatternSet &patterns);
36-
void populateExpandExp2FPattern(RewritePatternSet &patterns);
37-
void populateExpandPowFPattern(RewritePatternSet &patterns);
38-
void populateExpandFPowIPattern(RewritePatternSet &patterns);
39-
void populateExpandRoundFPattern(RewritePatternSet &patterns);
40-
void populateExpandRoundEvenPattern(RewritePatternSet &patterns);
41-
void populateExpandRsqrtPattern(RewritePatternSet &patterns);
26+
namespace math {
27+
/// Adds patterns to expand math operations into other more fundamental
28+
/// operations. For example, hyperbolic functions are expanded into expressions
29+
/// using `exp`. If `opMnemonics` is empty then all available patterns will be
30+
/// added, otherwise only the patterns corresponding to ops in `opMnemonics`
31+
/// will be added to the set.
32+
void populateExpansionPatterns(RewritePatternSet &patterns,
33+
ArrayRef<StringRef> opMnemonics = {});
34+
} // namespace math
35+
4236
void populateMathAlgebraicSimplificationPatterns(RewritePatternSet &patterns);
4337

4438
struct MathPolynomialApproximationOptions {

mlir/include/mlir/Dialect/Math/Transforms/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,4 +44,24 @@ def MathExtendToSupportedTypes : Pass<"math-extend-to-supported-types"> {
4444
let dependentDialects = ["math::MathDialect", "arith::ArithDialect"];
4545
}
4646

47+
def MathExpandOpsPass : Pass<"math-expand-ops"> {
48+
let summary = "Expand math operations.";
49+
let description = [{
50+
Expands some math operations into more fundamental operations, allowing them
51+
to be subsequently lowered through these. For example, hyperbolic functions
52+
are transformed into their expanded form containing only `exp` functions.
53+
54+
The `ops` parameter can be used to apply only a subset of all the
55+
available expansions, these must correspond to the operation mnemonic.
56+
For example, `ops=sinh,acosh` will expand only `math.sinh` and
57+
`math.acosh` operations. If the list is empty, then all expansions are
58+
applied.
59+
}];
60+
let dependentDialects = ["arith::ArithDialect"];
61+
let options = [
62+
ListOption<"opMnemonics", "ops", "std::string",
63+
"Operations to expand.">
64+
];
65+
}
66+
4767
#endif // MLIR_DIALECT_MATH_TRANSFORMS_PASSES

mlir/lib/Dialect/Math/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
add_mlir_dialect_library(MLIRMathTransforms
22
AlgebraicSimplification.cpp
3-
ExpandPatterns.cpp
3+
ExpandOps.cpp
44
ExtendToSupportedTypes.cpp
55
PolynomialApproximation.cpp
66
UpliftToFMA.cpp

mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp renamed to mlir/lib/Dialect/Math/Transforms/ExpandOps.cpp

Lines changed: 77 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,18 @@
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
1414
#include "mlir/Dialect/Math/IR/Math.h"
1515
#include "mlir/Dialect/Math/Transforms/Passes.h"
16-
#include "mlir/Dialect/SCF/IR/SCF.h"
17-
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1816
#include "mlir/IR/Builders.h"
17+
#include "mlir/IR/Matchers.h"
1918
#include "mlir/IR/TypeUtilities.h"
20-
#include "mlir/Transforms/DialectConversion.h"
19+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2120

2221
using namespace mlir;
2322

23+
namespace mlir::math {
24+
#define GEN_PASS_DEF_MATHEXPANDOPSPASS
25+
#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26+
} // namespace mlir::math
27+
2428
/// Create a float constant.
2529
static Value createFloatConst(Location loc, Type type, APFloat value,
2630
OpBuilder &b) {
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
661665
return success();
662666
}
663667

664-
void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
665-
patterns.add(convertCtlzOp);
666-
}
667-
668-
void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) {
669-
patterns.add(convertSinhOp);
670-
}
671-
672-
void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) {
673-
patterns.add(convertCoshOp);
674-
}
675-
676-
void mlir::populateExpandTanPattern(RewritePatternSet &patterns) {
677-
patterns.add(convertTanOp);
678-
}
679-
680-
void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
681-
patterns.add(convertTanhOp);
682-
}
683-
684-
void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) {
685-
patterns.add(convertAsinhOp);
686-
}
687-
688-
void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) {
689-
patterns.add(convertAcoshOp);
690-
}
691-
692-
void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) {
693-
patterns.add(convertAtanhOp);
694-
}
695-
696-
void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) {
697-
patterns.add(convertFmaFOp);
698-
}
699-
700-
void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) {
701-
patterns.add(convertCeilOp);
702-
}
703-
704-
void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) {
705-
patterns.add(convertExp2fOp);
706-
}
707-
708-
void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) {
709-
patterns.add(convertPowfOp);
710-
}
711-
712-
void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) {
713-
patterns.add(convertFPowIOp);
714-
}
715-
716-
void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) {
717-
patterns.add(convertRoundOp);
668+
// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
669+
static LogicalResult convertClampfOp(math::ClampFOp op,
670+
PatternRewriter &rewriter) {
671+
auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
672+
op.getMin(), op.getFastmath());
673+
rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(),
674+
op.getFastmath());
675+
return success();
718676
}
719677

720-
void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) {
721-
patterns.add(convertRoundEvenOp);
678+
void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns,
679+
ArrayRef<StringRef> opMnemonics) {
680+
auto filter = [&](StringRef name) {
681+
// This should be a static assert and `consume_front` take a twine, but none
682+
// is currently possible. TODO: augment `StringRef::consume_front` and make
683+
// `getDialectNamespace` use `std::string_view`.
684+
assert("math" == MathDialect::getDialectNamespace());
685+
name.consume_front("math.");
686+
return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
687+
};
688+
if (filter(CountLeadingZerosOp::getOperationName()))
689+
patterns.add(convertCtlzOp);
690+
if (filter(SinhOp::getOperationName()))
691+
patterns.add(convertSinhOp);
692+
if (filter(CoshOp::getOperationName()))
693+
patterns.add(convertCoshOp);
694+
if (filter(TanOp::getOperationName()))
695+
patterns.add(convertTanOp);
696+
if (filter(TanhOp::getOperationName()))
697+
patterns.add(convertTanhOp);
698+
if (filter(AsinhOp::getOperationName()))
699+
patterns.add(convertAsinhOp);
700+
if (filter(AcoshOp::getOperationName()))
701+
patterns.add(convertAcoshOp);
702+
if (filter(AtanhOp::getOperationName()))
703+
patterns.add(convertAtanhOp);
704+
if (filter(FmaOp::getOperationName()))
705+
patterns.add(convertFmaFOp);
706+
if (filter(CeilOp::getOperationName()))
707+
patterns.add(convertCeilOp);
708+
if (filter(Exp2Op::getOperationName()))
709+
patterns.add(convertExp2fOp);
710+
if (filter(PowFOp::getOperationName()))
711+
patterns.add(convertPowfOp);
712+
if (filter(FPowIOp::getOperationName()))
713+
patterns.add(convertFPowIOp);
714+
if (filter(RoundOp::getOperationName()))
715+
patterns.add(convertRoundOp);
716+
if (filter(RoundEvenOp::getOperationName()))
717+
patterns.add(convertRoundEvenOp);
718+
if (filter(RsqrtOp::getOperationName()))
719+
patterns.add(convertRsqrtOp);
720+
if (filter(ClampFOp::getOperationName()))
721+
patterns.add(convertClampfOp);
722722
}
723723

724-
void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) {
725-
patterns.add(convertRsqrtOp);
726-
}
724+
//===----------------------------------------------------------------------===//
725+
// MathExpandOpsPass pass
726+
//===----------------------------------------------------------------------===//
727+
namespace {
728+
struct MathExpandOpsPass final
729+
: math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
730+
using MathExpandOpsPassBase::MathExpandOpsPassBase;
731+
732+
void runOnOperation() override {
733+
RewritePatternSet patterns(&getContext());
734+
SmallVector<StringRef> mnemonics =
735+
llvm::to_vector_of<StringRef>(opMnemonics);
736+
math::populateExpansionPatterns(patterns, mnemonics);
737+
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
738+
return signalPassFailure();
739+
}
740+
};
741+
} // namespace

mlir/test/Dialect/Math/expand-math.mlir

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
// RUN: mlir-opt %s --split-input-file -test-expand-math | FileCheck %s
1+
// RUN: mlir-opt %s --split-input-file -math-expand-ops | FileCheck %s
2+
// RUN: mlir-opt %s --split-input-file -math-expand-ops=ops=tanh,tan | FileCheck %s --check-prefix=CHECK-FILTER
23

34
// CHECK-LABEL: func @tanh
45
func.func @tanh(%arg: f32) -> f32 {
6+
// CHECK-FILTER-NOT: math.tanh
57
%res = math.tanh %arg : f32
68
return %res : f32
79
}
@@ -27,6 +29,7 @@ func.func @tanh(%arg: f32) -> f32 {
2729
// CHECK-LABEL: func @vector_tanh
2830
func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
2931
// CHECK-NOT: math.tanh
32+
// CHECK-FILTER-NOT: math.tanh
3033
%res = math.tanh %arg : vector<4xf32>
3134
return %res : vector<4xf32>
3235
}
@@ -35,6 +38,7 @@ func.func @vector_tanh(%arg: vector<4xf32>) -> vector<4xf32> {
3538

3639
// CHECK-LABEL: func @tan
3740
func.func @tan(%arg: f32) -> f32 {
41+
// CHECK-FILTER-NOT: math.tan
3842
%res = math.tan %arg : f32
3943
return %res : f32
4044
}
@@ -49,6 +53,7 @@ func.func @tan(%arg: f32) -> f32 {
4953

5054
// CHECK-LABEL: func @vector_tan
5155
func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
56+
// CHECK-FILTER-NOT: math.tan
5257
%res = math.tan %arg : vector<4xf32>
5358
return %res : vector<4xf32>
5459
}
@@ -58,6 +63,7 @@ func.func @vector_tan(%arg: vector<4xf32>) -> vector<4xf32> {
5863
// -----
5964

6065
func.func @ctlz(%arg: i32) -> i32 {
66+
// CHECK-FILTER: math.ctlz
6167
%res = math.ctlz %arg : i32
6268
return %res : i32
6369
}
@@ -112,6 +118,7 @@ func.func @ctlz(%arg: i32) -> i32 {
112118
// -----
113119

114120
func.func @ctlz_vector(%arg: vector<4xi32>) -> vector<4xi32> {
121+
// CHECK-FILTER: math.ctlz
115122
%res = math.ctlz %arg : vector<4xi32>
116123
return %res : vector<4xi32>
117124
}
@@ -145,6 +152,7 @@ func.func @ceilf_func(%a: f64) -> f64 {
145152
// CHECK-NEXT: [[INCR:%.+]] = arith.select [[COMP]], [[CST_0]], [[CST]]
146153
// CHECK-NEXT: [[ADDF:%.+]] = arith.addf [[COPYSIGN]], [[INCR]]
147154
// CHECK-NEXT: return [[ADDF]]
155+
// CHECK-FILTER: math.ceil
148156
%ret = math.ceil %a : f64
149157
return %ret : f64
150158
}
@@ -158,6 +166,7 @@ func.func @exp2f_func(%a: f64) -> f64 {
158166
// CHECK: [[MULF:%.+]] = arith.mulf [[ARG0]], [[CST]]
159167
// CHECK: [[EXP:%.+]] = math.exp [[MULF]]
160168
// CHECK: return [[EXP]]
169+
// CHECK-FILTER: math.exp2
161170
%ret = math.exp2 %a : f64
162171
return %ret : f64
163172
}
@@ -813,3 +822,27 @@ func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
813822
%a = math.rsqrt %arg : tensor<*xf32>
814823
return %a: tensor<*xf32>
815824
}
825+
826+
// -----
827+
828+
// CHECK-LABEL: func.func @clampf_scalar_op
829+
// CHECK-SAME: (%[[ARG:.*]]: f16, %[[MIN:.*]]: f16, %[[MAX:.*]]: f16)
830+
// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] : f16
831+
// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] : f16
832+
// CHECK: return %[[V1]] : f16
833+
834+
func.func @clampf_scalar_op(%arg: f16, %min: f16, %max: f16) -> f16 {
835+
%a = math.clampf %arg to [%min, %max] : f16
836+
return %a: f16
837+
}
838+
839+
// CHECK-LABEL: func.func @clampf_vector_op
840+
// CHECK-SAME: (%[[ARG:.*]]: vector<3x4xf32>, %[[MIN:.*]]: vector<3x4xf32>, %[[MAX:.*]]: vector<3x4xf32>)
841+
// CHECK: %[[V0:.*]] = arith.minimumf %[[ARG]], %[[MIN]] fastmath<fast> : vector<3x4xf32>
842+
// CHECK: %[[V1:.*]] = arith.maximumf %[[V0]], %[[MAX]] fastmath<fast> : vector<3x4xf32>
843+
// CHECK: return %[[V1]] : vector<3x4xf32>
844+
845+
func.func @clampf_vector_op(%arg: vector<3x4xf32>, %min: vector<3x4xf32>, %max: vector<3x4xf32>) -> vector<3x4xf32>{
846+
%a = math.clampf %arg to [%min, %max] fastmath<fast> : vector<3x4xf32>
847+
return %a: vector<3x4xf32>
848+
}

mlir/test/Dialect/Math/ops.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s | mlir-opt | FileCheck %s
1+
// RUN: mlir-opt %s --verify-roundtrip | FileCheck %s
22
// RUN: mlir-opt %s --mlir-print-op-generic | mlir-opt | FileCheck %s
33

44
// CHECK-LABEL: func @atan(
@@ -337,3 +337,16 @@ func.func @fpclassify(%f: f32, %d: f64, %v: vector<4xf32>, %t: tensor<4x?xf32>)
337337
math.isnormal %t : tensor<4x?xf32>
338338
return
339339
}
340+
341+
// CHECK-LABEL: func @clampf(
342+
func.func @clampf(%av: vector<3x4xf32>, %mv: vector<3x4xf32>, %Mv: vector<3x4xf32>,
343+
%as: f32, %ms: f32, %Ms: f32,
344+
%at: tensor<?xf80>, %mt: tensor<?xf80>, %Mt: tensor<?xf80>) {
345+
// CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] fastmath<fast> : vector<3x4xf32>
346+
%rv = math.clampf %av to [%mv, %Mv] fastmath<fast> : vector<3x4xf32>
347+
// CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : f32
348+
%rs = math.clampf %as to [%ms, %Ms] fastmath<none> : f32
349+
// CHECK: math.clampf %{{.*}} to [%{{.*}}, %{{.*}}] : tensor<?xf80>
350+
%rt = math.clampf %at to [%mt, %Mt] : tensor<?xf80>
351+
return
352+
}

mlir/test/lib/Dialect/Math/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Exclude tests from libMLIR.so
22
add_mlir_library(MLIRMathTestPasses
33
TestAlgebraicSimplification.cpp
4-
TestExpandMath.cpp
54
TestPolynomialApproximation.cpp
65

76
EXCLUDE_FROM_LIBMLIR

0 commit comments

Comments
 (0)