@@ -15496,6 +15496,110 @@ void ScalarEvolution::LoopGuards::collectFromPHI(
1549615496 }
1549715497}
1549815498
15499+ // Return a new SCEV that modifies \p Expr to the closest number divides by
15500+ // \p Divisor and less or equal than Expr. For now, only handle constant
15501+ // Expr.
15502+ static const SCEV *getPreviousSCEVDivisibleByDivisor(const SCEV *Expr,
15503+ const APInt &DivisorVal,
15504+ ScalarEvolution &SE) {
15505+ const APInt *ExprVal;
15506+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15507+ DivisorVal.isNonPositive())
15508+ return Expr;
15509+ APInt Rem = ExprVal->urem(DivisorVal);
15510+ // return the SCEV: Expr - Expr % Divisor
15511+ return SE.getConstant(*ExprVal - Rem);
15512+ }
15513+
15514+ // Return a new SCEV that modifies \p Expr to the closest number divisible by
15515+ // \p Divisor and greater than or equal to Expr.
15516+ // For now, only handle constant Expr and Divisor.
15517+ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
15518+ const APInt &DivisorVal,
15519+ ScalarEvolution &SE) {
15520+ const APInt *ExprVal;
15521+ if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15522+ DivisorVal.isNonPositive())
15523+ return Expr;
15524+ APInt Rem = ExprVal->urem(DivisorVal);
15525+ if (Rem.isZero())
15526+ return Expr;
15527+ // return the SCEV: Expr + Divisor - Expr % Divisor
15528+ return SE.getConstant(*ExprVal + DivisorVal - Rem);
15529+ }
15530+
15531+ static bool collectDivisibilityInformation(
15532+ ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15533+ DenseMap<const SCEV *, const SCEV *> &DivInfo,
15534+ DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15535+ // If we have LHS == 0, check if LHS is computing a property of some unknown
15536+ // SCEV %v which we can rewrite %v to express explicitly.
15537+ if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15538+ return false;
15539+ // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15540+ // explicitly express that.
15541+ const SCEVUnknown *URemLHS = nullptr;
15542+ const SCEV *URemRHS = nullptr;
15543+ if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15544+ return false;
15545+
15546+ const SCEV *Multiple = SE.getMulExpr(SE.getUDivExpr(LHS, URemRHS), URemRHS);
15547+ DivInfo[URemLHS] = Multiple;
15548+ Multiples[URemLHS] = cast<SCEVConstant>(URemRHS)->getAPInt();
15549+ return true;
15550+ }
15551+
15552+ // Check if the condition is a divisibility guard (A % B == 0).
15553+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15554+ ScalarEvolution &SE) {
15555+ const SCEV *X, *Y;
15556+ return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15557+ }
15558+
15559+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15560+ // recursively. This is done by aligning up/down the constant value to the
15561+ // Divisor.
15562+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15563+ const SCEV *Divisor,
15564+ ScalarEvolution &SE) {
15565+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15566+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15567+ // the non-constant operand and in \p LHS the constant operand.
15568+ auto IsMinMaxSCEVWithNonNegativeConstant =
15569+ [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15570+ const SCEV *&RHS) {
15571+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15572+ if (MinMax->getNumOperands() != 2)
15573+ return false;
15574+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15575+ if (C->getAPInt().isNegative())
15576+ return false;
15577+ SCTy = MinMax->getSCEVType();
15578+ LHS = MinMax->getOperand(0);
15579+ RHS = MinMax->getOperand(1);
15580+ return true;
15581+ }
15582+ }
15583+ return false;
15584+ };
15585+
15586+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15587+ SCEVTypes SCTy;
15588+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15589+ MinMaxRHS))
15590+ return MinMaxExpr;
15591+ auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15592+ assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15593+ auto *DivisibleExpr =
15594+ IsMin ? getPreviousSCEVDivisibleByDivisor(
15595+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE)
15596+ : getNextSCEVDivisibleByDivisor(
15597+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE);
15598+ SmallVector<const SCEV *> Ops = {
15599+ applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15600+ return SE.getMinMaxExpr(SCTy, Ops);
15601+ }
15602+
1549915603void ScalarEvolution::LoopGuards::collectFromBlock(
1550015604 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1550115605 const BasicBlock *Block, const BasicBlock *Pred,
@@ -15506,19 +15610,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1550615610 SmallVector<const SCEV *> ExprsToRewrite;
1550715611 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1550815612 const SCEV *RHS,
15509- DenseMap<const SCEV *, const SCEV *>
15510- &RewriteMap) {
15613+ DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15614+ const DenseMap<const SCEV *, const SCEV *>
15615+ &DivInfo) {
1551115616 // WARNING: It is generally unsound to apply any wrap flags to the proposed
1551215617 // replacement SCEV which isn't directly implied by the structure of that
1551315618 // SCEV. In particular, using contextual facts to imply flags is *NOT*
1551415619 // legal. See the scoping rules for flags in the header to understand why.
1551515620
15516- // If LHS is a constant, apply information to the other expression.
15517- if (isa<SCEVConstant>(LHS)) {
15518- std::swap(LHS, RHS);
15519- Predicate = CmpInst::getSwappedPredicate(Predicate);
15520- }
15521-
1552215621 // Check for a condition of the form (-C1 + X < C2). InstCombine will
1552315622 // create this form when combining two checks of the form (X u< C2 + C1) and
1552415623 // (X >=u C1).
@@ -15551,105 +15650,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1555115650 if (MatchRangeCheckIdiom())
1555215651 return;
1555315652
15554- // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15555- // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15556- // the non-constant operand and in \p LHS the constant operand.
15557- auto IsMinMaxSCEVWithNonNegativeConstant =
15558- [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15559- const SCEV *&RHS) {
15560- if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15561- if (MinMax->getNumOperands() != 2)
15562- return false;
15563- if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15564- if (C->getAPInt().isNegative())
15565- return false;
15566- SCTy = MinMax->getSCEVType();
15567- LHS = MinMax->getOperand(0);
15568- RHS = MinMax->getOperand(1);
15569- return true;
15570- }
15571- }
15572- return false;
15573- };
15574-
15575- // Return a new SCEV that modifies \p Expr to the closest number divides by
15576- // \p Divisor and greater or equal than Expr. For now, only handle constant
15577- // Expr.
15578- auto GetNextSCEVDividesByDivisor = [&](const SCEV *Expr,
15579- const APInt &DivisorVal) {
15580- const APInt *ExprVal;
15581- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15582- DivisorVal.isNonPositive())
15583- return Expr;
15584- APInt Rem = ExprVal->urem(DivisorVal);
15585- if (Rem.isZero())
15586- return Expr;
15587- // return the SCEV: Expr + Divisor - Expr % Divisor
15588- return SE.getConstant(*ExprVal + DivisorVal - Rem);
15589- };
15590-
15591- // Return a new SCEV that modifies \p Expr to the closest number divides by
15592- // \p Divisor and less or equal than Expr. For now, only handle constant
15593- // Expr.
15594- auto GetPreviousSCEVDividesByDivisor = [&](const SCEV *Expr,
15595- const APInt &DivisorVal) {
15596- const APInt *ExprVal;
15597- if (!match(Expr, m_scev_APInt(ExprVal)) || ExprVal->isNegative() ||
15598- DivisorVal.isNonPositive())
15599- return Expr;
15600- APInt Rem = ExprVal->urem(DivisorVal);
15601- // return the SCEV: Expr - Expr % Divisor
15602- return SE.getConstant(*ExprVal - Rem);
15603- };
15604-
15605- // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15606- // recursively. This is done by aligning up/down the constant value to the
15607- // Divisor.
15608- std::function<const SCEV *(const SCEV *, const SCEV *)>
15609- ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15610- const SCEV *Divisor) {
15611- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15612- if (!ConstDivisor)
15613- return MinMaxExpr;
15614- const APInt &DivisorVal = ConstDivisor->getAPInt();
15615-
15616- const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15617- SCEVTypes SCTy;
15618- if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15619- MinMaxRHS))
15620- return MinMaxExpr;
15621- auto IsMin =
15622- isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15623- assert(SE.isKnownNonNegative(MinMaxLHS) &&
15624- "Expected non-negative operand!");
15625- auto *DivisibleExpr =
15626- IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, DivisorVal)
15627- : GetNextSCEVDividesByDivisor(MinMaxLHS, DivisorVal);
15628- SmallVector<const SCEV *> Ops = {
15629- ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15630- return SE.getMinMaxExpr(SCTy, Ops);
15631- };
15632-
15633- // If we have LHS == 0, check if LHS is computing a property of some unknown
15634- // SCEV %v which we can rewrite %v to express explicitly.
15635- if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15636- // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15637- // explicitly express that.
15638- const SCEVUnknown *URemLHS = nullptr;
15639- const SCEV *URemRHS = nullptr;
15640- if (match(LHS,
15641- m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15642- auto I = RewriteMap.find(URemLHS);
15643- const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15644- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15645- const auto *Multiple =
15646- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15647- RewriteMap[URemLHS] = Multiple;
15648- ExprsToRewrite.push_back(URemLHS);
15649- return;
15650- }
15651- }
15652-
1565315653 // Do not apply information for constants or if RHS contains an AddRec.
1565415654 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1565515655 return;
@@ -15679,7 +15679,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1567915679 };
1568015680
1568115681 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15682- const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15682+ // Apply divisibility information when computing the constant multiple.
15683+ LoopGuards DivGuards(SE);
15684+ DivGuards.RewriteMap = DivInfo;
15685+ const APInt &DividesBy =
15686+ SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1568315687
1568415688 // Collect rewrites for LHS and its transitive operands based on the
1568515689 // condition.
@@ -15694,31 +15698,31 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1569415698 // predicate.
1569515699 const SCEV *One = SE.getOne(RHS->getType());
1569615700 switch (Predicate) {
15697- case CmpInst::ICMP_ULT:
15698- if (RHS->getType()->isPointerTy())
15699- return;
15700- RHS = SE.getUMaxExpr(RHS, One);
15701- [[fallthrough]];
15702- case CmpInst::ICMP_SLT: {
15703- RHS = SE.getMinusSCEV(RHS, One);
15704- RHS = GetPreviousSCEVDividesByDivisor (RHS, DividesBy);
15705- break;
15706- }
15707- case CmpInst::ICMP_UGT:
15708- case CmpInst::ICMP_SGT:
15709- RHS = SE.getAddExpr(RHS, One);
15710- RHS = GetNextSCEVDividesByDivisor (RHS, DividesBy);
15711- break;
15712- case CmpInst::ICMP_ULE:
15713- case CmpInst::ICMP_SLE:
15714- RHS = GetPreviousSCEVDividesByDivisor (RHS, DividesBy);
15715- break;
15716- case CmpInst::ICMP_UGE:
15717- case CmpInst::ICMP_SGE:
15718- RHS = GetNextSCEVDividesByDivisor (RHS, DividesBy);
15719- break;
15720- default:
15721- break;
15701+ case CmpInst::ICMP_ULT:
15702+ if (RHS->getType()->isPointerTy())
15703+ return;
15704+ RHS = SE.getUMaxExpr(RHS, One);
15705+ [[fallthrough]];
15706+ case CmpInst::ICMP_SLT: {
15707+ RHS = SE.getMinusSCEV(RHS, One);
15708+ RHS = getPreviousSCEVDivisibleByDivisor (RHS, DividesBy, SE );
15709+ break;
15710+ }
15711+ case CmpInst::ICMP_UGT:
15712+ case CmpInst::ICMP_SGT:
15713+ RHS = SE.getAddExpr(RHS, One);
15714+ RHS = getNextSCEVDivisibleByDivisor (RHS, DividesBy, SE );
15715+ break;
15716+ case CmpInst::ICMP_ULE:
15717+ case CmpInst::ICMP_SLE:
15718+ RHS = getPreviousSCEVDivisibleByDivisor (RHS, DividesBy, SE );
15719+ break;
15720+ case CmpInst::ICMP_UGE:
15721+ case CmpInst::ICMP_SGE:
15722+ RHS = getNextSCEVDivisibleByDivisor (RHS, DividesBy, SE );
15723+ break;
15724+ default:
15725+ break;
1572215726 }
1572315727
1572415728 SmallVector<const SCEV *, 16> Worklist(1, LHS);
@@ -15769,7 +15773,7 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1576915773 case CmpInst::ICMP_NE:
1577015774 if (match(RHS, m_scev_Zero())) {
1577115775 const SCEV *OneAlignedUp =
15772- GetNextSCEVDividesByDivisor (One, DividesBy);
15776+ getNextSCEVDivisibleByDivisor (One, DividesBy, SE );
1577315777 To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
1577415778 } else {
1577515779 if (LHS->getType()->isPointerTy()) {
@@ -15857,8 +15861,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1585715861
1585815862 // Now apply the information from the collected conditions to
1585915863 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15860- // earliest conditions is processed first. This ensures the SCEVs with the
15864+ // earliest conditions is processed first, except guards with divisibility
15865+ // information, which are moved to the back. This ensures the SCEVs with the
1586115866 // shortest dependency chains are constructed first.
15867+ SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15868+ GuardsToProcess;
1586215869 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1586315870 SmallVector<Value *, 8> Worklist;
1586415871 SmallPtrSet<Value *, 8> Visited;
@@ -15873,7 +15880,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1587315880 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1587415881 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1587515882 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15876- CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15883+ // If LHS is a constant, apply information to the other expression.
15884+ if (isa<SCEVConstant>(LHS)) {
15885+ std::swap(LHS, RHS);
15886+ Predicate = CmpInst::getSwappedPredicate(Predicate);
15887+ }
15888+ GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1587715889 continue;
1587815890 }
1587915891
@@ -15886,6 +15898,30 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1588615898 }
1588715899 }
1588815900
15901+ // Process divisibility guards in reverse order to populate DivInfo early.
15902+ DenseMap<const SCEV *, APInt> Multiples;
15903+ DenseMap<const SCEV *, const SCEV *> DivInfo;
15904+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15905+ if (!isDivisibilityGuard(LHS, RHS, SE))
15906+ continue;
15907+ collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE);
15908+ }
15909+
15910+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15911+ CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo);
15912+
15913+ // Apply divisibility information last. This ensures it is applied to the
15914+ // outermost expression after other rewrites for the given value.
15915+ for (const auto &[K, V] : Multiples) {
15916+ const SCEV *DivisorSCEV = SE.getConstant(V);
15917+ Guards.RewriteMap[K] =
15918+ SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15919+ Guards.rewrite(K), DivisorSCEV, SE),
15920+ DivisorSCEV),
15921+ DivisorSCEV);
15922+ ExprsToRewrite.push_back(K);
15923+ }
15924+
1588915925 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1589015926 // the replacement expressions are contained in the ranges of the replaced
1589115927 // expressions.
0 commit comments