1919#include " mlir/IR/BuiltinAttributes.h"
2020#include " mlir/IR/IRMapping.h"
2121#include " mlir/IR/Matchers.h"
22+ #include " mlir/IR/Operation.h"
23+ #include " mlir/IR/OperationSupport.h"
2224#include " mlir/IR/PatternMatch.h"
2325#include " mlir/Interfaces/FunctionInterfaces.h"
2426#include " mlir/Interfaces/ParallelCombiningOpInterface.h"
2527#include " mlir/Interfaces/ValueBoundsOpInterface.h"
2628#include " mlir/Transforms/InliningUtils.h"
2729#include " llvm/ADT/MapVector.h"
2830#include " llvm/ADT/SmallPtrSet.h"
31+ #include " llvm/Support/Casting.h"
32+ #include " llvm/Support/DebugLog.h"
33+ #include < optional>
2934
3035using namespace mlir ;
3136using namespace mlir ::scf;
@@ -105,6 +110,24 @@ static TerminatorTy verifyAndGetTerminator(Operation *op, Region ®ion,
105110 return nullptr ;
106111}
107112
113+ // / Helper function to compute the difference between two values. This is used
114+ // / by the loop implementations to compute the trip count.
115+ static std::optional<llvm::APSInt> computeUbMinusLb (Value lb, Value ub,
116+ bool isSigned) {
117+ llvm::APSInt diff;
118+ auto addOp = ub.getDefiningOp <arith::AddIOp>();
119+ if (!addOp)
120+ return std::nullopt ;
121+ if ((isSigned && !addOp.hasNoSignedWrap ()) ||
122+ (!isSigned && !addOp.hasNoUnsignedWrap ()))
123+ return std::nullopt ;
124+
125+ if (addOp.getLhs () != lb ||
126+ !matchPattern (addOp.getRhs (), m_ConstantInt (&diff)))
127+ return std::nullopt ;
128+ return diff;
129+ }
130+
108131// ===----------------------------------------------------------------------===//
109132// ExecuteRegionOp
110133// ===----------------------------------------------------------------------===//
@@ -408,11 +431,19 @@ std::optional<ResultRange> ForOp::getLoopResults() { return getResults(); }
408431// / Promotes the loop body of a forOp to its containing block if the forOp
409432// / it can be determined that the loop has a single iteration.
410433LogicalResult ForOp::promoteIfSingleIteration (RewriterBase &rewriter) {
411- std::optional<int64_t > tripCount =
412- constantTripCount (getLowerBound (), getUpperBound (), getStep ());
413- if (!tripCount.has_value () || tripCount != 1 )
434+ std::optional<APInt> tripCount = getStaticTripCount ();
435+ LDBG () << " promoteIfSingleIteration tripCount is " << tripCount
436+ << " for loop "
437+ << OpWithFlags (getOperation (), OpPrintingFlags ().skipRegions ());
438+ if (!tripCount.has_value () || tripCount->getSExtValue () > 1 )
414439 return failure ();
415440
441+ if (*tripCount == 0 ) {
442+ rewriter.replaceAllUsesWith (getResults (), getInitArgs ());
443+ rewriter.eraseOp (*this );
444+ return success ();
445+ }
446+
416447 // Replace all results with the yielded values.
417448 auto yieldOp = cast<scf::YieldOp>(getBody ()->getTerminator ());
418449 rewriter.replaceAllUsesWith (getResults (), getYieldedValues ());
@@ -646,7 +677,8 @@ SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
646677LogicalResult scf::ForallOp::promoteIfSingleIteration (RewriterBase &rewriter) {
647678 for (auto [lb, ub, step] :
648679 llvm::zip (getMixedLowerBound (), getMixedUpperBound (), getMixedStep ())) {
649- auto tripCount = constantTripCount (lb, ub, step);
680+ auto tripCount =
681+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
650682 if (!tripCount.has_value () || *tripCount != 1 )
651683 return failure ();
652684 }
@@ -1003,27 +1035,6 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
10031035 }
10041036};
10051037
1006- // / Util function that tries to compute a constant diff between u and l.
1007- // / Returns std::nullopt when the difference between two AffineValueMap is
1008- // / dynamic.
1009- static std::optional<APInt> computeConstDiff (Value l, Value u) {
1010- IntegerAttr clb, cub;
1011- if (matchPattern (l, m_Constant (&clb)) && matchPattern (u, m_Constant (&cub))) {
1012- llvm::APInt lbValue = clb.getValue ();
1013- llvm::APInt ubValue = cub.getValue ();
1014- return ubValue - lbValue;
1015- }
1016-
1017- // Else a simple pattern match for x + c or c + x
1018- llvm::APInt diff;
1019- if (matchPattern (
1020- u, m_Op<arith::AddIOp>(matchers::m_Val (l), m_ConstantInt (&diff))) ||
1021- matchPattern (
1022- u, m_Op<arith::AddIOp>(m_ConstantInt (&diff), matchers::m_Val (l))))
1023- return diff;
1024- return std::nullopt ;
1025- }
1026-
10271038// / Rewriting pattern that erases loops that are known not to iterate, replaces
10281039// / single-iteration loops with their bodies, and removes empty loops that
10291040// / iterate at least once and only return values defined outside of the loop.
@@ -1032,34 +1043,21 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
10321043
10331044 LogicalResult matchAndRewrite (ForOp op,
10341045 PatternRewriter &rewriter) const override {
1035- // If the upper bound is the same as the lower bound, the loop does not
1036- // iterate, just remove it.
1037- if (op.getLowerBound () == op.getUpperBound ()) {
1046+ std::optional<APInt> tripCount = op.getStaticTripCount ();
1047+ if (!tripCount.has_value ())
1048+ return rewriter.notifyMatchFailure (op,
1049+ " can't compute constant trip count" );
1050+
1051+ if (tripCount->isZero ()) {
1052+ LDBG () << " SimplifyTrivialLoops tripCount is 0 for loop "
1053+ << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
10381054 rewriter.replaceOp (op, op.getInitArgs ());
10391055 return success ();
10401056 }
10411057
1042- std::optional<APInt> diff =
1043- computeConstDiff (op.getLowerBound (), op.getUpperBound ());
1044- if (!diff)
1045- return failure ();
1046-
1047- // If the loop is known to have 0 iterations, remove it.
1048- bool zeroOrLessIterations =
1049- diff->isZero () || (!op.getUnsignedCmp () && diff->isNegative ());
1050- if (zeroOrLessIterations) {
1051- rewriter.replaceOp (op, op.getInitArgs ());
1052- return success ();
1053- }
1054-
1055- std::optional<llvm::APInt> maybeStepValue = op.getConstantStep ();
1056- if (!maybeStepValue)
1057- return failure ();
1058-
1059- // If the loop is known to have 1 iteration, inline its body and remove the
1060- // loop.
1061- llvm::APInt stepValue = *maybeStepValue;
1062- if (stepValue.sge (*diff)) {
1058+ if (tripCount->getSExtValue () == 1 ) {
1059+ LDBG () << " SimplifyTrivialLoops tripCount is 1 for loop "
1060+ << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
10631061 SmallVector<Value, 4 > blockArgs;
10641062 blockArgs.reserve (op.getInitArgs ().size () + 1 );
10651063 blockArgs.push_back (op.getLowerBound ());
@@ -1072,11 +1070,14 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
10721070 Block &block = op.getRegion ().front ();
10731071 if (!llvm::hasSingleElement (block))
10741072 return failure ();
1075- // If the loop is empty, iterates at least once, and only returns values
1073+ // The loop is empty and iterates at least once, if it only returns values
10761074 // defined outside of the loop, remove it and replace it with yield values.
10771075 if (llvm::any_of (op.getYieldedValues (),
10781076 [&](Value v) { return !op.isDefinedOutsideOfLoop (v); }))
10791077 return failure ();
1078+ LDBG () << " SimplifyTrivialLoops empty body loop allows replacement with "
1079+ " yield operands for loop "
1080+ << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
10801081 rewriter.replaceOp (op, op.getYieldedValues ());
10811082 return success ();
10821083 }
@@ -1172,6 +1173,11 @@ Speculation::Speculatability ForOp::getSpeculatability() {
11721173 return Speculation::NotSpeculatable;
11731174}
11741175
1176+ std::optional<APInt> ForOp::getStaticTripCount () {
1177+ return constantTripCount (getLowerBound (), getUpperBound (), getStep (),
1178+ /* isSigned=*/ !getUnsignedCmp (), computeUbMinusLb);
1179+ }
1180+
11751181// ===----------------------------------------------------------------------===//
11761182// ForallOp
11771183// ===----------------------------------------------------------------------===//
@@ -1768,7 +1774,8 @@ struct ForallOpSingleOrZeroIterationDimsFolder
17681774 for (auto [lb, ub, step, iv] :
17691775 llvm::zip (op.getMixedLowerBound (), op.getMixedUpperBound (),
17701776 op.getMixedStep (), op.getInductionVars ())) {
1771- auto numIterations = constantTripCount (lb, ub, step);
1777+ auto numIterations =
1778+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
17721779 if (numIterations.has_value ()) {
17731780 // Remove the loop if it performs zero iterations.
17741781 if (*numIterations == 0 ) {
@@ -1839,7 +1846,8 @@ struct ForallOpReplaceConstantInductionVar : public OpRewritePattern<ForallOp> {
18391846 op.getMixedStep (), op.getInductionVars ())) {
18401847 if (iv.hasNUses (0 ))
18411848 continue ;
1842- auto numIterations = constantTripCount (lb, ub, step);
1849+ auto numIterations =
1850+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
18431851 if (!numIterations.has_value () || numIterations.value () != 1 ) {
18441852 continue ;
18451853 }
@@ -3084,7 +3092,8 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
30843092 for (auto [lb, ub, step, iv] :
30853093 llvm::zip (op.getLowerBound (), op.getUpperBound (), op.getStep (),
30863094 op.getInductionVars ())) {
3087- auto numIterations = constantTripCount (lb, ub, step);
3095+ auto numIterations =
3096+ constantTripCount (lb, ub, step, /* isSigned=*/ true , computeUbMinusLb);
30883097 if (numIterations.has_value ()) {
30893098 // Remove the loop if it performs zero iterations.
30903099 if (*numIterations == 0 ) {
0 commit comments