Skip to content

Commit ca4f6c9

Browse files
committed
[LV] Vectorize conditional scalar assignments
Based on Michael Maitland's previous work: #121222 This PR uses the existing recurrences code instead of introducing a new pass just for CSA autovec. I've also made recipes that are more generic. I've enabled it by default to see the impact on tests; if there are regressions we can put it behind a cli option.
1 parent b6bbc4b commit ca4f6c9

19 files changed

+1958
-257
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ enum class RecurKind {
7070
FindLastIVUMax, ///< FindLast reduction with select(cmp(),x,y) where one of
7171
///< (x,y) is increasing loop induction, and both x and y
7272
///< are integer type, producing a UMax reduction.
73+
FindLast, ///< FindLast reduction with select(cmp(),x,y) where x and y
74+
///< are an integer type, one is the current recurrence value,
75+
///< and the other is an arbitrary value.
7376
// clang-format on
7477
// TODO: Any_of and FindLast reduction need not be restricted to integer type
7578
// only.
@@ -183,6 +186,12 @@ class RecurrenceDescriptor {
183186
PHINode *OrigPhi, Instruction *I,
184187
ScalarEvolution &SE);
185188

189+
/// Returns a struct describing whether the instruction is of the form
190+
/// Select(Cmp(A, B), X, Y)
191+
/// where one of (X, Y) is the Phi value and the other is an arbitrary value.
192+
LLVM_ABI static InstDesc isFindLastPattern(Instruction *I, PHINode *Phi,
193+
Loop *TheLoop);
194+
186195
/// Returns a struct describing if the instruction is a
187196
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
188197
LLVM_ABI static InstDesc isConditionalRdxPattern(Instruction *I);
@@ -305,6 +314,12 @@ class RecurrenceDescriptor {
305314
isFindLastIVRecurrenceKind(Kind);
306315
}
307316

317+
/// Returns true if the recurrence kind is of the form
318+
/// select(cmp(),x,y) where one of (x,y) is an arbitrary value.
319+
static bool isFindLastRecurrenceKind(RecurKind Kind) {
320+
return Kind == RecurKind::FindLast;
321+
}
322+
308323
/// Returns the type of the recurrence. This type can be narrower than the
309324
/// actual type of the Phi if the recurrence has been type-promoted.
310325
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5656
case RecurKind::FindFirstIVUMin:
5757
case RecurKind::FindLastIVSMax:
5858
case RecurKind::FindLastIVUMax:
59+
// TODO: Make type-agnostic.
60+
case RecurKind::FindLast:
5961
return true;
6062
}
6163
return false;
@@ -712,6 +714,27 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
712714
m_Value(NonRdxPhi)))))
713715
return InstDesc(false, I);
714716

