Skip to content

Commit 62d1a08

Browse files
authored
[LV] Use ExtractLane(LastActiveLane, V) live outs when tail-folding. (#149042)
Building on top of #148817, introduce a new abstract LastActiveLane opcode that gets lowered to Not(Mask) → FirstActiveLane(NotMask) → Sub(result, 1). When folding the tail, update all extracts for uses outside the loop the extract the value of the last actice lane. See also #148603 PR: #149042
1 parent 1e4b9e4 commit 62d1a08

21 files changed

+1626
-584
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2097,24 +2097,6 @@ bool LoopVectorizationLegality::canFoldTailByMasking() const {
20972097
for (const auto &Reduction : getReductionVars())
20982098
ReductionLiveOuts.insert(Reduction.second.getLoopExitInstr());
20992099

2100-
// TODO: handle non-reduction outside users when tail is folded by masking.
2101-
for (auto *AE : AllowedExit) {
2102-
// Check that all users of allowed exit values are inside the loop or
2103-
// are the live-out of a reduction.
2104-
if (ReductionLiveOuts.count(AE))
2105-
continue;
2106-
for (User *U : AE->users()) {
2107-
Instruction *UI = cast<Instruction>(U);
2108-
if (TheLoop->contains(UI))
2109-
continue;
2110-
LLVM_DEBUG(
2111-
dbgs()
2112-
<< "LV: Cannot fold tail by masking, loop has an outside user for "
2113-
<< *UI << "\n");
2114-
return false;
2115-
}
2116-
}
2117-
21182100
for (const auto &Entry : getInductionVars()) {
21192101
PHINode *OrigPhi = Entry.first;
21202102
for (User *U : OrigPhi->users()) {

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8895,7 +8895,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
88958895
if (FinalReductionResult == U || Parent->getParent())
88968896
continue;
88978897
U->replaceUsesOfWith(OrigExitingVPV, FinalReductionResult);
8898-
if (match(U, m_ExtractLastElement(m_VPValue())))
8898+
if (match(U, m_CombineOr(m_ExtractLastElement(m_VPValue()),
8899+
m_ExtractLane(m_VPValue(), m_VPValue()))))
88998900
cast<VPInstruction>(U)->replaceAllUsesWith(FinalReductionResult);
89008901
}
89018902

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1047,6 +1047,13 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10471047
// It produces the lane index across all unrolled iterations. Unrolling will
10481048
// add all copies of its original operand as additional operands.
10491049
FirstActiveLane,
1050+
// Calculates the last active lane index of the vector predicate operands.
1051+
// The predicates must be prefix-masks (all 1s before all 0s). Used when
1052+
// tail-folding to extract the correct live-out value from the last active
1053+
// iteration. It produces the lane index across all unrolled iterations.
1054+
// Unrolling will add all copies of its original operand as additional
1055+
// operands.
1056+
LastActiveLane,
10501057

10511058
// The opcodes below are used for VPInstructionWithType.
10521059
//

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
115115
case VPInstruction::ExtractLane:
116116
return inferScalarType(R->getOperand(1));
117117
case VPInstruction::FirstActiveLane:
118+
case VPInstruction::LastActiveLane:
118119
return Type::getIntNTy(Ctx, 64);
119120
case VPInstruction::ExtractLastElement:
120121
case VPInstruction::ExtractLastLanePerPart:

llvm/lib/Transforms/Vectorize/VPlanPatternMatch.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,24 @@ m_ExtractElement(const Op0_t &Op0, const Op1_t &Op1) {
395395
return m_VPInstruction<Instruction::ExtractElement>(Op0, Op1);
396396
}
397397

398+
template <typename Op0_t, typename Op1_t>
399+
inline VPInstruction_match<VPInstruction::ExtractLane, Op0_t, Op1_t>
400+
m_ExtractLane(const Op0_t &Op0, const Op1_t &Op1) {
401+
return m_VPInstruction<VPInstruction::ExtractLane>(Op0, Op1);
402+
}
403+
398404
template <typename Op0_t>
399405
inline VPInstruction_match<VPInstruction::ExtractLastLanePerPart, Op0_t>
400406
m_ExtractLastLanePerPart(const Op0_t &Op0) {
401407
return m_VPInstruction<VPInstruction::ExtractLastLanePerPart>(Op0);
402408
}
403409

410+
template <typename Op0_t>
411+
inline VPInstruction_match<VPInstruction::ExtractPenultimateElement, Op0_t>
412+
m_ExtractPenultimateElement(const Op0_t &Op0) {
413+
return m_VPInstruction<VPInstruction::ExtractPenultimateElement>(Op0);
414+
}
415+
404416
template <typename Op0_t, typename Op1_t, typename Op2_t>
405417
inline VPInstruction_match<VPInstruction::ActiveLaneMask, Op0_t, Op1_t, Op2_t>
406418
m_ActiveLaneMask(const Op0_t &Op0, const Op1_t &Op1, const Op2_t &Op2) {
@@ -429,6 +441,16 @@ m_FirstActiveLane(const Op0_t &Op0) {
429441
return m_VPInstruction<VPInstruction::FirstActiveLane>(Op0);
430442
}
431443

444+
template <typename Op0_t>
445+
inline VPInstruction_match<VPInstruction::LastActiveLane, Op0_t>
446+
m_LastActiveLane(const Op0_t &Op0) {
447+
return m_VPInstruction<VPInstruction::LastActiveLane>(Op0);
448+
}
449+
450+
inline VPInstruction_match<VPInstruction::StepVector> m_StepVector() {
451+
return m_VPInstruction<VPInstruction::StepVector>();
452+
}
453+
432454
template <unsigned Opcode, typename Op0_t>
433455
inline AllRecipe_match<Opcode, Op0_t> m_Unary(const Op0_t &Op0) {
434456
return AllRecipe_match<Opcode, Op0_t>(Op0);

llvm/lib/Transforms/Vectorize/VPlanPredicator.cpp

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ class VPPredicator {
4444
/// possibly inserting new recipes at \p Dst (using Builder's insertion point)
4545
VPValue *createEdgeMask(VPBasicBlock *Src, VPBasicBlock *Dst);
4646

47-
/// Returns the *entry* mask for \p VPBB.
48-
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
49-
return BlockMaskCache.lookup(VPBB);
50-
}
51-
5247
/// Record \p Mask as the *entry* mask of \p VPBB, which is expected to not
5348
/// already have a mask.
5449
void setBlockInMask(VPBasicBlock *VPBB, VPValue *Mask) {
@@ -68,6 +63,11 @@ class VPPredicator {
6863
}
6964

7065
public:
66+
/// Returns the *entry* mask for \p VPBB.
67+
VPValue *getBlockInMask(VPBasicBlock *VPBB) const {
68+
return BlockMaskCache.lookup(VPBB);
69+
}
70+
7171
/// Returns the precomputed predicate of the edge from \p Src to \p Dst.
7272
VPValue *getEdgeMask(const VPBasicBlock *Src, const VPBasicBlock *Dst) const {
7373
return EdgeMaskCache.lookup({Src, Dst});
@@ -301,5 +301,34 @@ VPlanTransforms::introduceMasksAndLinearize(VPlan &Plan, bool FoldTail) {
301301

302302
PrevVPBB = VPBB;
303303
}
304+
305+
// If we folded the tail and introduced a header mask, any extract of the
306+
// last element must be updated to extract from the last active lane of the
307+
// header mask instead (i.e., the lane corresponding to the last active
308+
// iteration).
309+
if (FoldTail) {
310+
assert(Plan.getExitBlocks().size() == 1 &&
311+
"only a single-exit block is supported currently");
312+
VPBasicBlock *EB = Plan.getExitBlocks().front();
313+
assert(EB->getSinglePredecessor() == Plan.getMiddleBlock() &&
314+
"the exit block must have middle block as single predecessor");
315+
316+
VPBuilder B(Plan.getMiddleBlock()->getTerminator());
317+
for (auto &P : EB->phis()) {
318+
auto *ExitIRI = cast<VPIRPhi>(&P);
319+
VPValue *Inc = ExitIRI->getIncomingValue(0);
320+
VPValue *Op;
321+
if (!match(Inc, m_ExtractLastElement(m_VPValue(Op))))
322+
continue;
323+
324+
// Compute the index of the last active lane.
325+
VPValue *HeaderMask = Predicator.getBlockInMask(Header);
326+
VPValue *LastActiveLane =
327+
B.createNaryOp(VPInstruction::LastActiveLane, HeaderMask);
328+
auto *Ext =
329+
B.createNaryOp(VPInstruction::ExtractLane, {LastActiveLane, Op});
330+
Inc->replaceAllUsesWith(Ext);
331+
}
332+
}
304333
return Predicator.getBlockMaskCache();
305334
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ unsigned VPInstruction::getNumOperandsForOpcode(unsigned Opcode) {
547547
case VPInstruction::ExtractLastLanePerPart:
548548
case VPInstruction::ExtractPenultimateElement:
549549
case VPInstruction::FirstActiveLane:
550+
case VPInstruction::LastActiveLane:
550551
case VPInstruction::Not:
551552
case VPInstruction::Unpack:
552553
return 1;
@@ -1156,6 +1157,29 @@ InstructionCost VPInstruction::computeCost(ElementCount VF,
11561157
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
11571158
return Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
11581159
}
1160+
case VPInstruction::LastActiveLane: {
1161+
Type *ScalarTy = Ctx.Types.inferScalarType(getOperand(0));
1162+
if (VF.isScalar())
1163+
return Ctx.TTI.getCmpSelInstrCost(Instruction::ICmp, ScalarTy,
1164+
CmpInst::makeCmpResultType(ScalarTy),
1165+
CmpInst::ICMP_EQ, Ctx.CostKind);
1166+
// Calculate the cost of determining the lane index: NOT + cttz_elts + SUB.
1167+
auto *PredTy = toVectorTy(ScalarTy, VF);
1168+
IntrinsicCostAttributes Attrs(Intrinsic::experimental_cttz_elts,
1169+
Type::getInt64Ty(Ctx.LLVMCtx),
1170+
{PredTy, Type::getInt1Ty(Ctx.LLVMCtx)});
1171+
InstructionCost Cost = Ctx.TTI.getIntrinsicInstrCost(Attrs, Ctx.CostKind);
1172+
// Add cost of NOT operation on the predicate.
1173+
Cost += Ctx.TTI.getArithmeticInstrCost(
1174+
Instruction::Xor, PredTy, Ctx.CostKind,
1175+
{TargetTransformInfo::OK_AnyValue, TargetTransformInfo::OP_None},
1176+
{TargetTransformInfo::OK_UniformConstantValue,
1177+
TargetTransformInfo::OP_None});
1178+
// Add cost of SUB operation on the index.
1179+
Cost += Ctx.TTI.getArithmeticInstrCost(
1180+
Instruction::Sub, Type::getInt64Ty(Ctx.LLVMCtx), Ctx.CostKind);
1181+
return Cost;
1182+
}
11591183
case VPInstruction::FirstOrderRecurrenceSplice: {
11601184
assert(VF.isVector() && "Scalar FirstOrderRecurrenceSplice?");
11611185
SmallVector<int> Mask(VF.getKnownMinValue());
@@ -1210,6 +1234,7 @@ bool VPInstruction::isVectorToScalar() const {
12101234
getOpcode() == Instruction::ExtractElement ||
12111235
getOpcode() == VPInstruction::ExtractLane ||
12121236
getOpcode() == VPInstruction::FirstActiveLane ||
1237+
getOpcode() == VPInstruction::LastActiveLane ||
12131238
getOpcode() == VPInstruction::ComputeAnyOfResult ||
12141239
getOpcode() == VPInstruction::ComputeFindIVResult ||
12151240
getOpcode() == VPInstruction::ComputeReductionResult ||
@@ -1275,6 +1300,7 @@ bool VPInstruction::opcodeMayReadOrWriteFromMemory() const {
12751300
case VPInstruction::ExtractPenultimateElement:
12761301
case VPInstruction::ActiveLaneMask:
12771302
case VPInstruction::FirstActiveLane:
1303+
case VPInstruction::LastActiveLane:
12781304
case VPInstruction::FirstOrderRecurrenceSplice:
12791305
case VPInstruction::LogicalAnd:
12801306
case VPInstruction::Not:
@@ -1451,6 +1477,9 @@ void VPInstruction::print(raw_ostream &O, const Twine &Indent,
14511477
case VPInstruction::FirstActiveLane:
14521478
O << "first-active-lane";
14531479
break;
1480+
case VPInstruction::LastActiveLane:
1481+
O << "last-active-lane";
1482+
break;
14541483
case VPInstruction::ReductionStartVector:
14551484
O << "reduction-start-vector";
14561485
break;

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -805,8 +805,8 @@ static VPValue *optimizeEarlyExitInductionUser(VPlan &Plan,
805805
VPValue *Op,
806806
ScalarEvolution &SE) {
807807
VPValue *Incoming, *Mask;
808-
if (!match(Op, m_VPInstruction<VPInstruction::ExtractLane>(
809-
m_FirstActiveLane(m_VPValue(Mask)), m_VPValue(Incoming))))
808+
if (!match(Op, m_ExtractLane(m_FirstActiveLane(m_VPValue(Mask)),
809+
m_VPValue(Incoming))))
810810
return nullptr;
811811

812812
auto *WideIV = getOptimizableIVOf(Incoming, SE);
@@ -1274,8 +1274,7 @@ static void simplifyRecipe(VPSingleDefRecipe *Def, VPTypeAnalysis &TypeInfo) {
12741274
}
12751275

12761276
// Look through ExtractPenultimateElement (BuildVector ....).
1277-
if (match(Def, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
1278-
m_BuildVector()))) {
1277+
if (match(Def, m_ExtractPenultimateElement(m_BuildVector()))) {
12791278
auto *BuildVector = cast<VPInstruction>(Def->getOperand(0));
12801279
Def->replaceAllUsesWith(
12811280
BuildVector->getOperand(BuildVector->getNumOperands() - 2));
@@ -2056,6 +2055,32 @@ bool VPlanTransforms::adjustFixedOrderRecurrences(VPlan &Plan,
20562055
// Set the first operand of RecurSplice to FOR again, after replacing
20572056
// all users.
20582057
RecurSplice->setOperand(0, FOR);
2058+
2059+
// Check for users extracting at the penultimate active lane of the FOR.
2060+
// If only a single lane is active in the current iteration, we need to
2061+
// select the last element from the previous iteration (from the FOR phi
2062+
// directly).
2063+
for (VPUser *U : RecurSplice->users()) {
2064+
if (!match(U, m_ExtractLane(m_LastActiveLane(m_VPValue()),
2065+
m_Specific(RecurSplice))))
2066+
continue;
2067+
2068+
VPBuilder B(cast<VPInstruction>(U));
2069+
VPValue *LastActiveLane = cast<VPInstruction>(U)->getOperand(0);
2070+
Type *I64Ty = Type::getInt64Ty(Plan.getContext());
2071+
VPValue *Zero = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 0));
2072+
VPValue *One = Plan.getOrAddLiveIn(ConstantInt::get(I64Ty, 1));
2073+
VPValue *PenultimateIndex =
2074+
B.createNaryOp(Instruction::Sub, {LastActiveLane, One});
2075+
VPValue *PenultimateLastIter =
2076+
B.createNaryOp(VPInstruction::ExtractLane,
2077+
{PenultimateIndex, FOR->getBackedgeValue()});
2078+
VPValue *LastPrevIter =
2079+
B.createNaryOp(VPInstruction::ExtractLastElement, FOR);
2080+
VPValue *Cmp = B.createICmp(CmpInst::ICMP_EQ, LastActiveLane, Zero);
2081+
VPValue *Sel = B.createSelect(Cmp, LastPrevIter, PenultimateLastIter);
2082+
cast<VPInstruction>(U)->replaceAllUsesWith(Sel);
2083+
}
20592084
}
20602085
return true;
20612086
}
@@ -3445,6 +3470,34 @@ void VPlanTransforms::convertToConcreteRecipes(VPlan &Plan) {
34453470
ToRemove.push_back(Expr);
34463471
}
34473472

3473+
// Expand LastActiveLane into Not + FirstActiveLane + Sub.
3474+
auto *LastActiveL = dyn_cast<VPInstruction>(&R);
3475+
if (LastActiveL &&
3476+
LastActiveL->getOpcode() == VPInstruction::LastActiveLane) {
3477+
// Create Not(Mask) for all operands.
3478+
SmallVector<VPValue *, 2> NotMasks;
3479+
for (VPValue *Op : LastActiveL->operands()) {
3480+
VPValue *NotMask = Builder.createNot(Op, LastActiveL->getDebugLoc());
3481+
NotMasks.push_back(NotMask);
3482+
}
3483+
3484+
// Create FirstActiveLane on the inverted masks.
3485+
VPValue *FirstInactiveLane = Builder.createNaryOp(
3486+
VPInstruction::FirstActiveLane, NotMasks,
3487+
LastActiveL->getDebugLoc(), "first.inactive.lane");
3488+
3489+
// Subtract 1 to get the last active lane.
3490+
VPValue *One = Plan.getOrAddLiveIn(
3491+
ConstantInt::get(Type::getInt64Ty(Plan.getContext()), 1));
3492+
VPValue *LastLane = Builder.createNaryOp(
3493+
Instruction::Sub, {FirstInactiveLane, One},
3494+
LastActiveL->getDebugLoc(), "last.active.lane");
3495+
3496+
LastActiveL->replaceAllUsesWith(LastLane);
3497+
ToRemove.push_back(LastActiveL);
3498+
continue;
3499+
}
3500+
34483501
VPValue *VectorStep;
34493502
VPValue *ScalarStep;
34503503
if (!match(&R, m_VPInstruction<VPInstruction::WideIVStep>(

llvm/lib/Transforms/Vectorize/VPlanUnroll.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
352352
VPValue *Op1;
353353
if (match(&R, m_VPInstruction<VPInstruction::AnyOf>(m_VPValue(Op1))) ||
354354
match(&R, m_FirstActiveLane(m_VPValue(Op1))) ||
355+
match(&R, m_LastActiveLane(m_VPValue(Op1))) ||
355356
match(&R, m_VPInstruction<VPInstruction::ComputeAnyOfResult>(
356357
m_VPValue(), m_VPValue(), m_VPValue(Op1))) ||
357358
match(&R, m_VPInstruction<VPInstruction::ComputeReductionResult>(
@@ -364,17 +365,21 @@ void UnrollState::unrollBlock(VPBlockBase *VPB) {
364365
continue;
365366
}
366367
VPValue *Op0;
367-
if (match(&R, m_VPInstruction<VPInstruction::ExtractLane>(
368-
m_VPValue(Op0), m_VPValue(Op1)))) {
368+
if (match(&R, m_ExtractLane(m_VPValue(Op0), m_VPValue(Op1)))) {
369369
addUniformForAllParts(cast<VPInstruction>(&R));
370370
for (unsigned Part = 1; Part != UF; ++Part)
371371
R.addOperand(getValueForPart(Op1, Part));
372372
continue;
373373
}
374374
if (match(&R, m_ExtractLastElement(m_VPValue(Op0))) ||
375-
match(&R, m_VPInstruction<VPInstruction::ExtractPenultimateElement>(
376-
m_VPValue(Op0)))) {
375+
match(&R, m_ExtractPenultimateElement(m_VPValue(Op0)))) {
377376
addUniformForAllParts(cast<VPSingleDefRecipe>(&R));
377+
if (isa<VPFirstOrderRecurrencePHIRecipe>(Op0)) {
378+
assert(match(&R, m_ExtractLastElement(m_VPValue())) &&
379+
"can only extract last element of FOR");
380+
continue;
381+
}
382+
378383
if (Plan.hasScalarVFOnly()) {
379384
auto *I = cast<VPInstruction>(&R);
380385
// Extracting from end with VF = 1 implies retrieving the last or

0 commit comments

Comments
 (0)