Skip to content

Commit 1ff7778

Browse files
committed
[LV] Vectorize conditional scalar assignments
Based on Michael Maitland's previous work: #121222 This PR uses the existing recurrences code instead of introducing a new pass just for CSA autovec. I've also made recipes that are more generic. I've enabled it by default to see the impact on tests; if there are regressions we can put it behind a cli option.
1 parent 683e2bf commit 1ff7778

19 files changed

+1958
-257
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ enum class RecurKind {
7070
FindLastIVUMax, ///< FindLast reduction with select(cmp(),x,y) where one of
7171
///< (x,y) is increasing loop induction, and both x and y
7272
///< are integer type, producing a UMax reduction.
73+
FindLast, ///< FindLast reduction with select(cmp(),x,y) where x and y
74+
///< are an integer type, one is the current recurrence value,
75+
///< and the other is an arbitrary value.
7376
// clang-format on
7477
// TODO: Any_of and FindLast reduction need not be restricted to integer type
7578
// only.
@@ -183,6 +186,12 @@ class RecurrenceDescriptor {
183186
PHINode *OrigPhi, Instruction *I,
184187
ScalarEvolution &SE);
185188

189+
/// Returns a struct describing whether the instruction is of the form
190+
/// Select(Cmp(A, B), X, Y)
191+
/// where one of (X, Y) is the Phi value and the other is an arbitrary value.
192+
LLVM_ABI static InstDesc isFindLastPattern(Instruction *I, PHINode *Phi,
193+
Loop *TheLoop);
194+
186195
/// Returns a struct describing if the instruction is a
187196
/// Select(FCmp(X, Y), (Z = X op PHINode), PHINode) instruction pattern.
188197
LLVM_ABI static InstDesc isConditionalRdxPattern(Instruction *I);
@@ -305,6 +314,12 @@ class RecurrenceDescriptor {
305314
isFindLastIVRecurrenceKind(Kind);
306315
}
307316

317+
/// Returns true if the recurrence kind is of the form
318+
/// select(cmp(),x,y) where one of (x,y) is an arbitrary value.
319+
static bool isFindLastRecurrenceKind(RecurKind Kind) {
320+
return Kind == RecurKind::FindLast;
321+
}
322+
308323
/// Returns the type of the recurrence. This type can be narrower than the
309324
/// actual type of the Phi if the recurrence has been type-promoted.
310325
Type *getRecurrenceType() const { return RecurrenceType; }

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurKind Kind) {
5656
case RecurKind::FindFirstIVUMin:
5757
case RecurKind::FindLastIVSMax:
5858
case RecurKind::FindLastIVUMax:
59+
// TODO: Make type-agnostic.
60+
case RecurKind::FindLast:
5961
return true;
6062
}
6163
return false;
@@ -712,6 +714,27 @@ RecurrenceDescriptor::isFindIVPattern(RecurKind Kind, Loop *TheLoop,
712714
m_Value(NonRdxPhi)))))
713715
return InstDesc(false, I);
714716

