2525#include " mlir/IR/BuiltinAttributes.h"
2626#include " mlir/IR/IntegerSet.h"
2727#include " mlir/IR/Visitors.h"
28- #include " mlir/Transforms/DialectConversion .h"
28+ #include " mlir/Transforms/WalkPatternRewriteDriver .h"
2929#include " llvm/ADT/DenseMap.h"
3030#include " llvm/Support/Debug.h"
3131#include < optional>
@@ -451,10 +451,10 @@ static void rewriteStore(fir::StoreOp storeOp,
451451}
452452
453453static void rewriteMemoryOps (Block *block, mlir::PatternRewriter &rewriter) {
454- for (auto &bodyOp : block->getOperations ()) {
454+ for (auto &bodyOp : llvm::make_early_inc_range ( block->getOperations () )) {
455455 if (isa<fir::LoadOp>(bodyOp))
456456 rewriteLoad (cast<fir::LoadOp>(bodyOp), rewriter);
457- if (isa<fir::StoreOp>(bodyOp))
457+ else if (isa<fir::StoreOp>(bodyOp))
458458 rewriteStore (cast<fir::StoreOp>(bodyOp), rewriter);
459459 }
460460}
@@ -476,6 +476,8 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
476476 loop.dump (););
477477 LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
478478 functionAnalysis.getChildLoopAnalysis (loop);
479+ if (!loopAnalysis.canPromoteToAffine ())
480+ return rewriter.notifyMatchFailure (loop, " cannot promote to affine" );
479481 auto &loopOps = loop.getBody ()->getOperations ();
480482 auto resultOp = cast<fir::ResultOp>(loop.getBody ()->getTerminator ());
481483 auto results = resultOp.getOperands ();
@@ -576,12 +578,14 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
576578public:
577579 using OpRewritePattern::OpRewritePattern;
578580 AffineIfConversion (mlir::MLIRContext *context, AffineFunctionAnalysis &afa)
579- : OpRewritePattern(context) {}
581+ : OpRewritePattern(context), functionAnalysis(afa) {}
580582 llvm::LogicalResult
581583 matchAndRewrite (fir::IfOp op,
582584 mlir::PatternRewriter &rewriter) const override {
583585 LLVM_DEBUG (llvm::dbgs () << " AffineIfConversion: rewriting if:\n " ;
584586 op.dump (););
587+ if (!functionAnalysis.getChildIfAnalysis (op).canPromoteToAffine ())
588+ return rewriter.notifyMatchFailure (op, " cannot promote to affine" );
585589 auto &ifOps = op.getThenRegion ().front ().getOperations ();
586590 auto affineCondition = AffineIfCondition (op.getCondition ());
587591 if (!affineCondition.hasIntegerSet ()) {
@@ -611,6 +615,8 @@ class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> {
611615 rewriter.replaceOp (op, affineIf.getOperation ()->getResults ());
612616 return success ();
613617 }
618+
619+ AffineFunctionAnalysis &functionAnalysis;
614620};
615621
616622// / Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases
@@ -627,28 +633,11 @@ class AffineDialectPromotion
627633 mlir::RewritePatternSet patterns (context);
628634 patterns.insert <AffineIfConversion>(context, functionAnalysis);
629635 patterns.insert <AffineLoopConversion>(context, functionAnalysis);
630- mlir::ConversionTarget target = *context;
631- target.addLegalDialect <mlir::affine::AffineDialect, FIROpsDialect,
632- mlir::scf::SCFDialect, mlir::arith::ArithDialect,
633- mlir::func::FuncDialect>();
634- target.addDynamicallyLegalOp <IfOp>([&functionAnalysis](fir::IfOp op) {
635- return !(functionAnalysis.getChildIfAnalysis (op).canPromoteToAffine ());
636- });
637- target.addDynamicallyLegalOp <DoLoopOp>([&functionAnalysis](
638- fir::DoLoopOp op) {
639- return !(functionAnalysis.getChildLoopAnalysis (op).canPromoteToAffine ());
640- });
641-
642636 LLVM_DEBUG (llvm::dbgs ()
643637 << " AffineDialectPromotion: running promotion on: \n " ;
644638 function.print (llvm::dbgs ()););
645639 // apply the patterns
646- if (mlir::failed (mlir::applyPartialConversion (function, target,
647- std::move (patterns)))) {
648- mlir::emitError (mlir::UnknownLoc::get (context),
649- " error in converting to affine dialect\n " );
650- signalPassFailure ();
651- }
640+ walkAndApplyPatterns (function, std::move (patterns));
652641 }
653642};
654643} // namespace
0 commit comments