Skip to content

Commit 95c7af6

Browse files
committed
[LV] Vectorize conditional scalar assignments
Based on Michael Maitland's previous work: #121222 This PR uses the existing recurrences code instead of introducing a new pass just for CSA autovec. I've also made recipes that are more generic. I've enabled it by default to see the impact on tests; if there are regressions we can put it behind a cli option.
1 parent e1aa2df commit 95c7af6

19 files changed

+2027
-244
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ enum class RecurKind {
7070
FindLastIVUMax, ///< FindLast reduction with select(cmp(),x,y) where one of
7171
///< (x,y) is increasing loop induction, and both x and y
7272
///< are integer type, producing a UMax reduction.
73+
FindLast, ///< FindLast reduction with select(cmp(),x,y) where x and y
74+
///< can be any scalar type, one is the current recurrence
75+
///< value, and the other is an arbitrary value.
7376
// clang-format on
7477
// TODO: Any_of and FindLast reduction need not be restricted to integer type
7578
// only.
@@ -183,6 +186,12 @@ class RecurrenceDescriptor {
183186
PHINode *OrigPhi, Instruction *I,
184187
ScalarEvolution &SE);
185188

189+
/// Returns a struct describing whether the instruction is of the form
190+
/// Select(Cmp(A, B), X, Y)
191+
/// where one of (X, Y) is the Phi value and the other is an arbitrary value.
192+
LLVM_ABI static InstDesc isFindLastPattern(Instruction *I, PHINode *Phi,
193+
Loop *TheLoop);
194+
186195
/// Returns a struct describing if the instruction is a
187196
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
188197
LLVM_ABI static InstDesc isConditionalRdxPattern(Instruction *I);
@@ -299,6 +308,12 @@ class RecurrenceDescriptor {
299308
isFindLastIVRecurrenceKind(Kind);
300309
}
301310

311+
/// Returns true if the recurrence kind is of the form
312+
/// select(cmp(),x,y) where one of (x,y) is an arbitrary value.
313+
static bool isFindLastRecurrenceKind(RecurKind Kind) {
314+
return Kind == RecurKind::FindLast;
315+
}
316+
302317
/// Returns the type of the recurrence. This type can be narrower than the
303318
/// actual type of the Phi if the recurrence has been type-promoted.
304319
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5656
case RecurKind::FindFirstIVUMin:
5757
case RecurKind::FindLastIVSMax:
5858
case RecurKind::FindLastIVUMax:
59+
// TODO: Make type-agnostic.
60+
case RecurKind::FindLast:
5961
return true;
6062
}
6163
return false;
@@ -426,6 +428,8 @@ bool RecurrenceDescriptor::AddReductionVar(
426428
++NumCmpSelectPatternInst;
427429
if (isAnyOfRecurrenceKind(Kind) && IsASelect)
428430
++NumCmpSelectPatternInst;
431+
if (isFindLastRecurrenceKind(Kind) && IsASelect)
432+
++NumCmpSelectPatternInst;
429433

430434
// Check whether we found a reduction operator.
431435
FoundReduxOp |= !IsAPhi && Cur != Start;
@@ -789,6 +793,38 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
789793
return InstDesc(false, I);
790794
}
791795

