diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp index 691aac39f7b47..a31f17b1936d6 100644 --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -15510,6 +15510,78 @@ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr, return SE.getConstant(*ExprVal + DivisorVal - Rem); } +static bool collectDivisibilityInformation( + ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, + DenseMap &DivInfo, + DenseMap &Multiples, ScalarEvolution &SE) { + // If we have LHS == 0, check if LHS is computing a property of some unknown + // SCEV %v which we can rewrite %v to express explicitly. + if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero())) + return false; + // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to + // explicitly express that. + const SCEVUnknown *URemLHS = nullptr; + const SCEV *URemRHS = nullptr; + if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) + return false; + + const SCEV *Multiple = + SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS); + DivInfo[URemLHS] = Multiple; + if (auto *C = dyn_cast(URemRHS)) + Multiples[URemLHS] = C->getAPInt(); + return true; +} + +// Check if the condition is a divisibility guard (A % B == 0). +static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS, + ScalarEvolution &SE) { + const SCEV *X, *Y; + return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero(); +} + +// Apply divisibility by \p Divisor on MinMaxExpr with constant values, +// recursively. This is done by aligning up/down the constant value to the +// Divisor. +static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr, + APInt Divisor, + ScalarEvolution &SE) { + // Return true if \p Expr is a MinMax SCEV expression with a non-negative + // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS + // the non-constant operand and in \p LHS the constant operand. + auto IsMinMaxSCEVWithNonNegativeConstant = + [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, + const SCEV *&RHS) { + if (auto *MinMax = dyn_cast(Expr)) { + if (MinMax->getNumOperands() != 2) + return false; + if (auto *C = dyn_cast(MinMax->getOperand(0))) { + if (C->getAPInt().isNegative()) + return false; + SCTy = MinMax->getSCEVType(); + LHS = MinMax->getOperand(0); + RHS = MinMax->getOperand(1); + return true; + } + } + return false; + }; + + const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; + SCEVTypes SCTy; + if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, + MinMaxRHS)) + return MinMaxExpr; + auto IsMin = isa(MinMaxExpr) || isa(MinMaxExpr); + assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!"); + auto *DivisibleExpr = + IsMin ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE) + : getNextSCEVDivisibleByDivisor(MinMaxLHS, Divisor, SE); + SmallVector Ops = { + applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr}; + return SE.getMinMaxExpr(SCTy, Ops); +} + void ScalarEvolution::LoopGuards::collectFromBlock( ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards, const BasicBlock *Block, const BasicBlock *Pred, @@ -15520,19 +15592,13 @@ void ScalarEvolution::LoopGuards::collectFromBlock( SmallVector ExprsToRewrite; auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS, - DenseMap - &RewriteMap) { + DenseMap &RewriteMap, + const LoopGuards &DivGuards) { // WARNING: It is generally unsound to apply any wrap flags to the proposed // replacement SCEV which isn't directly implied by the structure of that // SCEV. In particular, using contextual facts to imply flags is *NOT* // legal. See the scoping rules for flags in the header to understand why. - // If LHS is a constant, apply information to the other expression. - if (isa(LHS)) { - std::swap(LHS, RHS); - Predicate = CmpInst::getSwappedPredicate(Predicate); - } - // Check for a condition of the form (-C1 + X < C2). InstCombine will // create this form when combining two checks of the form (X u< C2 + C1) and // (X >=u C1). @@ -15565,67 +15631,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock( if (MatchRangeCheckIdiom()) return; - // Return true if \p Expr is a MinMax SCEV expression with a non-negative - // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS - // the non-constant operand and in \p LHS the constant operand. - auto IsMinMaxSCEVWithNonNegativeConstant = - [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS, - const SCEV *&RHS) { - const APInt *C; - SCTy = Expr->getSCEVType(); - return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) && - match(LHS, m_scev_APInt(C)) && C->isNonNegative(); - }; - - // Apply divisibilty by \p Divisor on MinMaxExpr with constant values, - // recursively. This is done by aligning up/down the constant value to the - // Divisor. - std::function - ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr, - const SCEV *Divisor) { - auto *ConstDivisor = dyn_cast(Divisor); - if (!ConstDivisor) - return MinMaxExpr; - const APInt &DivisorVal = ConstDivisor->getAPInt(); - - const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr; - SCEVTypes SCTy; - if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS, - MinMaxRHS)) - return MinMaxExpr; - auto IsMin = - isa(MinMaxExpr) || isa(MinMaxExpr); - assert(SE.isKnownNonNegative(MinMaxLHS) && - "Expected non-negative operand!"); - auto *DivisibleExpr = - IsMin - ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE) - : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE); - SmallVector Ops = { - ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr}; - return SE.getMinMaxExpr(SCTy, Ops); - }; - - // If we have LHS == 0, check if LHS is computing a property of some unknown - // SCEV %v which we can rewrite %v to express explicitly. - if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) { - // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to - // explicitly express that. - const SCEVUnknown *URemLHS = nullptr; - const SCEV *URemRHS = nullptr; - if (match(LHS, - m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) { - auto I = RewriteMap.find(URemLHS); - const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS; - RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS); - const auto *Multiple = - SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS); - RewriteMap[URemLHS] = Multiple; - ExprsToRewrite.push_back(URemLHS); - return; - } - } - // Do not apply information for constants or if RHS contains an AddRec. if (isa(LHS) || SE.containsAddRecurrence(RHS)) return; @@ -15655,7 +15660,9 @@ void ScalarEvolution::LoopGuards::collectFromBlock( }; const SCEV *RewrittenLHS = GetMaybeRewritten(LHS); - const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS); + // Apply divisibility information when computing the constant multiple. + const APInt &DividesBy = + SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS)); // Collect rewrites for LHS and its transitive operands based on the // condition. @@ -15840,8 +15847,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock( // Now apply the information from the collected conditions to // Guards.RewriteMap. Conditions are processed in reverse order, so the - // earliest conditions is processed first. This ensures the SCEVs with the + // earliest conditions is processed first, except guards with divisibility + // information, which are moved to the back. This ensures the SCEVs with the // shortest dependency chains are constructed first. + SmallVector> + GuardsToProcess; for (auto [Term, EnterIfTrue] : reverse(Terms)) { SmallVector Worklist; SmallPtrSet Visited; @@ -15856,7 +15866,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock( EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate(); const auto *LHS = SE.getSCEV(Cmp->getOperand(0)); const auto *RHS = SE.getSCEV(Cmp->getOperand(1)); - CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap); + // If LHS is a constant, apply information to the other expression. + // TODO: If LHS is not a constant, check if using CompareSCEVComplexity + // can improve results. + if (isa(LHS)) { + std::swap(LHS, RHS); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + GuardsToProcess.emplace_back(Predicate, LHS, RHS); continue; } @@ -15869,6 +15886,31 @@ void ScalarEvolution::LoopGuards::collectFromBlock( } } + // Process divisibility guards in reverse order to populate DivGuards early. + DenseMap Multiples; + LoopGuards DivGuards(SE); + for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) { + if (!isDivisibilityGuard(LHS, RHS, SE)) + continue; + collectDivisibilityInformation(Predicate, LHS, RHS, DivGuards.RewriteMap, + Multiples, SE); + } + + for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) + CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivGuards); + + // Apply divisibility information last. This ensures it is applied to the + // outermost expression after other rewrites for the given value. + for (const auto &[K, Divisor] : Multiples) { + const SCEV *DivisorSCEV = SE.getConstant(Divisor); + Guards.RewriteMap[K] = + SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr( + Guards.rewrite(K), Divisor, SE), + DivisorSCEV), + DivisorSCEV); + ExprsToRewrite.push_back(K); + } + // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of // the replacement expressions are contained in the ranges of the replaced // expressions. diff --git a/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll b/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll index 14ee00d77197c..2763860e79875 100644 --- a/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll +++ b/llvm/test/Transforms/IndVarSimplify/loop-guard-order.ll @@ -114,7 +114,7 @@ define i32 @urem_order1(i32 %n) { ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[IV:%.*]] = phi i32 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] ; CHECK-NEXT: call void @foo() -; CHECK-NEXT: [[IV_NEXT]] = add i32 [[IV]], 3 +; CHECK-NEXT: [[IV_NEXT]] = add nuw i32 [[IV]], 3 ; CHECK-NEXT: [[EC:%.*]] = icmp eq i32 [[IV_NEXT]], [[N]] ; CHECK-NEXT: br i1 [[EC]], label %[[EXIT_LOOPEXIT:.*]], label %[[LOOP]] ; CHECK: [[EXIT_LOOPEXIT]]: @@ -205,13 +205,12 @@ define i64 @test_loop_with_div_order_1(i64 %n) { ; CHECK-NEXT: [[PARITY_CHECK:%.*]] = icmp eq i64 [[IS_ODD]], 0 ; CHECK-NEXT: br i1 [[PARITY_CHECK]], label %[[LOOP_PREHEADER:.*]], label %[[EXIT]] ; CHECK: [[LOOP_PREHEADER]]: -; CHECK-NEXT: [[UMAX:%.*]] = call i64 @llvm.umax.i64(i64 [[UPPER_BOUND]], i64 1) ; CHECK-NEXT: br label %[[LOOP:.*]] ; CHECK: [[LOOP]]: ; CHECK-NEXT: [[IV:%.*]] = phi i64 [ [[IV_NEXT:%.*]], %[[LOOP]] ], [ 0, %[[LOOP_PREHEADER]] ] ; CHECK-NEXT: [[DUMMY:%.*]] = load volatile i64, ptr null, align 8 ; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1 -; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UMAX]] +; CHECK-NEXT: [[EXITCOND:%.*]] = icmp ne i64 [[IV_NEXT]], [[UPPER_BOUND]] ; CHECK-NEXT: br i1 [[EXITCOND]], label %[[LOOP]], label %[[EXIT_LOOPEXIT:.*]] ; CHECK: [[EXIT_LOOPEXIT]]: ; CHECK-NEXT: br label %[[EXIT]]