717+
if (isFindLastRecurrenceKind(Kind)) {
718+
// Must be an integer scalar.
719+
Type *Type = OrigPhi->getType();
720+
if (!Type->isIntegerTy() && !Type->isPointerTy())
721+
return InstDesc(false, I);
722+
723+
// FIXME: Support more complex patterns, including multiple selects.
724+
// Phi or Select must be used only outside the loop,
725+
// except for each other.
726+
if (!all_of(I->users(), [OrigPhi, TheLoop](User *U) {
727+
if (U == OrigPhi)
728+
return true;
729+
if (auto *UI = dyn_cast<Instruction>(U))
730+
return !TheLoop->contains(UI);
731+
return false;
732+
}))
733+
return InstDesc(false, I);
734+
735+
return InstDesc(I, RecurKind::FindLast);
736+
}
737+
715738
// Returns either FindFirstIV/FindLastIV, if such a pattern is found, or
716739
// std::nullopt.
717740
auto GetRecurKind = [&](Value *V) -> std::optional<RecurKind> {
@@ -920,7 +943,7 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
920943
Kind == RecurKind::Add || Kind == RecurKind::Mul ||
921944
Kind == RecurKind::Sub || Kind == RecurKind::AddChainWithSubs)
922945
return isConditionalRdxPattern(I);
923-
if (isFindIVRecurrenceKind(Kind) && SE)
946+
if ((isFindIVRecurrenceKind(Kind) || isFindLastRecurrenceKind(Kind)) && SE)
924947
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
925948
[[fallthrough]];
926949
case Instruction::FCmp:
@@ -1118,7 +1141,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
11181141
<< "\n");
11191142
return true;
11201143
}
1121-
1144+
if (AddReductionVar(Phi, RecurKind::FindLast, TheLoop, FMF, RedDes, DB, AC,
1145+
DT, SE)) {
1146+
LLVM_DEBUG(dbgs() << "Found a FindLast reduction PHI." << *Phi << "\n");
1147+
return true;
1148+
}
11221149
// Not a reduction of known type.
11231150
return false;
11241151
}
@@ -1248,6 +1275,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12481275
case RecurKind::FMaximumNum:
12491276
case RecurKind::FMinimumNum:
12501277
return Instruction::FCmp;
1278+
case RecurKind::FindLast:
1279+
return Instruction::Select;
12511280
default:
12521281
llvm_unreachable("Unknown recurrence operation");
12531282
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
10321032
}
10331033
break;
10341034
}
1035+
case Intrinsic::experimental_vector_extract_last_active:
1036+
if (ST->isSVEAvailable()) {
1037+
auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]);
1038+
// This should turn into chained clastb instructions.
1039+
return LegalCost;
1040+
}
1041+
break;
10351042
default:
10361043
break;
10371044
}
@@ -5366,6 +5373,7 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
53665373
case RecurKind::FMax:
53675374
case RecurKind::FMulAdd:
53685375
case RecurKind::AnyOf:
5376+
case RecurKind::FindLast:
53695377
return true;
53705378
default:
53715379
return false;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4052,6 +4052,7 @@ static bool willGenerateVectors(VPlan &Plan, ElementCount VF,
40524052
case VPDef::VPWidenIntrinsicSC:
40534053
case VPDef::VPWidenSC:
40544054
case VPDef::VPWidenSelectSC:
4055+
case VPDef::VPWidenSelectVectorSC:
40554056
case VPDef::VPBlendSC:
40564057
case VPDef::VPFirstOrderRecurrencePHISC:
40574058
case VPDef::VPHistogramSC:
@@ -4559,6 +4560,12 @@ LoopVectorizationPlanner::selectInterleaveCount(VPlan &Plan, ElementCount VF,
45594560
any_of(Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis(),
45604561
IsaPred<VPReductionPHIRecipe>);
45614562

4563+
// FIXME: implement interleaving for FindLast transform correctly.
4564+
for (auto &[_, RdxDesc] : Legal->getReductionVars())
4565+
if (RecurrenceDescriptor::isFindLastRecurrenceKind(
4566+
RdxDesc.getRecurrenceKind()))
4567+
return 1;
4568+
45624569
// If we did not calculate the cost for VF (because the user selected the VF)
45634570
// then we calculate the cost of VF here.
45644571
if (LoopCost == 0) {
@@ -8472,6 +8479,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84728479
*Plan, Builder))
84738480
return nullptr;
84748481

8482+
// Create whole-vector selects for find-last recurrences.
8483+
VPlanTransforms::runPass(VPlanTransforms::convertFindLastRecurrences, *Plan,
8484+
RecipeBuilder, Legal);
8485+
84758486
if (useActiveLaneMask(Style)) {
84768487
// TODO: Move checks to VPlanTransforms::addActiveLaneMask once
84778488
// TailFoldingStyle is visible there.
@@ -8566,6 +8577,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
85668577

85678578
RecurKind Kind = PhiR->getRecurrenceKind();
85688579
assert(
8580+
!RecurrenceDescriptor::isFindLastRecurrenceKind(Kind) &&
85698581
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
85708582
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
85718583
"AnyOf and FindIV reductions are not allowed for in-loop reductions");
@@ -8774,6 +8786,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
87748786
FinalReductionResult =
87758787
Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
87768788
{PhiR, Start, NewExitingVPV}, ExitDL);
8789+
} else if (RecurrenceDescriptor::isFindLastRecurrenceKind(
8790+
RdxDesc.getRecurrenceKind())) {
8791+
FinalReductionResult = Builder.createNaryOp(
8792+
VPInstruction::ExtractLastActive, {NewExitingVPV}, ExitDL);
87778793
} else {
87788794
VPIRFlags Flags =
87798795
RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind)
@@ -8869,7 +8885,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
88698885
RecurKind RK = RdxDesc.getRecurrenceKind();
88708886
if ((!RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) &&
88718887
!RecurrenceDescriptor::isFindIVRecurrenceKind(RK) &&
8872-
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))) {
8888+
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) &&
8889+
!RecurrenceDescriptor::isFindLastRecurrenceKind(RK))) {
88738890
VPBuilder PHBuilder(Plan->getVectorPreheader());
88748891
VPValue *Iden = Plan->getOrAddLiveIn(
88758892
getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags()));
@@ -9993,6 +10010,22 @@ bool LoopVectorizePass::processLoop(Loop *L) {
999310010
// Override IC if user provided an interleave count.
999410011
IC = UserIC > 0 ? UserIC : IC;
999510012

10013+
// FIXME: Enable interleaving for last_active reductions.
10014+
if (any_of(LVL.getReductionVars(), [&](auto &Reduction) -> bool {
10015+
const RecurrenceDescriptor &RdxDesc = Reduction.second;
10016+
return RecurrenceDescriptor::isFindLastRecurrenceKind(
10017+
RdxDesc.getRecurrenceKind());
10018+
})) {
10019+
LLVM_DEBUG(dbgs() << "LV: Not interleaving without vectorization due "
10020+
<< "to conditional scalar assignments.\n");
10021+
IntDiagMsg = {
10022+
"ConditionalAssignmentPreventsScalarInterleaving",
10023+
"Unable to interleave without vectorization due to conditional "
10024+
"assignments"};
10025+
InterleaveLoop = false;
10026+
IC = 1;
10027+
}
10028+
999610029
// Emit diagnostic messages, if any.
999710030
const char *VAPassName = Hints.vectorizeAnalysisPassName();
999810031
if (!VectorizeLoop && !InterleaveLoop) {

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25103,6 +25103,7 @@ class HorizontalReduction {
2510325103
case RecurKind::FindFirstIVUMin:
2510425104
case RecurKind::FindLastIVSMax:
2510525105
case RecurKind::FindLastIVUMax:
25106+
case RecurKind::FindLast:
2510625107
case RecurKind::FMaxNum:
2510725108
case RecurKind::FMinNum:
2510825109
case RecurKind::FMaximumNum:
@@ -25244,6 +25245,7 @@ class HorizontalReduction {
2524425245
case RecurKind::FindFirstIVUMin:
2524525246
case RecurKind::FindLastIVSMax:
2524625247
case RecurKind::FindLastIVUMax:
25248+
case RecurKind::FindLast:
2524725249
case RecurKind::FMaxNum:
2524825250
case RecurKind::FMinNum:
2524925251
case RecurKind::FMaximumNum:
@@ -25350,6 +25352,7 @@ class HorizontalReduction {
2535025352
case RecurKind::FindFirstIVUMin:
2535125353
case RecurKind::FindLastIVSMax:
2535225354
case RecurKind::FindLastIVUMax:
25355+
case RecurKind::FindLast:
2535325356
case RecurKind::FMaxNum:
2535425357
case RecurKind::FMinNum:
2535525358
case RecurKind::FMaximumNum:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
544544
case VPRecipeBase::VPWidenIntrinsicSC:
545545
case VPRecipeBase::VPWidenSC:
546546
case VPRecipeBase::VPWidenSelectSC:
547+
case VPRecipeBase::VPWidenSelectVectorSC:
547548
case VPRecipeBase::VPBlendSC:
548549
case VPRecipeBase::VPPredInstPHISC:
549550
case VPRecipeBase::VPCanonicalIVPHISC:
@@ -1067,6 +1068,8 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10671068
/// Returns the value for vscale.
10681069
VScale,
10691070
OpsEnd = VScale,
1071+
/// Extracts the last active lane based on a predicate vector operand.
1072+
ExtractLastActive,
10701073
};
10711074

10721075
/// Returns true if this VPInstruction generates scalar values for all lanes.
@@ -1769,6 +1772,47 @@ struct LLVM_ABI_FOR_TEST VPWidenSelectRecipe : public VPRecipeWithIRFlags,
17691772
}
17701773
};
17711774

1775+
/// A recipe for selecting whole vector values.
1776+
struct VPWidenSelectVectorRecipe : public VPRecipeWithIRFlags {
1777+
VPWidenSelectVectorRecipe(ArrayRef<VPValue *> Operands)
1778+
: VPRecipeWithIRFlags(VPDef::VPWidenSelectVectorSC, Operands) {}
1779+
1780+
~VPWidenSelectVectorRecipe() override = default;
1781+
1782+
VPWidenSelectVectorRecipe *clone() override {
1783+
SmallVector<VPValue *, 3> Operands(operands());
1784+
return new VPWidenSelectVectorRecipe(Operands);
1785+
}
1786+
1787+
VP_CLASSOF_IMPL(VPDef::VPWidenSelectVectorSC)
1788+
1789+
/// Produce a widened version of the select instruction.
1790+
void execute(VPTransformState &State) override;
1791+
1792+
/// Return the cost of this VPWidenSelectVectorRecipe.
1793+
InstructionCost computeCost(ElementCount VF,
1794+
VPCostContext &Ctx) const override;
1795+
1796+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1797+
/// Print the recipe.
1798+
void print(raw_ostream &O, const Twine &Indent,
1799+
VPSlotTracker &SlotTracker) const override;
1800+
#endif
1801+
1802+
VPValue *getCond() const { return getOperand(0); }
1803+
1804+
bool isInvariantCond() const {
1805+
return getCond()->isDefinedOutsideLoopRegions();
1806+
}
1807+
1808+
/// Returns true if the recipe only uses the first lane of operand \p Op.
1809+
bool onlyFirstLaneUsed(const VPValue *Op) const override {
1810+
assert(is_contained(operands(), Op) &&
1811+
"Op must be an operand of the recipe");
1812+
return Op == getCond() && isInvariantCond();
1813+
}
1814+
};
1815+
17721816
/// A recipe for handling GEP instructions.
17731817
class LLVM_ABI_FOR_TEST VPWidenGEPRecipe : public VPRecipeWithIRFlags {
17741818
Type *SourceElementTy;

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
118118
return Type::getIntNTy(Ctx, 64);
119119
case VPInstruction::ExtractLastElement:
120120
case VPInstruction::ExtractLastLanePerPart:
121-
case VPInstruction::ExtractPenultimateElement: {
121+
case VPInstruction::ExtractPenultimateElement:
122+
case VPInstruction::ExtractLastActive: {
122123
Type *BaseTy = inferScalarType(R->getOperand(0));
123124
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
124125
return VecTy->getElementType();
@@ -311,7 +312,11 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
311312
})
312313
.Case<VPExpressionRecipe>([this](const auto *R) {
313314
return inferScalarType(R->getOperandOfResultType());
314-
});
315+
})
316+
.Case<VPWidenSelectVectorRecipe>(
317+
[this](const VPWidenSelectVectorRecipe *R) {
318+
return inferScalarType(R->getOperand(1));
319+
});
315320

316321
assert(ResultTy && "could not infer type for the given VPValue");
317322
CachedTypes[V] = ResultTy;

0 commit comments

Comments
 (0)