Skip to content

Commit 6d3342f

Browse files
committed
[profcheck][SimplifyCFG] Propagate !prof from switch to select
1 parent c42de45 commit 6d3342f

File tree

2 files changed

+117
-41
lines changed

2 files changed

+117
-41
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
#include <cstdint>
8585
#include <iterator>
8686
#include <map>
87+
#include <numeric>
8788
#include <optional>
8889
#include <set>
8990
#include <tuple>
@@ -6318,9 +6319,12 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
63186319
// Helper function that checks if it is possible to transform a switch with only
63196320
// two cases (or two cases + default) that produces a result into a select.
63206321
// TODO: Handle switches with more than 2 cases that map to the same result.
6322+
// The branch weights correspond to the provided Condition (i.e. if Condition is
6323+
// modified from the original SwitchInst, the caller must adjust the weights)
63216324
static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63226325
Constant *DefaultResult, Value *Condition,
6323-
IRBuilder<> &Builder, const DataLayout &DL) {
6326+
IRBuilder<> &Builder, const DataLayout &DL,
6327+
ArrayRef<uint32_t> BranchWeights) {
63246328
// If we are selecting between only two cases transform into a simple
63256329
// select or a two-way select if default is possible.
63266330
// Example:
@@ -6329,6 +6333,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63296333
// case 20: return 2; ----> %2 = icmp eq i32 %a, 20
63306334
// default: return 4; %3 = select i1 %2, i32 2, i32 %1
63316335
// }
6336+
6337+
const bool HasBranchWeights =
6338+
!BranchWeights.empty() && !ProfcheckDisableMetadataFixes;
6339+
63326340
if (ResultVector.size() == 2 && ResultVector[0].second.size() == 1 &&
63336341
ResultVector[1].second.size() == 1) {
63346342
ConstantInt *FirstCase = ResultVector[0].second[0];
@@ -6337,13 +6345,37 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63376345
if (DefaultResult) {
63386346
Value *ValueCompare =
63396347
Builder.CreateICmpEQ(Condition, SecondCase, "switch.selectcmp");
6340-
SelectValue = Builder.CreateSelect(ValueCompare, ResultVector[1].first,
6341-
DefaultResult, "switch.select");
6348+
SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect(
6349+
ValueCompare, ResultVector[1].first, DefaultResult, "switch.select"));
6350+
SelectValue = SelectValueInst;
6351+
if (HasBranchWeights) {
6352+
// We start with 3 probabilities, where the numerator is the
6353+
// corresponding BranchWeights[i], and the denominator is the sum over
6354+
// BranchWeights. We want the probability and negative probability of
6355+
// Condition == SecondCase.
6356+
assert(BranchWeights.size() == 3);
6357+
setBranchWeights(SelectValueInst, BranchWeights[2],
6358+
BranchWeights[0] + BranchWeights[1],
6359+
/*IsExpected=*/false);
6360+
}
63426361
}
63436362
Value *ValueCompare =
63446363
Builder.CreateICmpEQ(Condition, FirstCase, "switch.selectcmp");
6345-
return Builder.CreateSelect(ValueCompare, ResultVector[0].first,
6346-
SelectValue, "switch.select");
6364+
SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect(
6365+
ValueCompare, ResultVector[0].first, SelectValue, "switch.select"));
6366+
if (HasBranchWeights) {
6367+
// We may have had a DefaultResult. Base the position of the first and
6368+
// second's branch weights accordingly. Also the proability that Condition
6369+
// != FirstCase needs to take that into account.
6370+
assert(BranchWeights.size() >= 2);
6371+
size_t FirstCasePos = (Condition != nullptr);
6372+
size_t SecondCasePos = FirstCasePos + 1;
6373+
uint32_t DefaultCase = (Condition != nullptr) ? BranchWeights[0] : 0;
6374+
setBranchWeights(Ret, BranchWeights[FirstCasePos],
6375+
DefaultCase + BranchWeights[SecondCasePos],
6376+
/*IsExpected=*/false);
6377+
}
6378+
return Ret;
63476379
}
63486380

63496381
// Handle the degenerate case where two cases have the same result value.
@@ -6379,8 +6411,16 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63796411
Value *And = Builder.CreateAnd(Condition, AndMask);
63806412
Value *Cmp = Builder.CreateICmpEQ(
63816413
And, Constant::getIntegerValue(And->getType(), AndMask));
6382-
return Builder.CreateSelect(Cmp, ResultVector[0].first,
6383-
DefaultResult);
6414+
SelectInst *Ret = cast<SelectInst>(
6415+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6416+
if (HasBranchWeights) {
6417+
// We know there's a Default case. We base the resulting branch
6418+
// weights off its probability.
6419+
assert(BranchWeights.size() >= 2);
6420+
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6421+
BranchWeights[0], /*IsExpected=*/false);
6422+
}
6423+
return Ret;
63846424
}
63856425
}
63866426

@@ -6397,7 +6437,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63976437
Value *And = Builder.CreateAnd(Condition, ~BitMask, "switch.and");
63986438
Value *Cmp = Builder.CreateICmpEQ(
63996439
And, Constant::getNullValue(And->getType()), "switch.selectcmp");
6400-
return Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
6440+
SelectInst *Ret = cast<SelectInst>(
6441+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6442+
if (HasBranchWeights) {
6443+
assert(BranchWeights.size() >= 2);
6444+
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6445+
BranchWeights[0], /*IsExpected=*/false);
6446+
}
6447+
return Ret;
64016448
}
64026449
}
64036450

@@ -6408,7 +6455,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64086455
Value *Cmp2 = Builder.CreateICmpEQ(Condition, CaseValues[1],
64096456
"switch.selectcmp.case2");
64106457
Value *Cmp = Builder.CreateOr(Cmp1, Cmp2, "switch.selectcmp");
6411-
return Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult);
6458+
SelectInst *Ret = cast<SelectInst>(
6459+
Builder.CreateSelect(Cmp, ResultVector[0].first, DefaultResult));
6460+
if (HasBranchWeights) {
6461+
assert(BranchWeights.size() >= 2);
6462+
setBranchWeights(Ret, accumulate(drop_begin(BranchWeights), 0),
6463+
BranchWeights[0], /*IsExpected=*/false);
6464+
}
6465+
return Ret;
64126466
}
64136467
}
64146468

@@ -6469,8 +6523,18 @@ static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
64696523

64706524
assert(PHI != nullptr && "PHI for value select not found");
64716525
Builder.SetInsertPoint(SI);
6472-
Value *SelectValue =
6473-
foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder, DL);
6526+
SmallVector<uint32_t, 4> BranchWeights;
6527+
if (!ProfcheckDisableMetadataFixes) {
6528+
[[maybe_unused]] auto HasWeights =
6529+
extractBranchWeights(getBranchWeightMDNode(*SI), BranchWeights);
6530+
assert(!HasWeights == (BranchWeights.empty()));
6531+
}
6532+
assert(BranchWeights.empty() ||
6533+
(BranchWeights.size() >=
6534+
UniqueResults.size() + (DefaultResult != nullptr)));
6535+
6536+
Value *SelectValue = foldSwitchToSelect(UniqueResults, DefaultResult, Cond,
6537+
Builder, DL, BranchWeights);
64746538
if (!SelectValue)
64756539
return false;
64766540

0 commit comments

Comments
 (0)