|
13 | 13 | #include "mlir/Dialect/Arith/IR/Arith.h" |
14 | 14 | #include "mlir/Dialect/Math/IR/Math.h" |
15 | 15 | #include "mlir/Dialect/Math/Transforms/Passes.h" |
16 | | -#include "mlir/Dialect/SCF/IR/SCF.h" |
17 | | -#include "mlir/Dialect/Vector/IR/VectorOps.h" |
18 | 16 | #include "mlir/IR/Builders.h" |
| 17 | +#include "mlir/IR/Matchers.h" |
19 | 18 | #include "mlir/IR/TypeUtilities.h" |
20 | | -#include "mlir/Transforms/DialectConversion.h" |
| 19 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
21 | 20 |
|
22 | 21 | using namespace mlir; |
23 | 22 |
|
| 23 | +namespace mlir::math { |
| 24 | +#define GEN_PASS_DEF_MATHEXPANDOPSPASS |
| 25 | +#include "mlir/Dialect/Math/Transforms/Passes.h.inc" |
| 26 | +} // namespace mlir::math |
| 27 | + |
24 | 28 | /// Create a float constant. |
25 | 29 | static Value createFloatConst(Location loc, Type type, APFloat value, |
26 | 30 | OpBuilder &b) { |
@@ -661,66 +665,77 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, |
661 | 665 | return success(); |
662 | 666 | } |
663 | 667 |
|
664 | | -void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { |
665 | | - patterns.add(convertCtlzOp); |
666 | | -} |
667 | | - |
668 | | -void mlir::populateExpandSinhPattern(RewritePatternSet &patterns) { |
669 | | - patterns.add(convertSinhOp); |
670 | | -} |
671 | | - |
672 | | -void mlir::populateExpandCoshPattern(RewritePatternSet &patterns) { |
673 | | - patterns.add(convertCoshOp); |
674 | | -} |
675 | | - |
676 | | -void mlir::populateExpandTanPattern(RewritePatternSet &patterns) { |
677 | | - patterns.add(convertTanOp); |
678 | | -} |
679 | | - |
680 | | -void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { |
681 | | - patterns.add(convertTanhOp); |
682 | | -} |
683 | | - |
684 | | -void mlir::populateExpandAsinhPattern(RewritePatternSet &patterns) { |
685 | | - patterns.add(convertAsinhOp); |
686 | | -} |
687 | | - |
688 | | -void mlir::populateExpandAcoshPattern(RewritePatternSet &patterns) { |
689 | | - patterns.add(convertAcoshOp); |
690 | | -} |
691 | | - |
692 | | -void mlir::populateExpandAtanhPattern(RewritePatternSet &patterns) { |
693 | | - patterns.add(convertAtanhOp); |
694 | | -} |
695 | | - |
696 | | -void mlir::populateExpandFmaFPattern(RewritePatternSet &patterns) { |
697 | | - patterns.add(convertFmaFOp); |
698 | | -} |
699 | | - |
700 | | -void mlir::populateExpandCeilFPattern(RewritePatternSet &patterns) { |
701 | | - patterns.add(convertCeilOp); |
702 | | -} |
703 | | - |
704 | | -void mlir::populateExpandExp2FPattern(RewritePatternSet &patterns) { |
705 | | - patterns.add(convertExp2fOp); |
706 | | -} |
707 | | - |
708 | | -void mlir::populateExpandPowFPattern(RewritePatternSet &patterns) { |
709 | | - patterns.add(convertPowfOp); |
710 | | -} |
711 | | - |
712 | | -void mlir::populateExpandFPowIPattern(RewritePatternSet &patterns) { |
713 | | - patterns.add(convertFPowIOp); |
714 | | -} |
715 | | - |
716 | | -void mlir::populateExpandRoundFPattern(RewritePatternSet &patterns) { |
717 | | - patterns.add(convertRoundOp); |
| 668 | +// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf` |
| 669 | +static LogicalResult convertClampfOp(math::ClampFOp op, |
| 670 | + PatternRewriter &rewriter) { |
| 671 | + auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(), |
| 672 | + op.getMin(), op.getFastmath()); |
| 673 | + rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMax(), |
| 674 | + op.getFastmath()); |
| 675 | + return success(); |
718 | 676 | } |
719 | 677 |
|
720 | | -void mlir::populateExpandRoundEvenPattern(RewritePatternSet &patterns) { |
721 | | - patterns.add(convertRoundEvenOp); |
| 678 | +void mlir::math::populateExpansionPatterns(RewritePatternSet &patterns, |
| 679 | + ArrayRef<StringRef> opMnemonics) { |
| 680 | + auto filter = [&](StringRef name) { |
| 681 | + // This should be a static assert and `consume_front` take a twine, but none |
| 682 | + // is currently possible. TODO: augment `StringRef::consume_front` and make |
| 683 | + // `getDialectNamespace` use `std::string_view`. |
| 684 | + assert("math" == MathDialect::getDialectNamespace()); |
| 685 | + name.consume_front("math."); |
| 686 | + return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0); |
| 687 | + }; |
| 688 | + if (filter(CountLeadingZerosOp::getOperationName())) |
| 689 | + patterns.add(convertCtlzOp); |
| 690 | + if (filter(SinhOp::getOperationName())) |
| 691 | + patterns.add(convertSinhOp); |
| 692 | + if (filter(CoshOp::getOperationName())) |
| 693 | + patterns.add(convertCoshOp); |
| 694 | + if (filter(TanOp::getOperationName())) |
| 695 | + patterns.add(convertTanOp); |
| 696 | + if (filter(TanhOp::getOperationName())) |
| 697 | + patterns.add(convertTanhOp); |
| 698 | + if (filter(AsinhOp::getOperationName())) |
| 699 | + patterns.add(convertAsinhOp); |
| 700 | + if (filter(AcoshOp::getOperationName())) |
| 701 | + patterns.add(convertAcoshOp); |
| 702 | + if (filter(AtanhOp::getOperationName())) |
| 703 | + patterns.add(convertAtanhOp); |
| 704 | + if (filter(FmaOp::getOperationName())) |
| 705 | + patterns.add(convertFmaFOp); |
| 706 | + if (filter(CeilOp::getOperationName())) |
| 707 | + patterns.add(convertCeilOp); |
| 708 | + if (filter(Exp2Op::getOperationName())) |
| 709 | + patterns.add(convertExp2fOp); |
| 710 | + if (filter(PowFOp::getOperationName())) |
| 711 | + patterns.add(convertPowfOp); |
| 712 | + if (filter(FPowIOp::getOperationName())) |
| 713 | + patterns.add(convertFPowIOp); |
| 714 | + if (filter(RoundOp::getOperationName())) |
| 715 | + patterns.add(convertRoundOp); |
| 716 | + if (filter(RoundEvenOp::getOperationName())) |
| 717 | + patterns.add(convertRoundEvenOp); |
| 718 | + if (filter(RsqrtOp::getOperationName())) |
| 719 | + patterns.add(convertRsqrtOp); |
| 720 | + if (filter(ClampFOp::getOperationName())) |
| 721 | + patterns.add(convertClampfOp); |
722 | 722 | } |
723 | 723 |
|
724 | | -void mlir::populateExpandRsqrtPattern(RewritePatternSet &patterns) { |
725 | | - patterns.add(convertRsqrtOp); |
726 | | -} |
| 724 | +//===----------------------------------------------------------------------===// |
| 725 | +// MathExpandOpsPass pass |
| 726 | +//===----------------------------------------------------------------------===// |
| 727 | +namespace { |
| 728 | +struct MathExpandOpsPass final |
| 729 | + : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> { |
| 730 | + using MathExpandOpsPassBase::MathExpandOpsPassBase; |
| 731 | + |
| 732 | + void runOnOperation() override { |
| 733 | + RewritePatternSet patterns(&getContext()); |
| 734 | + SmallVector<StringRef> mnemonics = |
| 735 | + llvm::to_vector_of<StringRef>(opMnemonics); |
| 736 | + math::populateExpansionPatterns(patterns, mnemonics); |
| 737 | + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) |
| 738 | + return signalPassFailure(); |
| 739 | + } |
| 740 | +}; |
| 741 | +} // namespace |
0 commit comments