@@ -15765,19 +15765,25 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1576515765 GetNextSCEVDividesByDivisor(One, DividesBy);
1576615766 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1576715767 } else {
15768+ // LHS != RHS can be rewritten as (LHS - RHS) = UMax(1, LHS - RHS),
15769+ // but creating the subtraction eagerly is expensive. Track the
15770+ // inequalities in a separate map, and materialize the rewrite lazily
15771+ // when encountering a suitable subtraction while re-writing.
1576815772 if (LHS->getType()->isPointerTy()) {
1576915773 LHS = SE.getLosslessPtrToIntExpr(LHS);
1577015774 RHS = SE.getLosslessPtrToIntExpr(RHS);
1577115775 if (isa<SCEVCouldNotCompute>(LHS) || isa<SCEVCouldNotCompute>(RHS))
1577215776 break;
1577315777 }
15774- auto AddSubRewrite = [&](const SCEV *A, const SCEV *B) {
15775- const SCEV *Sub = SE.getMinusSCEV(A, B);
15776- AddRewrite(Sub, Sub,
15777- SE.getUMaxExpr(Sub, SE.getOne(From->getType())));
15778- };
15779- AddSubRewrite(LHS, RHS);
15780- AddSubRewrite(RHS, LHS);
15778+ const SCEVConstant *C;
15779+ const SCEV *A, *B;
15780+ if (match(RHS, m_scev_Add(m_SCEVConstant(C), m_SCEV(A))) &&
15781+ match(LHS, m_scev_Add(m_scev_Specific(C), m_SCEV(B)))) {
15782+ RHS = A;
15783+ LHS = B;
15784+ }
15785+ Guards.NotEqualMap[LHS].insert(RHS);
15786+ Guards.NotEqualMap[RHS].insert(LHS);
1578115787 continue;
1578215788 }
1578315789 break;
@@ -15911,13 +15917,15 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1591115917 class SCEVLoopGuardRewriter
1591215918 : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
1591315919 const DenseMap<const SCEV *, const SCEV *> ⤅
15920+ const DenseMap<const SCEV *, SmallPtrSet<const SCEV *, 2>> &NotEqualMap;
1591415921
1591515922 SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
1591615923
1591715924 public:
1591815925 SCEVLoopGuardRewriter(ScalarEvolution &SE,
1591915926 const ScalarEvolution::LoopGuards &Guards)
15920- : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
15927+ : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap),
15928+ NotEqualMap(Guards.NotEqualMap) {
1592115929 if (Guards.PreserveNUW)
1592215930 FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
1592315931 if (Guards.PreserveNSW)
@@ -15972,14 +15980,35 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1597215980 }
1597315981
1597415982 const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
15983+ // Helper to check if S is a subtraction (A - B) where A != B, and if so,
15984+ // return UMax(S, 1).
15985+ auto RewriteSubtraction = [&](const SCEV *S) -> const SCEV * {
15986+ const SCEV *LHS, *RHS;
15987+ if (MatchBinarySub(S, LHS, RHS)) {
15988+ auto It = NotEqualMap.find(LHS);
15989+ if (It != NotEqualMap.end() && It->second.contains(RHS))
15990+ return SE.getUMaxExpr(S, SE.getOne(S->getType()));
15991+ }
15992+ return nullptr;
15993+ };
15994+
15995+ // Check if Expr itself is a subtraction pattern with guard info.
15996+ if (const SCEV *Rewritten = RewriteSubtraction(Expr))
15997+ return Rewritten;
15998+
1597515999 // Trip count expressions sometimes consist of adding 3 operands, i.e.
1597616000 // (Const + A + B). There may be guard info for A + B, and if so, apply
1597716001 // it.
1597816002 // TODO: Could more generally apply guards to Add sub-expressions.
1597916003 if (isa<SCEVConstant>(Expr->getOperand(0)) &&
1598016004 Expr->getNumOperands() == 3) {
15981- if (const SCEV *S = Map.lookup(
15982- SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2))))
16005+ const SCEV *Add =
16006+ SE.getAddExpr(Expr->getOperand(1), Expr->getOperand(2));
16007+ if (const SCEV *Rewritten = RewriteSubtraction(Add))
16008+ return SE.getAddExpr(
16009+ Expr->getOperand(0), Rewritten,
16010+ ScalarEvolution::maskFlags(Expr->getNoWrapFlags(), FlagMask));
16011+ if (const SCEV *S = Map.lookup(Add))
1598316012 return SE.getAddExpr(Expr->getOperand(0), S);
1598416013 }
1598516014 SmallVector<const SCEV *, 2> Operands;
@@ -16014,7 +16043,7 @@ const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
1601416043 }
1601516044 };
1601616045
16017- if (RewriteMap.empty())
16046+ if (RewriteMap.empty() && NotEqualMap.empty() )
1601816047 return Expr;
1601916048
1602016049 SCEVLoopGuardRewriter Rewriter(SE, *this);
0 commit comments