Skip to content

Commit 3060d05

Browse files
committed
Resolve select <all ones>, foo, bar -> vp.select <all ones>, foo, bar, EVL
1 parent 06c6cff commit 3060d05

File tree

1 file changed

+32
-12
lines changed

1 file changed

+32
-12
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,15 +2443,23 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask,
24432443
VPRecipeBase &CurRecipe,
24442444
VPTypeAnalysis &TypeInfo,
24452445
VPValue &AllOneMask, VPValue &EVL) {
2446-
// FIXME: Don't transform recipes to EVL recipes if they're not masked by the
2447-
// header mask.
2446+
// Derive a new mask by removing the header mask. Return nullptr if a new mask
2447+
// cannot be derived because the original mask does not contain the header
2448+
// mask.
24482449
auto GetNewMask = [&](VPValue *OrigMask) -> VPValue * {
24492450
assert(OrigMask && "Unmasked recipe when folding tail");
24502451
// HeaderMask will be handled using EVL.
2452+
if (HeaderMask == OrigMask)
2453+
return &AllOneMask;
24512454
VPValue *Mask;
24522455
if (match(OrigMask, m_LogicalAnd(m_Specific(HeaderMask), m_VPValue(Mask))))
24532456
return Mask;
2454-
return HeaderMask == OrigMask ? nullptr : OrigMask;
2457+
return nullptr;
2458+
};
2459+
2460+
// TODO: Can be simplified by SimplifyRecipes.
2461+
auto OptimizeAllTrueMask = [](VPValue *Mask) -> VPValue * {
2462+
return match(Mask, m_True()) ? nullptr : Mask;
24552463
};
24562464

24572465
/// Adjust any end pointers so that they point to the end of EVL lanes not VF.
@@ -2474,23 +2482,35 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask,
24742482
};
24752483

24762484
return TypeSwitch<VPRecipeBase *, VPRecipeBase *>(&CurRecipe)
2477-
.Case<VPWidenLoadRecipe>([&](VPWidenLoadRecipe *L) {
2485+
.Case<VPWidenLoadRecipe>([&](VPWidenLoadRecipe *L) -> VPRecipeBase * {
24782486
VPValue *NewMask = GetNewMask(L->getMask());
2487+
if (!NewMask)
2488+
return nullptr;
24792489
VPValue *NewAddr = GetNewAddr(L->getAddr());
2480-
return new VPWidenLoadEVLRecipe(*L, NewAddr, EVL, NewMask);
2490+
return new VPWidenLoadEVLRecipe(*L, NewAddr, EVL,
2491+
OptimizeAllTrueMask(NewMask));
24812492
})
2482-
.Case<VPWidenStoreRecipe>([&](VPWidenStoreRecipe *S) {
2493+
.Case<VPWidenStoreRecipe>([&](VPWidenStoreRecipe *S) -> VPRecipeBase * {
24832494
VPValue *NewMask = GetNewMask(S->getMask());
2495+
if (!NewMask)
2496+
return nullptr;
24842497
VPValue *NewAddr = GetNewAddr(S->getAddr());
2485-
return new VPWidenStoreEVLRecipe(*S, NewAddr, EVL, NewMask);
2498+
return new VPWidenStoreEVLRecipe(*S, NewAddr, EVL,
2499+
OptimizeAllTrueMask(NewMask));
24862500
})
2487-
.Case<VPInterleaveRecipe>([&](VPInterleaveRecipe *IR) {
2501+
.Case<VPInterleaveRecipe>([&](VPInterleaveRecipe *IR) -> VPRecipeBase * {
24882502
VPValue *NewMask = GetNewMask(IR->getMask());
2489-
return new VPInterleaveEVLRecipe(*IR, EVL, NewMask);
2503+
if (!NewMask)
2504+
return nullptr;
2505+
return new VPInterleaveEVLRecipe(*IR, EVL,
2506+
OptimizeAllTrueMask(NewMask));
24902507
})
2491-
.Case<VPReductionRecipe>([&](VPReductionRecipe *Red) {
2508+
.Case<VPReductionRecipe>([&](VPReductionRecipe *Red) -> VPRecipeBase * {
24922509
VPValue *NewMask = GetNewMask(Red->getCondOp());
2493-
return new VPReductionEVLRecipe(*Red, EVL, NewMask);
2510+
if (!NewMask)
2511+
return nullptr;
2512+
return new VPReductionEVLRecipe(*Red, EVL,
2513+
OptimizeAllTrueMask(NewMask));
24942514
})
24952515
.Case<VPInstruction>([&](VPInstruction *VPI) -> VPRecipeBase * {
24962516
VPValue *Cond, *LHS, *RHS;
@@ -2504,7 +2524,7 @@ static VPRecipeBase *optimizeMaskToEVL(VPValue *HeaderMask,
25042524

25052525
VPValue *NewMask = GetNewMask(Cond);
25062526
if (!NewMask)
2507-
NewMask = &AllOneMask;
2527+
return nullptr;
25082528
return new VPWidenIntrinsicRecipe(
25092529
Intrinsic::vp_merge, {NewMask, LHS, RHS, &EVL},
25102530
TypeInfo.inferScalarType(LHS), VPI->getDebugLoc());

0 commit comments

Comments
 (0)