8484#include < cstdint>
8585#include < iterator>
8686#include < map>
87+ #include < numeric>
8788#include < optional>
8889#include < set>
8990#include < tuple>
@@ -6318,9 +6319,12 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
63186319// Helper function that checks if it is possible to transform a switch with only
63196320// two cases (or two cases + default) that produces a result into a select.
63206321// TODO: Handle switches with more than 2 cases that map to the same result.
6322+ // The branch weights correspond to the provided Condition (i.e. if Condition is
6323+ // modified from the original SwitchInst, the caller must adjust the weights)
63216324static Value *foldSwitchToSelect (const SwitchCaseResultVectorTy &ResultVector,
63226325 Constant *DefaultResult, Value *Condition,
6323- IRBuilder<> &Builder, const DataLayout &DL) {
6326+ IRBuilder<> &Builder, const DataLayout &DL,
6327+ ArrayRef<uint32_t > BranchWeights) {
63246328 // If we are selecting between only two cases transform into a simple
63256329 // select or a two-way select if default is possible.
63266330 // Example:
@@ -6329,6 +6333,10 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63296333 // case 20: return 2; ----> %2 = icmp eq i32 %a, 20
63306334 // default: return 4; %3 = select i1 %2, i32 2, i32 %1
63316335 // }
6336+
6337+ const bool HasBranchWeights =
6338+ !BranchWeights.empty () && !ProfcheckDisableMetadataFixes;
6339+
63326340 if (ResultVector.size () == 2 && ResultVector[0 ].second .size () == 1 &&
63336341 ResultVector[1 ].second .size () == 1 ) {
63346342 ConstantInt *FirstCase = ResultVector[0 ].second [0 ];
@@ -6337,13 +6345,37 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63376345 if (DefaultResult) {
63386346 Value *ValueCompare =
63396347 Builder.CreateICmpEQ (Condition, SecondCase, " switch.selectcmp" );
6340- SelectValue = Builder.CreateSelect (ValueCompare, ResultVector[1 ].first ,
6341- DefaultResult, " switch.select" );
6348+ SelectInst *SelectValueInst = cast<SelectInst>(Builder.CreateSelect (
6349+ ValueCompare, ResultVector[1 ].first , DefaultResult, " switch.select" ));
6350+ SelectValue = SelectValueInst;
6351+ if (HasBranchWeights) {
6352+ // We start with 3 probabilities, where the numerator is the
6353+ // corresponding BranchWeights[i], and the denominator is the sum over
6354+ // BranchWeights. We want the probability and negative probability of
6355+ // Condition == SecondCase.
6356+ assert (BranchWeights.size () == 3 );
6357+ setBranchWeights (SelectValueInst, BranchWeights[2 ],
6358+ BranchWeights[0 ] + BranchWeights[1 ],
6359+ /* IsExpected=*/ false );
6360+ }
63426361 }
63436362 Value *ValueCompare =
63446363 Builder.CreateICmpEQ (Condition, FirstCase, " switch.selectcmp" );
6345- return Builder.CreateSelect (ValueCompare, ResultVector[0 ].first ,
6346- SelectValue, " switch.select" );
6364+ SelectInst *Ret = cast<SelectInst>(Builder.CreateSelect (
6365+ ValueCompare, ResultVector[0 ].first , SelectValue, " switch.select" ));
6366+ if (HasBranchWeights) {
6367+ // We may have had a DefaultResult. Base the position of the first and
6368+ // second's branch weights accordingly. Also the proability that Condition
6369+ // != FirstCase needs to take that into account.
6370+ assert (BranchWeights.size () >= 2 );
6371+ size_t FirstCasePos = (Condition != nullptr );
6372+ size_t SecondCasePos = FirstCasePos + 1 ;
6373+ uint32_t DefaultCase = (Condition != nullptr ) ? BranchWeights[0 ] : 0 ;
6374+ setBranchWeights (Ret, BranchWeights[FirstCasePos],
6375+ DefaultCase + BranchWeights[SecondCasePos],
6376+ /* IsExpected=*/ false );
6377+ }
6378+ return Ret;
63476379 }
63486380
63496381 // Handle the degenerate case where two cases have the same result value.
@@ -6379,8 +6411,16 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63796411 Value *And = Builder.CreateAnd (Condition, AndMask);
63806412 Value *Cmp = Builder.CreateICmpEQ (
63816413 And, Constant::getIntegerValue (And->getType (), AndMask));
6382- return Builder.CreateSelect (Cmp, ResultVector[0 ].first ,
6383- DefaultResult);
6414+ SelectInst *Ret = cast<SelectInst>(
6415+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6416+ if (HasBranchWeights) {
6417+ // We know there's a Default case. We base the resulting branch
6418+ // weights off its probability.
6419+ assert (BranchWeights.size () >= 2 );
6420+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6421+ BranchWeights[0 ], /* IsExpected=*/ false );
6422+ }
6423+ return Ret;
63846424 }
63856425 }
63866426
@@ -6397,7 +6437,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
63976437 Value *And = Builder.CreateAnd (Condition, ~BitMask, " switch.and" );
63986438 Value *Cmp = Builder.CreateICmpEQ (
63996439 And, Constant::getNullValue (And->getType ()), " switch.selectcmp" );
6400- return Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult);
6440+ SelectInst *Ret = cast<SelectInst>(
6441+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6442+ if (HasBranchWeights) {
6443+ assert (BranchWeights.size () >= 2 );
6444+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6445+ BranchWeights[0 ], /* IsExpected=*/ false );
6446+ }
6447+ return Ret;
64016448 }
64026449 }
64036450
@@ -6408,7 +6455,14 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
64086455 Value *Cmp2 = Builder.CreateICmpEQ (Condition, CaseValues[1 ],
64096456 " switch.selectcmp.case2" );
64106457 Value *Cmp = Builder.CreateOr (Cmp1, Cmp2, " switch.selectcmp" );
6411- return Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult);
6458+ SelectInst *Ret = cast<SelectInst>(
6459+ Builder.CreateSelect (Cmp, ResultVector[0 ].first , DefaultResult));
6460+ if (HasBranchWeights) {
6461+ assert (BranchWeights.size () >= 2 );
6462+ setBranchWeights (Ret, accumulate (drop_begin (BranchWeights), 0 ),
6463+ BranchWeights[0 ], /* IsExpected=*/ false );
6464+ }
6465+ return Ret;
64126466 }
64136467 }
64146468
@@ -6469,8 +6523,18 @@ static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
64696523
64706524 assert (PHI != nullptr && " PHI for value select not found" );
64716525 Builder.SetInsertPoint (SI);
6472- Value *SelectValue =
6473- foldSwitchToSelect (UniqueResults, DefaultResult, Cond, Builder, DL);
6526+ SmallVector<uint32_t , 4 > BranchWeights;
6527+ if (!ProfcheckDisableMetadataFixes) {
6528+ [[maybe_unused]] auto HasWeights =
6529+ extractBranchWeights (getBranchWeightMDNode (*SI), BranchWeights);
6530+ assert (!HasWeights == (BranchWeights.empty ()));
6531+ }
6532+ assert (BranchWeights.empty () ||
6533+ (BranchWeights.size () >=
6534+ UniqueResults.size () + (DefaultResult != nullptr )));
6535+
6536+ Value *SelectValue = foldSwitchToSelect (UniqueResults, DefaultResult, Cond,
6537+ Builder, DL, BranchWeights);
64746538 if (!SelectValue)
64756539 return false ;
64766540
0 commit comments