717+
if (isFindLastRecurrenceKind(Kind)) {
718+
// Must be an integer scalar.
719+
Type *Type = OrigPhi->getType();
720+
if (!Type->isIntegerTy() && !Type->isPointerTy())
721+
return InstDesc(false, I);
722+
723+
// FIXME: Support more complex patterns, including multiple selects.
724+
// Phi or Select must be used only outside the loop,
725+
// except for each other.
726+
if (!all_of(I->users(), [OrigPhi, TheLoop](User *U) {
727+
if (U == OrigPhi)
728+
return true;
729+
if (auto *UI = dyn_cast<Instruction>(U))
730+
return !TheLoop->contains(UI);
731+
return false;
732+
}))
733+
return InstDesc(false, I);
734+
735+
return InstDesc(I, RecurKind::FindLast);
736+
}
737+
715738
// Returns either FindFirstIV/FindLastIV, if such a pattern is found, or
716739
// std::nullopt.
717740
auto GetRecurKind = [&](Value *V) -> std::optional<RecurKind> {
@@ -920,7 +943,7 @@ RecurrenceDescriptor::InstDesc RecurrenceDescriptor::isRecurrenceInstr(
920943
Kind == RecurKind::Add || Kind == RecurKind::Mul ||
921944
Kind == RecurKind::Sub || Kind == RecurKind::AddChainWithSubs)
922945
return isConditionalRdxPattern(I);
923-
if (isFindIVRecurrenceKind(Kind) && SE)
946+
if ((isFindIVRecurrenceKind(Kind) || isFindLastRecurrenceKind(Kind)) && SE)
924947
return isFindIVPattern(Kind, L, OrigPhi, I, *SE);
925948
[[fallthrough]];
926949
case Instruction::FCmp:
@@ -1118,7 +1141,11 @@ bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop,
11181141
<< "\n");
11191142
return true;
11201143
}
1121-
1144+
if (AddReductionVar(Phi, RecurKind::FindLast, TheLoop, FMF, RedDes, DB, AC,
1145+
DT, SE)) {
1146+
LLVM_DEBUG(dbgs() << "Found a FindLast reduction PHI." << *Phi << "\n");
1147+
return true;
1148+
}
11221149
// Not a reduction of known type.
11231150
return false;
11241151
}
@@ -1248,6 +1275,8 @@ unsigned RecurrenceDescriptor::getOpcode(RecurKind Kind) {
12481275
case RecurKind::FMaximumNum:
12491276
case RecurKind::FMinimumNum:
12501277
return Instruction::FCmp;
1278+
case RecurKind::FindLast:
1279+
return Instruction::Select;
12511280
default:
12521281
llvm_unreachable("Unknown recurrence operation");
12531282
}

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,13 @@ AArch64TTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
10041004
}
10051005
break;
10061006
}
1007+
case Intrinsic::experimental_vector_extract_last_active:
1008+
if (ST->isSVEAvailable()) {
1009+
auto [LegalCost, _] = getTypeLegalizationCost(ICA.getArgTypes()[0]);
1010+
// This should turn into chained clastb instructions.
1011+
return LegalCost;
1012+
}
1013+
break;
10071014
default:
10081015
break;
10091016
}
@@ -5330,6 +5337,7 @@ bool AArch64TTIImpl::isLegalToVectorizeReduction(
53305337
case RecurKind::FMax:
53315338
case RecurKind::FMulAdd:
53325339
case RecurKind::AnyOf:
5340+
case RecurKind::FindLast:
53335341
return true;
53345342
default:
53355343
return false;

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4052,6 +4052,7 @@ static bool willGenerateVectors(VPlan &Plan, ElementCount VF,
40524052
case VPDef::VPWidenIntrinsicSC:
40534053
case VPDef::VPWidenSC:
40544054
case VPDef::VPWidenSelectSC:
4055+
case VPDef::VPWidenSelectVectorSC:
40554056
case VPDef::VPBlendSC:
40564057
case VPDef::VPFirstOrderRecurrencePHISC:
40574058
case VPDef::VPHistogramSC:
@@ -4546,6 +4547,12 @@ LoopVectorizationPlanner::selectInterleaveCount(VPlan &Plan, ElementCount VF,
45464547
any_of(Plan.getVectorLoopRegion()->getEntryBasicBlock()->phis(),
45474548
IsaPred<VPReductionPHIRecipe>);
45484549

4550+
// FIXME: implement interleaving for FindLast transform correctly.
4551+
for (auto &[_, RdxDesc] : Legal->getReductionVars())
4552+
if (RecurrenceDescriptor::isFindLastRecurrenceKind(
4553+
RdxDesc.getRecurrenceKind()))
4554+
return 1;
4555+
45494556
// If we did not calculate the cost for VF (because the user selected the VF)
45504557
// then we calculate the cost of VF here.
45514558
if (LoopCost == 0) {
@@ -8459,6 +8466,10 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84598466
*Plan, Builder))
84608467
return nullptr;
84618468

8469+
// Create whole-vector selects for find-last recurrences.
8470+
VPlanTransforms::runPass(VPlanTransforms::convertFindLastRecurrences, *Plan,
8471+
RecipeBuilder, Legal);
8472+
84628473
if (useActiveLaneMask(Style)) {
84638474
// TODO: Move checks to VPlanTransforms::addActiveLaneMask once
84648475
// TailFoldingStyle is visible there.
@@ -8553,6 +8564,7 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
85538564

85548565
RecurKind Kind = PhiR->getRecurrenceKind();
85558566
assert(
8567+
!RecurrenceDescriptor::isFindLastRecurrenceKind(Kind) &&
85568568
!RecurrenceDescriptor::isAnyOfRecurrenceKind(Kind) &&
85578569
!RecurrenceDescriptor::isFindIVRecurrenceKind(Kind) &&
85588570
"AnyOf and FindIV reductions are not allowed for in-loop reductions");
@@ -8761,6 +8773,10 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
87618773
FinalReductionResult =
87628774
Builder.createNaryOp(VPInstruction::ComputeAnyOfResult,
87638775
{PhiR, Start, NewExitingVPV}, ExitDL);
8776+
} else if (RecurrenceDescriptor::isFindLastRecurrenceKind(
8777+
RdxDesc.getRecurrenceKind())) {
8778+
FinalReductionResult = Builder.createNaryOp(
8779+
VPInstruction::ExtractLastActive, {NewExitingVPV}, ExitDL);
87648780
} else {
87658781
VPIRFlags Flags =
87668782
RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind)
@@ -8856,7 +8872,8 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
88568872
RecurKind RK = RdxDesc.getRecurrenceKind();
88578873
if ((!RecurrenceDescriptor::isAnyOfRecurrenceKind(RK) &&
88588874
!RecurrenceDescriptor::isFindIVRecurrenceKind(RK) &&
8859-
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK))) {
8875+
!RecurrenceDescriptor::isMinMaxRecurrenceKind(RK) &&
8876+
!RecurrenceDescriptor::isFindLastRecurrenceKind(RK))) {
88608877
VPBuilder PHBuilder(Plan->getVectorPreheader());
88618878
VPValue *Iden = Plan->getOrAddLiveIn(
88628879
getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags()));
@@ -9970,6 +9987,22 @@ bool LoopVectorizePass::processLoop(Loop *L) {
99709987
// Override IC if user provided an interleave count.
99719988
IC = UserIC > 0 ? UserIC : IC;
99729989

9990+
// FIXME: Enable interleaving for last_active reductions.
9991+
if (any_of(LVL.getReductionVars(), [&](auto &Reduction) -> bool {
9992+
const RecurrenceDescriptor &RdxDesc = Reduction.second;
9993+
return RecurrenceDescriptor::isFindLastRecurrenceKind(
9994+
RdxDesc.getRecurrenceKind());
9995+
})) {
9996+
LLVM_DEBUG(dbgs() << "LV: Not interleaving without vectorization due "
9997+
<< "to conditional scalar assignments.\n");
9998+
IntDiagMsg = {
9999+
"ConditionalAssignmentPreventsScalarInterleaving",
10000+
"Unable to interleave without vectorization due to conditional "
10001+
"assignments"};
10002+
InterleaveLoop = false;
10003+
IC = 1;
10004+
}
10005+
997310006
// Emit diagnostic messages, if any.
997410007
const char *VAPassName = Hints.vectorizeAnalysisPassName();
997510008
if (!VectorizeLoop && !InterleaveLoop) {

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25087,6 +25087,7 @@ class HorizontalReduction {
2508725087
case RecurKind::FindFirstIVUMin:
2508825088
case RecurKind::FindLastIVSMax:
2508925089
case RecurKind::FindLastIVUMax:
25090+
case RecurKind::FindLast:
2509025091
case RecurKind::FMaxNum:
2509125092
case RecurKind::FMinNum:
2509225093
case RecurKind::FMaximumNum:
@@ -25228,6 +25229,7 @@ class HorizontalReduction {
2522825229
case RecurKind::FindFirstIVUMin:
2522925230
case RecurKind::FindLastIVSMax:
2523025231
case RecurKind::FindLastIVUMax:
25232+
case RecurKind::FindLast:
2523125233
case RecurKind::FMaxNum:
2523225234
case RecurKind::FMinNum:
2523325235
case RecurKind::FMaximumNum:
@@ -25334,6 +25336,7 @@ class HorizontalReduction {
2533425336
case RecurKind::FindFirstIVUMin:
2533525337
case RecurKind::FindLastIVSMax:
2533625338
case RecurKind::FindLastIVUMax:
25339+
case RecurKind::FindLast:
2533725340
case RecurKind::FMaxNum:
2533825341
case RecurKind::FMinNum:
2533925342
case RecurKind::FMaximumNum:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ class VPSingleDefRecipe : public VPRecipeBase, public VPValue {
544544
case VPRecipeBase::VPWidenIntrinsicSC:
545545
case VPRecipeBase::VPWidenSC:
546546
case VPRecipeBase::VPWidenSelectSC:
547+
case VPRecipeBase::VPWidenSelectVectorSC:
547548
case VPRecipeBase::VPBlendSC:
548549
case VPRecipeBase::VPPredInstPHISC:
549550
case VPRecipeBase::VPCanonicalIVPHISC:
@@ -1067,6 +1068,8 @@ class LLVM_ABI_FOR_TEST VPInstruction : public VPRecipeWithIRFlags,
10671068
/// Returns the value for vscale.
10681069
VScale,
10691070
OpsEnd = VScale,
1071+
/// Extracts the last active lane based on a predicate vector operand.
1072+
ExtractLastActive,
10701073
};
10711074

10721075
/// Returns true if this VPInstruction generates scalar values for all lanes.
@@ -1769,6 +1772,47 @@ struct LLVM_ABI_FOR_TEST VPWidenSelectRecipe : public VPRecipeWithIRFlags,
17691772
}
17701773
};
17711774

1775+
/// A recipe for selecting whole vector values.
1776+
struct VPWidenSelectVectorRecipe : public VPRecipeWithIRFlags {
1777+
VPWidenSelectVectorRecipe(ArrayRef<VPValue *> Operands)
1778+
: VPRecipeWithIRFlags(VPDef::VPWidenSelectVectorSC, Operands) {}
1779+
1780+
~VPWidenSelectVectorRecipe() override = default;
1781+
1782+
VPWidenSelectVectorRecipe *clone() override {
1783+
SmallVector<VPValue *, 3> Operands(operands());
1784+
return new VPWidenSelectVectorRecipe(Operands);
1785+
}
1786+
1787+
VP_CLASSOF_IMPL(VPDef::VPWidenSelectVectorSC)
1788+
1789+
/// Produce a widened version of the select instruction.
1790+
void execute(VPTransformState &State) override;
1791+
1792+
/// Return the cost of this VPWidenSelectVectorRecipe.
1793+
InstructionCost computeCost(ElementCount VF,
1794+
VPCostContext &Ctx) const override;
1795+
1796+
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
1797+
/// Print the recipe.
1798+
void print(raw_ostream &O, const Twine &Indent,
1799+
VPSlotTracker &SlotTracker) const override;
1800+
#endif
1801+
1802+
VPValue *getCond() const { return getOperand(0); }
1803+
1804+
bool isInvariantCond() const {
1805+
return getCond()->isDefinedOutsideLoopRegions();
1806+
}
1807+
1808+
/// Returns true if the recipe only uses the first lane of operand \p Op.
1809+
bool onlyFirstLaneUsed(const VPValue *Op) const override {
1810+
assert(is_contained(operands(), Op) &&
1811+
"Op must be an operand of the recipe");
1812+
return Op == getCond() && isInvariantCond();
1813+
}
1814+
};
1815+
17721816
/// A recipe for handling GEP instructions.
17731817
class LLVM_ABI_FOR_TEST VPWidenGEPRecipe : public VPRecipeWithIRFlags {
17741818
Type *SourceElementTy;

llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ Type *VPTypeAnalysis::inferScalarTypeForRecipe(const VPInstruction *R) {
118118
return Type::getIntNTy(Ctx, 64);
119119
case VPInstruction::ExtractLastElement:
120120
case VPInstruction::ExtractLastLanePerPart:
121-
case VPInstruction::ExtractPenultimateElement: {
121+
case VPInstruction::ExtractPenultimateElement:
122+
case VPInstruction::ExtractLastActive: {
122123
Type *BaseTy = inferScalarType(R->getOperand(0));
123124
if (auto *VecTy = dyn_cast<VectorType>(BaseTy))
124125
return VecTy->getElementType();
@@ -311,7 +312,11 @@ Type *VPTypeAnalysis::inferScalarType(const VPValue *V) {
311312
})
312313
.Case<VPExpressionRecipe>([this](const auto *R) {
313314
return inferScalarType(R->getOperandOfResultType());
314-
});
315+
})
316+
.Case<VPWidenSelectVectorRecipe>(
317+
[this](const VPWidenSelectVectorRecipe *R) {
318+
return inferScalarType(R->getOperand(1));
319+
});
315320

316321
assert(ResultTy && "could not infer type for the given VPValue");
317322
CachedTypes[V] = ResultTy;

0 commit comments

Comments
 (0)