@@ -1685,13 +1685,12 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
16851685 BB->eraseFromParent ();
16861686}
16871687
1688- static void
1689- deleteDeadBlocksFromLoop (Loop &L,
1690- SmallVectorImpl<BasicBlock *> &ExitBlocks,
1691- DominatorTree &DT, LoopInfo &LI,
1692- MemorySSAUpdater *MSSAU,
1693- ScalarEvolution *SE,
1694- function_ref<void (Loop &, StringRef)> DestroyLoopCB) {
1688+ static void deleteDeadBlocksFromLoop (Loop &L,
1689+ SmallVectorImpl<BasicBlock *> &ExitBlocks,
1690+ DominatorTree &DT, LoopInfo &LI,
1691+ MemorySSAUpdater *MSSAU,
1692+ ScalarEvolution *SE,
1693+ LPMUpdater &LoopUpdater) {
16951694 // Find all the dead blocks tied to this loop, and remove them from their
16961695 // successors.
16971696 SmallSetVector<BasicBlock *, 8 > DeadBlockSet;
@@ -1741,7 +1740,7 @@ deleteDeadBlocksFromLoop(Loop &L,
17411740 }) &&
17421741 " If the child loop header is dead all blocks in the child loop must "
17431742 " be dead as well!" );
1744- DestroyLoopCB (*ChildL, ChildL->getName ());
1743+ LoopUpdater. markLoopAsDeleted (*ChildL, ChildL->getName ());
17451744 if (SE)
17461745 SE->forgetBlockAndLoopDispositions ();
17471746 LI.destroy (ChildL);
@@ -2085,8 +2084,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks,
20852084 ParentL->removeChildLoop (llvm::find (*ParentL, &L));
20862085 else
20872086 LI.removeLoop (llvm::find (LI, &L));
2088- // markLoopAsDeleted for L should be triggered by the caller (it is typically
2089- // done by using the UnswitchCB callback ).
2087+ // markLoopAsDeleted for L should be triggered by the caller (it is
2088+ // typically done within postUnswitch ).
20902089 if (SE)
20912090 SE->forgetBlockAndLoopDispositions ();
20922091 LI.destroy (&L);
@@ -2123,18 +2122,56 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) {
21232122 } while (!DomWorklist.empty ());
21242123}
21252124
2125+ void postUnswitch (Loop &L, LPMUpdater &U, StringRef LoopName,
2126+ bool CurrentLoopValid, bool PartiallyInvariant,
2127+ bool InjectedCondition, ArrayRef<Loop *> NewLoops) {
2128+ // If we did a non-trivial unswitch, we have added new (cloned) loops.
2129+ if (!NewLoops.empty ())
2130+ U.addSiblingLoops (NewLoops);
2131+
2132+ // If the current loop remains valid, we should revisit it to catch any
2133+ // other unswitch opportunities. Otherwise, we need to mark it as deleted.
2134+ if (CurrentLoopValid) {
2135+ if (PartiallyInvariant) {
2136+ // Mark the new loop as partially unswitched, to avoid unswitching on
2137+ // the same condition again.
2138+ auto &Context = L.getHeader ()->getContext ();
2139+ MDNode *DisableUnswitchMD = MDNode::get (
2140+ Context,
2141+ MDString::get (Context, " llvm.loop.unswitch.partial.disable" ));
2142+ MDNode *NewLoopID = makePostTransformationMetadata (
2143+ Context, L.getLoopID (), {" llvm.loop.unswitch.partial" },
2144+ {DisableUnswitchMD});
2145+ L.setLoopID (NewLoopID);
2146+ } else if (InjectedCondition) {
2147+ // Do the same for injection of invariant conditions.
2148+ auto &Context = L.getHeader ()->getContext ();
2149+ MDNode *DisableUnswitchMD = MDNode::get (
2150+ Context,
2151+ MDString::get (Context, " llvm.loop.unswitch.injection.disable" ));
2152+ MDNode *NewLoopID = makePostTransformationMetadata (
2153+ Context, L.getLoopID (), {" llvm.loop.unswitch.injection" },
2154+ {DisableUnswitchMD});
2155+ L.setLoopID (NewLoopID);
2156+ } else
2157+ U.revisitCurrentLoop ();
2158+ } else
2159+ U.markLoopAsDeleted (L, LoopName);
2160+ }
2161+
21262162static void unswitchNontrivialInvariants (
21272163 Loop &L, Instruction &TI, ArrayRef<Value *> Invariants,
21282164 IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI,
2129- AssumptionCache &AC,
2130- function_ref<void (bool , bool , bool , ArrayRef<Loop *>)> UnswitchCB,
2131- ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
2132- function_ref<void(Loop &, StringRef)> DestroyLoopCB, bool InsertFreeze,
2133- bool InjectedCondition) {
2165+ AssumptionCache &AC, ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
2166+ LPMUpdater &LoopUpdater, bool InsertFreeze, bool InjectedCondition) {
21342167 auto *ParentBB = TI.getParent ();
21352168 BranchInst *BI = dyn_cast<BranchInst>(&TI);
21362169 SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
21372170
2171+ // Save the current loop name in a variable so that we can report it even
2172+ // after it has been deleted.
2173+ std::string LoopName (L.getName ());
2174+
21382175 // We can only unswitch switches, conditional branches with an invariant
21392176 // condition, or combining invariant conditions with an instruction or
21402177 // partially invariant instructions.
@@ -2447,7 +2484,7 @@ static void unswitchNontrivialInvariants(
24472484 // Now that our cloned loops have been built, we can update the original loop.
24482485 // First we delete the dead blocks from it and then we rebuild the loop
24492486 // structure taking these deletions into account.
2450- deleteDeadBlocksFromLoop (L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB );
2487+ deleteDeadBlocksFromLoop (L, ExitBlocks, DT, LI, MSSAU, SE, LoopUpdater );
24512488
24522489 if (MSSAU && VerifyMemorySSA)
24532490 MSSAU->getMemorySSA ()->verifyMemorySSA ();
@@ -2583,7 +2620,8 @@ static void unswitchNontrivialInvariants(
25832620 for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops))
25842621 if (UpdatedL->getParentLoop () == ParentL)
25852622 SibLoops.push_back (UpdatedL);
2586- UnswitchCB (IsStillLoop, PartiallyInvariant, InjectedCondition, SibLoops);
2623+ postUnswitch (L, LoopUpdater, LoopName, IsStillLoop, PartiallyInvariant,
2624+ InjectedCondition, SibLoops);
25872625
25882626 if (MSSAU && VerifyMemorySSA)
25892627 MSSAU->getMemorySSA ()->verifyMemorySSA ();
@@ -3430,12 +3468,11 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT,
34303468 Cond, &AC, L.getLoopPreheader ()->getTerminator (), &DT);
34313469}
34323470
3433- static bool unswitchBestCondition (
3434- Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
3435- AAResults &AA, TargetTransformInfo &TTI,
3436- function_ref<void (bool , bool , bool , ArrayRef<Loop *>)> UnswitchCB,
3437- ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
3438- function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
3471+ static bool unswitchBestCondition (Loop &L, DominatorTree &DT, LoopInfo &LI,
3472+ AssumptionCache &AC, AAResults &AA,
3473+ TargetTransformInfo &TTI, ScalarEvolution *SE,
3474+ MemorySSAUpdater *MSSAU,
3475+ LPMUpdater &LoopUpdater) {
34393476 // Collect all invariant conditions within this loop (as opposed to an inner
34403477 // loop which would be handled when visiting that inner loop).
34413478 SmallVector<NonTrivialUnswitchCandidate, 4 > UnswitchCandidates;
@@ -3498,8 +3535,8 @@ static bool unswitchBestCondition(
34983535 LLVM_DEBUG (dbgs () << " Unswitching non-trivial (cost = " << Best.Cost
34993536 << " ) terminator: " << *Best.TI << " \n " );
35003537 unswitchNontrivialInvariants (L, *Best.TI , Best.Invariants , PartialIVInfo, DT,
3501- LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB ,
3502- InsertFreeze, InjectedCondition);
3538+ LI, AC, SE, MSSAU, LoopUpdater, InsertFreeze ,
3539+ InjectedCondition);
35033540 return true ;
35043541}
35053542
@@ -3518,20 +3555,18 @@ static bool unswitchBestCondition(
35183555// / true, we will attempt to do non-trivial unswitching as well as trivial
35193556// / unswitching.
35203557// /
3521- // / The `UnswitchCB` callback provided will be run after unswitching is
3522- // / complete, with the first parameter set to `true` if the provided loop
3523- // / remains a loop, and a list of new sibling loops created.
3558+ // / The `postUnswitch` function will be run after unswitching is complete
3559+ // / with information on whether or not the provided loop remains a loop and
3560+ // / a list of new sibling loops created.
35243561// /
35253562// / If `SE` is non-null, we will update that analysis based on the unswitching
35263563// / done.
3527- static bool
3528- unswitchLoop (Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
3529- AAResults &AA, TargetTransformInfo &TTI, bool Trivial,
3530- bool NonTrivial,
3531- function_ref<void (bool , bool , bool , ArrayRef<Loop *>)> UnswitchCB,
3532- ScalarEvolution *SE, MemorySSAUpdater *MSSAU,
3533- ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI,
3534- function_ref<void(Loop &, StringRef)> DestroyLoopCB) {
3564+ static bool unswitchLoop (Loop &L, DominatorTree &DT, LoopInfo &LI,
3565+ AssumptionCache &AC, AAResults &AA,
3566+ TargetTransformInfo &TTI, bool Trivial,
3567+ bool NonTrivial, ScalarEvolution *SE,
3568+ MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI,
3569+ BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) {
35353570 assert (L.isRecursivelyLCSSAForm (DT, LI) &&
35363571 " Loops must be in LCSSA form before unswitching." );
35373572
@@ -3543,8 +3578,9 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
35433578 if (Trivial && unswitchAllTrivialConditions (L, DT, LI, SE, MSSAU)) {
35443579 // If we unswitched successfully we will want to clean up the loop before
35453580 // processing it further so just mark it as unswitched and return.
3546- UnswitchCB (/* CurrentLoopValid*/ true , /* PartiallyInvariant*/ false ,
3547- /* InjectedCondition*/ false , {});
3581+ postUnswitch (L, LoopUpdater, L.getName (),
3582+ /* CurrentLoopValid*/ true , /* PartiallyInvariant*/ false ,
3583+ /* InjectedCondition*/ false , {});
35483584 return true ;
35493585 }
35503586
@@ -3613,8 +3649,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
36133649
36143650 // Try to unswitch the best invariant condition. We prefer this full unswitch to
36153651 // a partial unswitch when possible below the threshold.
3616- if (unswitchBestCondition (L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU,
3617- DestroyLoopCB))
3652+ if (unswitchBestCondition (L, DT, LI, AC, AA, TTI, SE, MSSAU, LoopUpdater))
36183653 return true ;
36193654
36203655 // No other opportunities to unswitch.
@@ -3634,61 +3669,14 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM,
36343669 LLVM_DEBUG (dbgs () << " Unswitching loop in " << F.getName () << " : " << L
36353670 << " \n " );
36363671
3637- // Save the current loop name in a variable so that we can report it even
3638- // after it has been deleted.
3639- std::string LoopName = std::string (L.getName ());
3640-
3641- auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid,
3642- bool PartiallyInvariant,
3643- bool InjectedCondition,
3644- ArrayRef<Loop *> NewLoops) {
3645- // If we did a non-trivial unswitch, we have added new (cloned) loops.
3646- if (!NewLoops.empty ())
3647- U.addSiblingLoops (NewLoops);
3648-
3649- // If the current loop remains valid, we should revisit it to catch any
3650- // other unswitch opportunities. Otherwise, we need to mark it as deleted.
3651- if (CurrentLoopValid) {
3652- if (PartiallyInvariant) {
3653- // Mark the new loop as partially unswitched, to avoid unswitching on
3654- // the same condition again.
3655- auto &Context = L.getHeader ()->getContext ();
3656- MDNode *DisableUnswitchMD = MDNode::get (
3657- Context,
3658- MDString::get (Context, " llvm.loop.unswitch.partial.disable" ));
3659- MDNode *NewLoopID = makePostTransformationMetadata (
3660- Context, L.getLoopID (), {" llvm.loop.unswitch.partial" },
3661- {DisableUnswitchMD});
3662- L.setLoopID (NewLoopID);
3663- } else if (InjectedCondition) {
3664- // Do the same for injection of invariant conditions.
3665- auto &Context = L.getHeader ()->getContext ();
3666- MDNode *DisableUnswitchMD = MDNode::get (
3667- Context,
3668- MDString::get (Context, " llvm.loop.unswitch.injection.disable" ));
3669- MDNode *NewLoopID = makePostTransformationMetadata (
3670- Context, L.getLoopID (), {" llvm.loop.unswitch.injection" },
3671- {DisableUnswitchMD});
3672- L.setLoopID (NewLoopID);
3673- } else
3674- U.revisitCurrentLoop ();
3675- } else
3676- U.markLoopAsDeleted (L, LoopName);
3677- };
3678-
3679- auto DestroyLoopCB = [&U](Loop &L, StringRef Name) {
3680- U.markLoopAsDeleted (L, Name);
3681- };
3682-
36833672 std::optional<MemorySSAUpdater> MSSAU;
36843673 if (AR.MSSA ) {
36853674 MSSAU = MemorySSAUpdater (AR.MSSA );
36863675 if (VerifyMemorySSA)
36873676 AR.MSSA ->verifyMemorySSA ();
36883677 }
36893678 if (!unswitchLoop (L, AR.DT , AR.LI , AR.AC , AR.AA , AR.TTI , Trivial, NonTrivial,
3690- UnswitchCB, &AR.SE , MSSAU ? &*MSSAU : nullptr , PSI, AR.BFI ,
3691- DestroyLoopCB))
3679+ &AR.SE , MSSAU ? &*MSSAU : nullptr , PSI, AR.BFI , U))
36923680 return PreservedAnalyses::all ();
36933681
36943682 if (AR.MSSA && VerifyMemorySSA)
0 commit comments