796+
RecurrenceDescriptor::InstDesc
797+
RecurrenceDescriptor::isFindLastPattern(Instruction *I, PHINode *Phi,
798+
Loop *TheLoop) {
799+
// Must be a scalar.
800+
Type *Type = Phi->getType();
801+
if (!Type->isIntegerTy() && !Type->isFloatingPointTy() &&
802+
!Type->isPointerTy())
803+
return InstDesc(false, I);
804+
805+
SelectInst *Select = dyn_cast<SelectInst>(I);
806+
if (!Select)
807+
return InstDesc(false, I);
808+
809+
// FIXME: Support more complex patterns, including multiple selects.
810+
// Phi or Select must be used only outside the loop,
811+
// except for each other.
812+
auto IsOnlyUsedOutsideLoop = [&](Value *V, Value *Ignore) {
813+
return all_of(V->users(), [Ignore, TheLoop](User *U) {
814+
if (U == Ignore)
815+
return true;
816+
if (auto *I = dyn_cast<Instruction>(U))
817+
return !TheLoop->contains(I);
818+
return false;
819+
});
820+
};
821+
if (!IsOnlyUsedOutsideLoop(Phi, Select) ||
822+
!IsOnlyUsedOutsideLoop(Select, Phi))
823+
return InstDesc(false, I);
824+
825+
return InstDesc(I, RecurKind::FindLast);
826+
}
827+
792828
RecurrenceDescriptor::InstDesc
793829
RecurrenceDescriptor::isMinMaxPattern(Instruction *I, RecurKind Kind,
794830
const InstDesc &Prev) {
@@ -927,6 +963,8 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
927963
return isConditionalRdxPattern(I);
928964
if (isFindIVRecurrenceKind(Kind) && SE)
929965
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
966+
if (isFindLastRecurrenceKind(Kind))
967+
return isFindLastPattern(I, OrigPhi, L);
930968
[[fallthrough]];
931969
case Instruction::FCmp:
932970
case Instruction::ICmp:
@@ -1123,7 +1161,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
11231161
<< "\n");
11241162
return true;
11251163
}
1126-
1164+
if (AddReductionVar(Phi, RecurKind::FindLast, TheLoop, FMF, RedDes, DB, AC, DT,
1165+
SE)) {
1166+
LLVM_DEBUG(dbgs() << "Found a FindLast reduction PHI." << *Phi << "\n");
1167+
return true;
1168+
}
11271169
// Not a reduction of known type.
11281170
return false;
11291171
}
@@ -1245,6 +1287,7 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12451287
case RecurKind::SMin:
12461288
case RecurKind::UMax:
12471289
case RecurKind::UMin:
1290+
case RecurKind::FindLast:
12481291
return Instruction::ICmp;
12491292
case RecurKind::FMax:
12501293
case RecurKind::FMin:

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
10041004
}
10051005
break;
10061006
}
1007+
case Intrinsic::experimental_vector_extract_last_active:
1008+
if (ST->isSVEAvailable()) {
1009+
auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]);
1010+
// This should turn into chained clastb instructions.
1011+
return LegalCost;
1012+
}
1013+
break;
10071014
default:
10081015
break;
10091016
}
@@ -5325,6 +5332,7 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
53255332
case RecurKind::FMax:
53265333
case RecurKind::FMulAdd:
53275334
case RecurKind::AnyOf:
5335+
case RecurKind::FindLast:
53285336
return true;
53295337
default:
53305338
return false;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4047,6 +4047,7 @@ static bool willGenerateVectors(VPlan &Plan, ElementCount VF,
40474047
case VPDef::VPWidenIntrinsicSC:
40484048
case VPDef::VPWidenSC:
40494049
case VPDef::VPWidenSelectSC:
4050+
case VPDef::VPWidenSelectVectorSC:
40504051
case VPDef::VPBlendSC:
40514052
case VPDef::VPFirstOrderRecurrencePHISC:
40524053
case VPDef::VPHistogramSC:
@@ -4546,6 +4547,11 @@ LoopVectorizationPlanner::selectInterleaveCount(VPlan &Plan, ElementCount VF,
45464547
any_of(Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis(),
45474548
IsaPred<VPReductionPHIRecipe>);
45484549

4550+
// FIXME: implement interleaving for FindLast transform correctly.
4551+
for (auto &[_, RdxDesc] : Legal->getReductionVars())
4552+
if (RecurrenceDescriptor::isFindLastRecurrenceKind(RdxDesc.getRecurrenceKind()))
4553+
return 1;
4554+
45494555
// If we did not calculate the cost for VF (because the user selected the VF)
45504556
// then we calculate the cost of VF here.
45514557
if (LoopCost == 0) {
@@ -8687,6 +8693,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
86878693
*Plan, Builder))
86888694
return nullptr;
86898695

8696+
// Create whole-vector selects for find-last recurrences.
8697+
VPlanTransforms::runPass(VPlanTransforms::convertFindLastRecurrences,
8698+
*Plan, RecipeBuilder, Legal);
8699+
86908700
if (useActiveLaneMask(Style)) {
86918701
// TODO: Move checks to VPlanTransforms::addActiveLaneMask once
86928702
// TailFoldingStyle is visible there.
@@ -8779,6 +8789,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
87798789

87808790
RecurKind Kind = PhiR->getRecurrenceKind();
87818791
assert(
8792+
!RecurrenceDescriptor::isFindLastRecurrenceKind(Kind) &&
87828793
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
87838794
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
87848795
"AnyOf and FindIV reductions are not allowed for in-loop reductions");
@@ -8987,6 +8998,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
89878998
FinalReductionResult =
89888999
Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
89899000
{PhiR, Start, NewExitingVPV}, ExitDL);
9001+
} else if (RecurrenceDescriptor::isFindLastRecurrenceKind(
9002+
RdxDesc.getRecurrenceKind())) {
9003+
FinalReductionResult = Builder.createNaryOp(
9004+
VPInstruction::ExtractLastActive, {NewExitingVPV}, ExitDL);
89909005
} else {
89919006
VPIRFlags Flags =
89929007
RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind)
@@ -9076,7 +9091,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
90769091
RecurKind RK = RdxDesc.getRecurrenceKind();
90779092
if ((!RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) &&
90789093
!RecurrenceDescriptor::isFindIVRecurrenceKind(RK) &&
9079-
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))) {
9094+
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) &&
9095+
!RecurrenceDescriptor::isFindLastRecurrenceKind(RK))) {
90809096
VPBuilder PHBuilder(Plan->getVectorPreheader());
90819097
VPValue *Iden = Plan->getOrAddLiveIn(
90829098
getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags()));
@@ -10069,6 +10085,21 @@ bool LoopVectorizePass::processLoop(Loop *L) {
1006910085
// Override IC if user provided an interleave count.
1007010086
IC = UserIC > 0 ? UserIC : IC;
1007110087

10088+
// FIXME: Enable interleaving for last_active reductions.
10089+
if (any_of(LVL.getReductionVars(), [&](auto &Reduction) -> bool {
10090+
const RecurrenceDescriptor &RdxDesc = Reduction.second;
10091+
return RecurrenceDescriptor::isFindLastRecurrenceKind(RdxDesc.getRecurrenceKind());
10092+
})) {
10093+
LLVM_DEBUG(dbgs() << "LV: Not interleaving without vectorization due "
10094+
<< "to conditional scalar assignments.\n");
10095+
IntDiagMsg = {
10096+
"ConditionalAssignmentPreventsScalarInterleaving",
10097+
"Unable to interleave without vectorization due to conditional "
10098+
"assignments"};
10099+
InterleaveLoop = false;
10100+
IC = 1;
10101+
}
10102+
1007210103
// Emit diagnostic messages, if any.
1007310104
const char *VAPassName = Hints.vectorizeAnalysisPassName();
1007410105
if (!VectorizeLoop && !InterleaveLoop) {

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24868,6 +24868,7 @@ class HorizontalReduction {
2486824868
case RecurKind::FindFirstIVUMin:
2486924869
case RecurKind::FindLastIVSMax:
2487024870
case RecurKind::FindLastIVUMax:
24871+
case RecurKind::FindLast:
2487124872
case RecurKind::FMaxNum:
2487224873
case RecurKind::FMinNum:
2487324874
case RecurKind::FMaximumNum:
@@ -25009,6 +25010,7 @@ class HorizontalReduction {
2500925010
case RecurKind::FindFirstIVUMin:
2501025011
case RecurKind::FindLastIVSMax:
2501125012
case RecurKind::FindLastIVUMax:
25013+
case RecurKind::FindLast:
2501225014
case RecurKind::FMaxNum:
2501325015
case RecurKind::FMinNum:
2501425016
case RecurKind::FMaximumNum:
@@ -25115,6 +25117,7 @@ class HorizontalReduction {
2511525117
case RecurKind::FindFirstIVUMin:
2511625118
case RecurKind::FindLastIVSMax:
2511725119
case RecurKind::FindLastIVUMax:
25120+
case RecurKind::FindLast:
2511825121
case RecurKind::FMaxNum:
2511925122
case RecurKind::FMinNum:
2512025123
case RecurKind::FMaximumNum:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
548548
case VPRecipeBase::VPWidenIntrinsicSC:
549549
case VPRecipeBase::VPWidenSC:
550550
case VPRecipeBase::VPWidenSelectSC:
551+
case VPRecipeBase::VPWidenSelectVectorSC:
551552
case VPRecipeBase::VPBlendSC:
552553
case VPRecipeBase::VPPredInstPHISC:
553554
case VPRecipeBase::VPCanonicalIVPHISC:
@@ -1059,6 +1060,8 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10591060
ResumeForEpilogue,
10601061
/// Returns the value for vscale.
10611062
VScale,
1063+
// Extracts the last active lane based on a predicate vector operand.
1064+
ExtractLastActive,
10621065
};
10631066

10641067
private:
@@ -1749,6 +1752,47 @@ struct LLVM_ABI_FOR_TEST VPWidenSelectRecipe : public VPRecipeWithIRFlags,
17491752

17501753
unsigned getOpcode() const { return Instruction::Select; }
17511754

1755+
VPValue *getCond() const { return getOperand(0); }
1756+
1757+
bool isInvariantCond() const {
1758+
return getCond()->isDefinedOutsideLoopRegions();
1759+
}
1760+
1761+
/// Returns true if the recipe only uses the first lane of operand \p Op.
1762+
bool onlyFirstLaneUsed(const VPValue *Op) const override {
1763+
assert(is_contained(operands(), Op) &&
1764+
"Op must be an operand of the recipe");
1765+
return Op == getCond() && isInvariantCond();
1766+
}
1767+
};
1768+
1769+
/// A recipe for selecting whole vector values.
1770+
struct VPWidenSelectVectorRecipe : public VPRecipeWithIRFlags {
1771+
VPWidenSelectVectorRecipe(ArrayRef<VPValue *> Operands)
1772+
: VPRecipeWithIRFlags(VPDef::VPWidenSelectVectorSC, Operands) {}
1773+
1774+
~VPWidenSelectVectorRecipe() override = default;
1775+
1776+
VPWidenSelectVectorRecipe *clone() override {
1777+
SmallVector<VPValue *, 3> Operands(operands());
1778+
return new VPWidenSelectVectorRecipe(Operands);
1779+
}
1780+
1781+
VP_CLASSOF_IMPL(VPDef::VPWidenSelectVectorSC)
1782+
1783+
/// Produce a widened version of the select instruction.
1784+
void execute(VPTransformState &State) override;
1785+
1786+
/// Return the cost of this VPWidenSelectVectorRecipe.
1787+
InstructionCost computeCost(ElementCount VF,
1788+
VPCostContext &Ctx) const override;
1789+
1790+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1791+
/// Print the recipe.
1792+
void print(raw_ostream &O, const Twine &Indent,
1793+
VPSlotTracker &SlotTracker) const override;
1794+
#endif
1795+
17521796
VPValue *getCond() const {
17531797
return getOperand(0);
17541798
}

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
115115
case VPInstruction::FirstActiveLane:
116116
return Type::getIntNTy(Ctx, 64);
117117
case VPInstruction::ExtractLastElement:
118-
case VPInstruction::ExtractPenultimateElement: {
118+
case VPInstruction::ExtractPenultimateElement:
119+
case VPInstruction::ExtractLastActive: {
119120
Type *BaseTy = inferScalarType(R->getOperand(0));
120121
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
121122
return VecTy->getElementType();
@@ -308,7 +309,11 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
308309
})
309310
.Case<VPExpressionRecipe>([this](const auto *R) {
310311
return inferScalarType(R->getOperandOfResultType());
311-
});
312+
})
313+
.Case<VPWidenSelectVectorRecipe>(
314+
[this](const VPWidenSelectVectorRecipe *R) {
315+
return inferScalarType(R->getOperand(1));
316+
});
312317

313318
assert(ResultTy && "could not infer type for the given VPValue");
314319
CachedTypes[V] = ResultTy;

0 commit comments

Comments
 (0)