diff --git a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 5560647376523..53dc12d280e63 100644 --- a/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -1682,13 +1682,12 @@ deleteDeadClonedBlocks(Loop &L, ArrayRef ExitBlocks, BB->eraseFromParent(); } -static void -deleteDeadBlocksFromLoop(Loop &L, - SmallVectorImpl &ExitBlocks, - DominatorTree &DT, LoopInfo &LI, - MemorySSAUpdater *MSSAU, - ScalarEvolution *SE, - function_ref DestroyLoopCB) { +static void deleteDeadBlocksFromLoop(Loop &L, + SmallVectorImpl &ExitBlocks, + DominatorTree &DT, LoopInfo &LI, + MemorySSAUpdater *MSSAU, + ScalarEvolution *SE, + LPMUpdater &LoopUpdater) { // Find all the dead blocks tied to this loop, and remove them from their // successors. SmallSetVector DeadBlockSet; @@ -1738,7 +1737,7 @@ deleteDeadBlocksFromLoop(Loop &L, }) && "If the child loop header is dead all blocks in the child loop must " "be dead as well!"); - DestroyLoopCB(*ChildL, ChildL->getName()); + LoopUpdater.markLoopAsDeleted(*ChildL, ChildL->getName()); if (SE) SE->forgetBlockAndLoopDispositions(); LI.destroy(ChildL); @@ -2082,8 +2081,8 @@ static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef ExitBlocks, ParentL->removeChildLoop(llvm::find(*ParentL, &L)); else LI.removeLoop(llvm::find(LI, &L)); - // markLoopAsDeleted for L should be triggered by the caller (it is typically - // done by using the UnswitchCB callback). + // markLoopAsDeleted for L should be triggered by the caller (it is + // typically done within postUnswitch). if (SE) SE->forgetBlockAndLoopDispositions(); LI.destroy(&L); @@ -2120,18 +2119,56 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { } while (!DomWorklist.empty()); } +void postUnswitch(Loop &L, LPMUpdater &U, StringRef LoopName, + bool CurrentLoopValid, bool PartiallyInvariant, + bool InjectedCondition, ArrayRef NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + if (!NewLoops.empty()) + U.addSiblingLoops(NewLoops); + + // If the current loop remains valid, we should revisit it to catch any + // other unswitch opportunities. Otherwise, we need to mark it as deleted. + if (CurrentLoopValid) { + if (PartiallyInvariant) { + // Mark the new loop as partially unswitched, to avoid unswitching on + // the same condition again. + auto &Context = L.getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.partial.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L.getLoopID(), {"llvm.loop.unswitch.partial"}, + {DisableUnswitchMD}); + L.setLoopID(NewLoopID); + } else if (InjectedCondition) { + // Do the same for injection of invariant conditions. + auto &Context = L.getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.injection.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L.getLoopID(), {"llvm.loop.unswitch.injection"}, + {DisableUnswitchMD}); + L.setLoopID(NewLoopID); + } else + U.revisitCurrentLoop(); + } else + U.markLoopAsDeleted(L, LoopName); +} + static void unswitchNontrivialInvariants( Loop &L, Instruction &TI, ArrayRef Invariants, IVConditionInfo &PartialIVInfo, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, - function_ref)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref DestroyLoopCB, bool InsertFreeze, - bool InjectedCondition) { + AssumptionCache &AC, ScalarEvolution *SE, MemorySSAUpdater *MSSAU, + LPMUpdater &LoopUpdater, bool InsertFreeze, bool InjectedCondition) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast(&TI); SwitchInst *SI = BI ? nullptr : cast(&TI); + // Save the current loop name in a variable so that we can report it even + // after it has been deleted. + std::string LoopName(L.getName()); + // We can only unswitch switches, conditional branches with an invariant // condition, or combining invariant conditions with an instruction or // partially invariant instructions. @@ -2444,7 +2481,7 @@ static void unswitchNontrivialInvariants( // Now that our cloned loops have been built, we can update the original loop. // First we delete the dead blocks from it and then we rebuild the loop // structure taking these deletions into account. - deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE,DestroyLoopCB); + deleteDeadBlocksFromLoop(L, ExitBlocks, DT, LI, MSSAU, SE, LoopUpdater); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -2580,7 +2617,8 @@ static void unswitchNontrivialInvariants( for (Loop *UpdatedL : llvm::concat(NonChildClonedLoops, HoistedLoops)) if (UpdatedL->getParentLoop() == ParentL) SibLoops.push_back(UpdatedL); - UnswitchCB(IsStillLoop, PartiallyInvariant, InjectedCondition, SibLoops); + postUnswitch(L, LoopUpdater, LoopName, IsStillLoop, PartiallyInvariant, + InjectedCondition, SibLoops); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -3427,12 +3465,11 @@ static bool shouldInsertFreeze(Loop &L, Instruction &TI, DominatorTree &DT, Cond, &AC, L.getLoopPreheader()->getTerminator(), &DT); } -static bool unswitchBestCondition( - Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - AAResults &AA, TargetTransformInfo &TTI, - function_ref)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - function_ref DestroyLoopCB) { +static bool unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, AAResults &AA, + TargetTransformInfo &TTI, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU, + LPMUpdater &LoopUpdater) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector UnswitchCandidates; @@ -3495,8 +3532,8 @@ static bool unswitchBestCondition( LLVM_DEBUG(dbgs() << " Unswitching non-trivial (cost = " << Best.Cost << ") terminator: " << *Best.TI << "\n"); unswitchNontrivialInvariants(L, *Best.TI, Best.Invariants, PartialIVInfo, DT, - LI, AC, UnswitchCB, SE, MSSAU, DestroyLoopCB, - InsertFreeze, InjectedCondition); + LI, AC, SE, MSSAU, LoopUpdater, InsertFreeze, + InjectedCondition); return true; } @@ -3515,20 +3552,18 @@ static bool unswitchBestCondition( /// true, we will attempt to do non-trivial unswitching as well as trivial /// unswitching. /// -/// The `UnswitchCB` callback provided will be run after unswitching is -/// complete, with the first parameter set to `true` if the provided loop -/// remains a loop, and a list of new sibling loops created. +/// The `postUnswitch` function will be run after unswitching is complete +/// with information on whether or not the provided loop remains a loop and +/// a list of new sibling loops created. /// /// If `SE` is non-null, we will update that analysis based on the unswitching /// done. -static bool -unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, - AAResults &AA, TargetTransformInfo &TTI, bool Trivial, - bool NonTrivial, - function_ref)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU, - ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI, - function_ref DestroyLoopCB) { +static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, AAResults &AA, + TargetTransformInfo &TTI, bool Trivial, + bool NonTrivial, ScalarEvolution *SE, + MemorySSAUpdater *MSSAU, ProfileSummaryInfo *PSI, + BlockFrequencyInfo *BFI, LPMUpdater &LoopUpdater) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -3540,8 +3575,9 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. - UnswitchCB(/*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false, - /*InjectedCondition*/ false, {}); + postUnswitch(L, LoopUpdater, L.getName(), + /*CurrentLoopValid*/ true, /*PartiallyInvariant*/ false, + /*InjectedCondition*/ false, {}); return true; } @@ -3610,8 +3646,7 @@ unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU, - DestroyLoopCB)) + if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, SE, MSSAU, LoopUpdater)) return true; // No other opportunities to unswitch. @@ -3631,52 +3666,6 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, LLVM_DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); - // Save the current loop name in a variable so that we can report it even - // after it has been deleted. - std::string LoopName = std::string(L.getName()); - - auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, - bool PartiallyInvariant, - bool InjectedCondition, - ArrayRef NewLoops) { - // If we did a non-trivial unswitch, we have added new (cloned) loops. - if (!NewLoops.empty()) - U.addSiblingLoops(NewLoops); - - // If the current loop remains valid, we should revisit it to catch any - // other unswitch opportunities. Otherwise, we need to mark it as deleted. - if (CurrentLoopValid) { - if (PartiallyInvariant) { - // Mark the new loop as partially unswitched, to avoid unswitching on - // the same condition again. - auto &Context = L.getHeader()->getContext(); - MDNode *DisableUnswitchMD = MDNode::get( - Context, - MDString::get(Context, "llvm.loop.unswitch.partial.disable")); - MDNode *NewLoopID = makePostTransformationMetadata( - Context, L.getLoopID(), {"llvm.loop.unswitch.partial"}, - {DisableUnswitchMD}); - L.setLoopID(NewLoopID); - } else if (InjectedCondition) { - // Do the same for injection of invariant conditions. - auto &Context = L.getHeader()->getContext(); - MDNode *DisableUnswitchMD = MDNode::get( - Context, - MDString::get(Context, "llvm.loop.unswitch.injection.disable")); - MDNode *NewLoopID = makePostTransformationMetadata( - Context, L.getLoopID(), {"llvm.loop.unswitch.injection"}, - {DisableUnswitchMD}); - L.setLoopID(NewLoopID); - } else - U.revisitCurrentLoop(); - } else - U.markLoopAsDeleted(L, LoopName); - }; - - auto DestroyLoopCB = [&U](Loop &L, StringRef Name) { - U.markLoopAsDeleted(L, Name); - }; - std::optional MSSAU; if (AR.MSSA) { MSSAU = MemorySSAUpdater(AR.MSSA); @@ -3684,8 +3673,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, AR.MSSA->verifyMemorySSA(); } if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, - UnswitchCB, &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, - DestroyLoopCB)) + &AR.SE, MSSAU ? &*MSSAU : nullptr, PSI, AR.BFI, U)) return PreservedAnalyses::all(); if (AR.MSSA && VerifyMemorySSA)