Skip to content

Unroll loop into switch #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llvm/include/llvm/Transforms/Utils/UnrollLoop.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

#include "llvm/ADT/DenseMap.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/LazyValueInfo.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/IR/Constants.h"
#include "llvm/Support/InstructionCost.h"

namespace llvm {
Expand Down Expand Up @@ -110,6 +112,16 @@ void simplifyLoopAfterUnroll(Loop *L, bool SimplifyIVs, LoopInfo *LI,
AAResults *AA = nullptr);

MDNode *GetUnrollMetadata(MDNode *LoopID, StringRef Name);
LoopUnrollResult tryUnrollLoopIntoSwitch(Loop &L, LazyValueInfo &LVI,
ScalarEvolution &SE, LoopInfo &LI,
DominatorTree &DT,
bool ForgetAllSCEV = false);
LoopUnrollResult UnrollLoopIntoSwitch(
Loop &L, unsigned UnrollCount, Value *SwitchValue,
ConstantInt *FirstSwitchValue,
std::function<ConstantInt *(ConstantInt *)> nextSwitchValue,
ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT,
bool ForgetAllSCEV = false);

TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences(
Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI,
Expand Down
24 changes: 15 additions & 9 deletions llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "llvm/Analysis/AssumptionCache.h"
#include "llvm/Analysis/BlockFrequencyInfo.h"
#include "llvm/Analysis/CodeMetrics.h"
#include "llvm/Analysis/LazyValueInfo.h"
#include "llvm/Analysis/LoopAnalysisManager.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
Expand Down Expand Up @@ -1598,6 +1599,7 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,
auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
auto &AC = AM.getResult<AssumptionAnalysis>(F);
auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F);
LazyValueInfo &LVI = AM.getResult<LazyValueAnalysis>(F);
AAResults &AA = AM.getResult<AAManager>(F);

LoopAnalysisManager *LAM = nullptr;
Expand Down Expand Up @@ -1645,17 +1647,21 @@ PreservedAnalyses LoopUnrollPass::run(Function &F,
if (PSI && PSI->hasHugeWorkingSetSize())
LocalAllowPeeling = false;
std::string LoopName = std::string(L.getName());
LoopUnrollResult Result =
tryUnrollLoopIntoSwitch(L, LVI, SE, LI, DT, /*PreserveLCSSA*/ true);

// The API here is quite complex to call and we allow to select some
// flavors of unrolling during construction time (by setting UnrollOpts).
LoopUnrollResult Result = tryToUnrollLoop(
&L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI,
/*PreserveLCSSA*/ true, UnrollOpts.OptLevel, /*OnlyFullUnroll*/ false,
UnrollOpts.OnlyWhenForced, UnrollOpts.ForgetSCEV,
/*Count*/ std::nullopt,
/*Threshold*/ std::nullopt, UnrollOpts.AllowPartial,
UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound, LocalAllowPeeling,
UnrollOpts.AllowProfileBasedPeeling, UnrollOpts.FullUnrollMaxCount,
&AA);
if (Result == LoopUnrollResult::Unmodified)
Result = tryToUnrollLoop(
&L, DT, &LI, SE, TTI, AC, ORE, BFI, PSI,
/*PreserveLCSSA*/ true, UnrollOpts.OptLevel, /*OnlyFullUnroll*/ false,
UnrollOpts.OnlyWhenForced, UnrollOpts.ForgetSCEV,
/*Count*/ std::nullopt,
/*Threshold*/ std::nullopt, UnrollOpts.AllowPartial,
UnrollOpts.AllowRuntime, UnrollOpts.AllowUpperBound,
LocalAllowPeeling, UnrollOpts.AllowProfileBasedPeeling,
UnrollOpts.FullUnrollMaxCount, &AA);
Changed |= Result != LoopUnrollResult::Unmodified;

// The parent must not be damaged by unrolling!
Expand Down
303 changes: 303 additions & 0 deletions llvm/lib/Transforms/Utils/LoopUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,3 +1105,306 @@ MDNode *llvm::GetUnrollMetadata(MDNode *LoopID, StringRef Name) {
}
return nullptr;
}

static int64_t getElementCountInRange(const APInt &Lower, const APInt &Upper) {
assert(Lower.getBitWidth() == Upper.getBitWidth() && "Bitwidths must match");
assert(Lower.getBitWidth() <= 64 && "Bitwidth is too big");
uint64_t LowerVal = Lower.getZExtValue();
uint64_t UpperVal = Upper.getZExtValue();
if (UpperVal >= LowerVal) {
return UpperVal - LowerVal + 1;
}
// The values wrap around
uint64_t MaxVal = ((uint64_t)1 << Lower.getBitWidth()) - 1;
return MaxVal - LowerVal + UpperVal + 1;
}

LoopUnrollResult llvm::tryUnrollLoopIntoSwitch(Loop &L, LazyValueInfo &LVI,
ScalarEvolution &SE,
LoopInfo &LI, DominatorTree &DT,
bool ForgetAllSCEV) {
auto Bounds = L.getBounds(SE);
if (!Bounds) {
LLVM_DEBUG(
dbgs() << " Can't unroll into switch; loop bounds are unknown.\n");
return LoopUnrollResult::Unmodified;
}
if (L.getHeader() != L.getLoopLatch()) {
LLVM_DEBUG(
dbgs() << " Can't unroll into switch; loop header != loop latch.\n");
return LoopUnrollResult::Unmodified;
}

ConstantInt *FinalIVValue =
dyn_cast_or_null<ConstantInt>(&Bounds->getFinalIVValue());
// We can extend it in the future to support more complex cases
if (!FinalIVValue) {
LLVM_DEBUG(
dbgs()
<< " Can't unroll into switch; final value of IV is not constant.\n");
return LoopUnrollResult::Unmodified;
}
ConstantInt *StepSizeValue =
dyn_cast_or_null<ConstantInt>(Bounds->getStepValue());
if (!StepSizeValue) {
LLVM_DEBUG(
dbgs() << " Can't unroll into switch; step size is not constant.\n");
return LoopUnrollResult::Unmodified;
}
auto Direction = Bounds->getDirection();
if (Direction == Loop::LoopBounds::Direction::Unknown) {
LLVM_DEBUG(
dbgs() << " Can't unroll into switch; direction of IV is unknown.\n");
return LoopUnrollResult::Unmodified;
}
uint64_t StepSize;
if (Direction == Loop::LoopBounds::Direction::Increasing) {
StepSize = StepSizeValue->getValue().getZExtValue();
} else {
StepSize = StepSizeValue->getValue().abs().getZExtValue();
}
if (StepSize > 1) {
// TODO: To handle steps > 1 we need:
// 1) Make sure that FirstSwitchValue is a multiple of StepSize
// 2) Handle the case when SwitchValue is not a multiple of StepSize
// (potentially fallback to unrolled loop)
LLVM_DEBUG(dbgs() << " Can't unroll into switch; step size is too big.\n");
return LoopUnrollResult::Unmodified;
}

Value &InitialValue = Bounds->getInitialIVValue();
if (!InitialValue.getType()->isIntegerTy() ||
InitialValue.getType()->getIntegerBitWidth() > 64) {
LLVM_DEBUG(dbgs() << " Can't unroll into switch; loop induction variable "
"is not an integer.\n");
return LoopUnrollResult::Unmodified;
}
BasicBlock *PreHeader = L.getLoopPreheader();
if (!PreHeader) {
LLVM_DEBUG(
dbgs() << " Can't unroll into switch; loop has no preheader.\n");
return LoopUnrollResult::Unmodified;
}
auto InitialValueRange =
LVI.getConstantRangeOnEdge(&InitialValue, PreHeader, L.getHeader());
if (InitialValueRange.isFullSet()) {
LLVM_DEBUG(dbgs() << " Can't unroll into switch; no bound for number of "
"iterations available.\n");
return LoopUnrollResult::Unmodified;
}
OverflowingBinaryOperator *StepInstOp =
dyn_cast_or_null<OverflowingBinaryOperator>(&Bounds->getStepInst());
assert(StepInstOp && "Step instruction is not overflowing binary operator");

if (InitialValueRange.contains(FinalIVValue->getValue()) &&
!StepInstOp->hasNoSignedWrap() && !StepInstOp->hasNoUnsignedWrap()) {
LLVM_DEBUG(dbgs() << " Can't unroll into switch; cannot establish upper "
"on the number of iterations.\n");
return LoopUnrollResult::Unmodified;
}

uint64_t UnrollCount = 0;
ConstantInt *FirstSwitchValue;
if (Direction == Loop::LoopBounds::Direction::Decreasing) {
const APInt &UpperBound = InitialValueRange.getUpper();
UnrollCount =
(getElementCountInRange(FinalIVValue->getValue(), UpperBound) - 1) /
StepSize;
FirstSwitchValue = dyn_cast<ConstantInt>(
ConstantInt::get(FinalIVValue->getType(),
UpperBound - APInt(UpperBound.getBitWidth(), 1)));
} else {
const APInt &LowerBound = InitialValueRange.getLower();
UnrollCount =
getElementCountInRange(LowerBound, FinalIVValue->getValue()) / StepSize;
FirstSwitchValue = dyn_cast<ConstantInt>(
ConstantInt::get(FinalIVValue->getType(), LowerBound));
}
LLVM_DEBUG(dbgs() << " Unroll count: " << UnrollCount << "\n");
// How should we determin the max numbe of iterations?
if (UnrollCount > 20) {
LLVM_DEBUG(
dbgs()
<< " Can't unroll into switch; number of iterations is too big.\n");
return LoopUnrollResult::Unmodified;
}

return UnrollLoopIntoSwitch(
L, UnrollCount, &InitialValue, FirstSwitchValue,
[&](ConstantInt *CurrentSwitchValue) {
return dyn_cast<ConstantInt>(ConstantFoldBinaryInstruction(
StepInstOp->getOpcode(), CurrentSwitchValue, StepSizeValue));
},
SE, LI, DT, ForgetAllSCEV);
}

// Unrolls the loop with known upper bound N on the number of iterations
// into N copies of body block chained together and a switch to jump to
// the (N - n)-th block, where n is the number of iterations calculated
// at runtime.
// For example, the following C++ code:
// __builtin_assume(n <= 10);
// do{
// // loop body
// } while(--n >=0);
// will be transformed into the following:
// switch i32 %n, label %sw.0 [
// i32 10, label %sw.10
// i32 9, label %sw.9
// ...
// i32 1, label %sw.1
// ]
// sw.10:
// %n.10 = phi i32 [ %n, %entry ]
// ...
// br label %sw.9
// sw.9:
// %n.9 = phi i32 [ %n, %entry ], [ %n.10, %sw.10 ]
// ...
// br label %sw.8
// ...
// sw.0:
// %n.0 = phi i32 [ %n, %entry ], [ %n.1, %sw.1]
// ...
// br label %exit
//
// UnrollCount is the number of iterations to unroll the loop into (11 in the
// example above).
// SwitchValue value that should map to the starting block (n in
// the example above).
// FirstSwitchValue is the value of SwitchValue for the
// first iteration (10 in the example above).
// nextSwitchValue is a function that
// takes the value of SwitchValue for the previous iteration and returns the
// value of SwitchValue for the next iteration (n - 1 in the example above).
LoopUnrollResult llvm::UnrollLoopIntoSwitch(
Loop &L, unsigned UnrollCount, Value *SwitchValue,
ConstantInt *FirstSwitchValue,
std::function<ConstantInt *(ConstantInt *)> nextSwitchValue,
ScalarEvolution &SE, LoopInfo &LI, DominatorTree &DT, bool ForgetAllSCEV) {
if (UnrollCount <= 1) {
return LoopUnrollResult::Unmodified;
}
if (!SwitchValue->getType()->isIntegerTy()) {
LLVM_DEBUG(dbgs() << " Can't unroll into switch; loop induction variable "
"is not an integer.\n");
return LoopUnrollResult::Unmodified;
}
if (!L.getLoopPreheader()) {
LLVM_DEBUG(
dbgs()
<< " Can't unroll into switch; loop preheader-insertion failed.\n");
return LoopUnrollResult::Unmodified;
}

if (!L.getLoopLatch()) {
LLVM_DEBUG(
dbgs()
<< " Can't unroll into switch; loop exit-block-insertion failed.\n");
return LoopUnrollResult::Unmodified;
}

// Loops with indirectbr cannot be cloned.
if (!L.isSafeToClone()) {
LLVM_DEBUG(
dbgs() << " Can't unroll into switch; Loop body cannot be cloned.\n");
return LoopUnrollResult::Unmodified;
}

if (L.getHeader()->hasAddressTaken()) {
// The loop-rotate pass can be helpful to avoid this in many cases.
LLVM_DEBUG(
dbgs() << " Won't unroll loop: address of header block is taken.\n");
return LoopUnrollResult::Unmodified;
}

if (!L.getExitBlock()) {
LLVM_DEBUG(dbgs() << " Can't unroll; loop has multiple exit blocks.\n");
return LoopUnrollResult::Unmodified;
}

if (L.getHeader() != L.getLoopLatch()) {
LLVM_DEBUG(
dbgs() << " Can't unroll into switch; loop header != loop latch.\n");
return LoopUnrollResult::Unmodified;
}

BasicBlock *PreHeader = L.getLoopPreheader();
BasicBlock *Body = L.getHeader();
BasicBlock *ExitBlock = L.getExitBlock();

DenseMap<PHINode *, Value *> LoopBackedgeValuesOriginal;
for (PHINode &PH : Body->phis()) {
int Index = PH.getBasicBlockIndex(Body);
Value *IncomingValue = PH.getIncomingValue(Index);
assert(IncomingValue && "PHI node has no incoming values from the latch");
LoopBackedgeValuesOriginal[&PH] = IncomingValue;
PH.removeIncomingValue(Index, true);
}

DenseMap<PHINode *, Value *> OrigValueToExitPHIMap;
for (PHINode &PH : ExitBlock->phis()) {
assert(PH.getNumIncomingValues() == 1 &&
"Exit block has multiple incoming values");
OrigValueToExitPHIMap[&PH] = PH.getIncomingValue(0);
}
DenseMap<PHINode *, Value *> PrevValueToExitPHIMap(OrigValueToExitPHIMap);

// Replace terminator in preheader with a switch
BranchInst *PreHeaderBranchInst =
dyn_cast_or_null<BranchInst>(PreHeader->getTerminator());
assert(PreHeaderBranchInst && "Preheader terminator is not a branch.");
SwitchInst *SwitchInst = SwitchInst::Create(SwitchValue, Body, UnrollCount);
ReplaceInstWithInst(PreHeaderBranchInst, SwitchInst);
SwitchInst->addCase(FirstSwitchValue, Body);
ConstantInt *PrevSwitchValue = FirstSwitchValue;

if (ForgetAllSCEV)
SE.forgetAllLoops();
else {
SE.forgetTopmostLoop(&L);
SE.forgetBlockAndLoopDispositions();
}

BasicBlock *PrevBody = Body;
DenseMap<PHINode *, Value *> LoopBackedgeValuesPrevious(
LoopBackedgeValuesOriginal);
for (unsigned IterCount = 1; IterCount < UnrollCount; IterCount++) {
ValueToValueMapTy VMap;
BasicBlock *NewBody = CloneBasicBlock(Body, VMap, "." + Twine(IterCount));
auto BlockInsertPt = std::next(PrevBody->getIterator());
Body->getParent()->insert(BlockInsertPt, NewBody);
if (IterCount != UnrollCount - 1) {
ConstantInt *CurrentSwitchValue = nextSwitchValue(PrevSwitchValue);
SwitchInst->addCase(CurrentSwitchValue, NewBody);
PrevSwitchValue = CurrentSwitchValue;
}
remapInstructionsInBlocks({NewBody}, VMap);
for (auto &[OrigPHINode, PrevIncomingValue] : LoopBackedgeValuesPrevious) {
PHINode *NewPHINode = dyn_cast<PHINode>(VMap[OrigPHINode]);
NewPHINode->addIncoming(PrevIncomingValue, PrevBody);
LoopBackedgeValuesPrevious[OrigPHINode] =
VMap[LoopBackedgeValuesOriginal[OrigPHINode]];
}
ReplaceInstWithInst(PrevBody->getTerminator(), BranchInst::Create(NewBody));
PrevBody = NewBody;
DT.addNewBlock(NewBody, PreHeader);

// Update exit values map
for (auto &[ExitPHINode, ExitValue] : PrevValueToExitPHIMap) {
PrevValueToExitPHIMap[ExitPHINode] =
VMap[OrigValueToExitPHIMap[ExitPHINode]];
}
}
ReplaceInstWithInst(PrevBody->getTerminator(), BranchInst::Create(ExitBlock));
for (auto &[ExitPHINode, ExitValue] : PrevValueToExitPHIMap) {
ExitPHINode->setIncomingBlock(0, PrevBody);
ExitPHINode->setIncomingValue(0, ExitValue);
SE.forgetLcssaPhiWithNewPredecessor(&L, ExitPHINode);
}
SwitchInst->setDefaultDest(PrevBody);
LI.erase(&L);
DT.changeImmediateDominator(ExitBlock, PrevBody);

return LoopUnrollResult::FullyUnrolled;
}