Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2051,11 +2051,11 @@ static VPActiveLaneMaskPHIRecipe *addVPLaneMaskPhiAndUpdateExitBranch(
return LaneMaskPhi;
}

/// Collect all VPValues representing a header mask through the (ICMP_ULE,
/// WideCanonicalIV, backedge-taken-count) pattern.
/// Collect the header mask with the pattern:
/// (ICMP_ULE, WideCanonicalIV, backedge-taken-count)
/// TODO: Introduce explicit recipe for header-mask instead of searching
/// for the header-mask pattern manually.
static SmallVector<VPValue *> collectAllHeaderMasks(VPlan &Plan) {
static VPSingleDefRecipe *findHeaderMask(VPlan &Plan) {
SmallVector<VPValue *> WideCanonicalIVs;
auto *FoundWidenCanonicalIVUser =
find_if(Plan.getCanonicalIV()->users(),
Expand All @@ -2079,21 +2079,22 @@ static SmallVector<VPValue *> collectAllHeaderMasks(VPlan &Plan) {
WideCanonicalIVs.push_back(WidenOriginalIV);
}

// Walk users of wide canonical IVs and collect to all compares of the form
// Walk users of wide canonical IVs and find the single compare of the form
// (ICMP_ULE, WideCanonicalIV, backedge-taken-count).
SmallVector<VPValue *> HeaderMasks;
VPSingleDefRecipe *HeaderMask = nullptr;
for (auto *Wide : WideCanonicalIVs) {
for (VPUser *U : SmallVector<VPUser *>(Wide->users())) {
auto *HeaderMask = dyn_cast<VPInstruction>(U);
if (!HeaderMask || !vputils::isHeaderMask(HeaderMask, Plan))
auto *VPI = dyn_cast<VPInstruction>(U);
if (!VPI || !vputils::isHeaderMask(VPI, Plan))
continue;

assert(HeaderMask->getOperand(0) == Wide &&
assert(VPI->getOperand(0) == Wide &&
"WidenCanonicalIV must be the first operand of the compare");
HeaderMasks.push_back(HeaderMask);
assert(!HeaderMask && "Multiple header masks found?");
HeaderMask = VPI;
}
}
return HeaderMasks;
return HeaderMask;
}

void VPlanTransforms::addActiveLaneMask(
Expand All @@ -2109,6 +2110,7 @@ void VPlanTransforms::addActiveLaneMask(
[](VPUser *U) { return isa<VPWidenCanonicalIVRecipe>(U); });
assert(FoundWidenCanonicalIVUser &&
"Must have widened canonical IV when tail folding!");
VPSingleDefRecipe *HeaderMask = findHeaderMask(Plan);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to hoist the call upwards here because once we create the VPInstruction::ActiveLaneMask the assert is violated

auto *WideCanonicalIV =
cast<VPWidenCanonicalIVRecipe>(*FoundWidenCanonicalIVUser);
VPSingleDefRecipe *LaneMask;
Expand All @@ -2122,11 +2124,11 @@ void VPlanTransforms::addActiveLaneMask(
"active.lane.mask");
}

// Walk users of WideCanonicalIV and replace all compares of the form
// (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with an
// active-lane-mask.
for (VPValue *HeaderMask : collectAllHeaderMasks(Plan))
HeaderMask->replaceAllUsesWith(LaneMask);
// Walk users of WideCanonicalIV and replace the header mask of the form
// (ICMP_ULE, WideCanonicalIV, backedge-taken-count) with an active-lane-mask,
// removing the old one to ensure there is always only a single header mask.
HeaderMask->replaceAllUsesWith(LaneMask);
HeaderMask->eraseFromParent();
}

/// Try to optimize a \p CurRecipe masked by \p HeaderMask to a corresponding
Expand Down Expand Up @@ -2252,6 +2254,10 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) {
}
}

VPValue *HeaderMask = findHeaderMask(Plan);
if (!HeaderMask)
return;

// Replace header masks with a mask equivalent to predicating by EVL:
//
// icmp ule widen-canonical-iv backedge-taken-count
Expand All @@ -2263,10 +2269,8 @@ static void transformRecipestoEVLRecipes(VPlan &Plan, VPValue &EVL) {
VPValue *EVLMask = Builder.createICmp(
CmpInst::ICMP_ULT,
Builder.createNaryOp(VPInstruction::StepVector, {}, EVLType), &EVL);
for (VPValue *HeaderMask : collectAllHeaderMasks(Plan)) {
HeaderMask->replaceAllUsesWith(EVLMask);
ToErase.push_back(HeaderMask->getDefiningRecipe());
}
HeaderMask->replaceAllUsesWith(EVLMask);
ToErase.push_back(HeaderMask->getDefiningRecipe());

// Try to optimize header mask recipes away to their EVL variants.
// TODO: Split optimizeMaskToEVL out and move into
Expand Down