@@ -7531,6 +7531,10 @@ static bool planContainsAdditionalSimplifications(VPlan &Plan,
75317531 }
75327532 continue ;
75337533 }
7534+ // The VPlan-based cost model is more accurate for partial reduction and
7535+ // comparing against the legacy cost isn't desirable.
7536+ if (isa<VPPartialReductionRecipe>(&R))
7537+ return true ;
75347538 if (Instruction *UI = GetInstructionForCost (&R))
75357539 SeenInstrs.insert (UI);
75367540 }
@@ -8751,6 +8755,105 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(Instruction *I,
87518755 return Recipe;
87528756}
87538757
8758+ // / Find all possible partial reductions in the loop and track all of those that
8759+ // / are valid so recipes can be formed later.
8760+ void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8761+ // Find all possible partial reductions.
8762+ SmallVector<std::pair<PartialReductionChain, unsigned >, 1 >
8763+ PartialReductionChains;
8764+ for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ())
8765+ if (std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8766+ getScaledReduction (Phi, RdxDesc, Range))
8767+ PartialReductionChains.push_back (*Pair);
8768+
8769+ // A partial reduction is invalid if any of its extends are used by
8770+ // something that isn't another partial reduction. This is because the
8771+ // extends are intended to be lowered along with the reduction itself.
8772+
8773+ // Build up a set of partial reduction bin ops for efficient use checking.
8774+ SmallSet<User *, 4 > PartialReductionBinOps;
8775+ for (const auto &[PartialRdx, _] : PartialReductionChains)
8776+ PartialReductionBinOps.insert (PartialRdx.BinOp );
8777+
8778+ auto ExtendIsOnlyUsedByPartialReductions =
8779+ [&PartialReductionBinOps](Instruction *Extend) {
8780+ return all_of (Extend->users (), [&](const User *U) {
8781+ return PartialReductionBinOps.contains (U);
8782+ });
8783+ };
8784+
8785+ // Check if each use of a chain's two extends is a partial reduction
8786+ // and only add those that don't have non-partial reduction users.
8787+ for (auto Pair : PartialReductionChains) {
8788+ PartialReductionChain Chain = Pair.first ;
8789+ if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8790+ ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB ))
8791+ ScaledReductionExitInstrs.insert (std::make_pair (Chain.Reduction , Pair));
8792+ }
8793+ }
8794+
8795+ std::optional<std::pair<PartialReductionChain, unsigned >>
8796+ VPRecipeBuilder::getScaledReduction (PHINode *PHI,
8797+ const RecurrenceDescriptor &Rdx,
8798+ VFRange &Range) {
8799+ // TODO: Allow scaling reductions when predicating. The select at
8800+ // the end of the loop chooses between the phi value and most recent
8801+ // reduction result, both of which have different VFs to the active lane
8802+ // mask when scaling.
8803+ if (CM.blockNeedsPredicationForAnyReason (Rdx.getLoopExitInstr ()->getParent ()))
8804+ return std::nullopt ;
8805+
8806+ auto *Update = dyn_cast<BinaryOperator>(Rdx.getLoopExitInstr ());
8807+ if (!Update)
8808+ return std::nullopt ;
8809+
8810+ Value *Op = Update->getOperand (0 );
8811+ Value *PhiOp = Update->getOperand (1 );
8812+ if (Op == PHI) {
8813+ Op = Update->getOperand (1 );
8814+ PhiOp = Update->getOperand (0 );
8815+ }
8816+ if (PhiOp != PHI)
8817+ return std::nullopt ;
8818+
8819+ auto *BinOp = dyn_cast<BinaryOperator>(Op);
8820+ if (!BinOp || !BinOp->hasOneUse ())
8821+ return std::nullopt ;
8822+
8823+ using namespace llvm ::PatternMatch;
8824+ Value *A, *B;
8825+ if (!match (BinOp->getOperand (0 ), m_ZExtOrSExt (m_Value (A))) ||
8826+ !match (BinOp->getOperand (1 ), m_ZExtOrSExt (m_Value (B))))
8827+ return std::nullopt ;
8828+
8829+ Instruction *ExtA = cast<Instruction>(BinOp->getOperand (0 ));
8830+ Instruction *ExtB = cast<Instruction>(BinOp->getOperand (1 ));
8831+
8832+ TTI::PartialReductionExtendKind OpAExtend =
8833+ TargetTransformInfo::getPartialReductionExtendKind (ExtA);
8834+ TTI::PartialReductionExtendKind OpBExtend =
8835+ TargetTransformInfo::getPartialReductionExtendKind (ExtB);
8836+
8837+ PartialReductionChain Chain (Rdx.getLoopExitInstr (), ExtA, ExtB, BinOp);
8838+
8839+ unsigned TargetScaleFactor =
8840+ PHI->getType ()->getPrimitiveSizeInBits ().getKnownScalarFactor (
8841+ A->getType ()->getPrimitiveSizeInBits ());
8842+
8843+ if (LoopVectorizationPlanner::getDecisionAndClampRange (
8844+ [&](ElementCount VF) {
8845+ InstructionCost Cost = TTI->getPartialReductionCost (
8846+ Update->getOpcode (), A->getType (), B->getType (), PHI->getType (),
8847+ VF, OpAExtend, OpBExtend,
8848+ std::make_optional (BinOp->getOpcode ()));
8849+ return Cost.isValid ();
8850+ },
8851+ Range))
8852+ return std::make_pair (Chain, TargetScaleFactor);
8853+
8854+ return std::nullopt ;
8855+ }
8856+
87548857VPRecipeBase *
87558858VPRecipeBuilder::tryToCreateWidenRecipe (Instruction *Instr,
87568859 ArrayRef<VPValue *> Operands,
@@ -8775,9 +8878,14 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
87758878 Legal->getReductionVars ().find (Phi)->second ;
87768879 assert (RdxDesc.getRecurrenceStartValue () ==
87778880 Phi->getIncomingValueForBlock (OrigLoop->getLoopPreheader ()));
8778- PhiRecipe = new VPReductionPHIRecipe (Phi, RdxDesc, *StartV,
8779- CM.isInLoopReduction (Phi),
8780- CM.useOrderedReductions (RdxDesc));
8881+
8882+ // If the PHI is used by a partial reduction, set the scale factor.
8883+ std::optional<std::pair<PartialReductionChain, unsigned >> Pair =
8884+ getScaledReductionForInstr (RdxDesc.getLoopExitInstr ());
8885+ unsigned ScaleFactor = Pair ? Pair->second : 1 ;
8886+ PhiRecipe = new VPReductionPHIRecipe (
8887+ Phi, RdxDesc, *StartV, CM.isInLoopReduction (Phi),
8888+ CM.useOrderedReductions (RdxDesc), ScaleFactor);
87818889 } else {
87828890 // TODO: Currently fixed-order recurrences are modeled as chains of
87838891 // first-order recurrences. If there are no users of the intermediate
@@ -8809,6 +8917,9 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88098917 if (isa<LoadInst>(Instr) || isa<StoreInst>(Instr))
88108918 return tryToWidenMemory (Instr, Operands, Range);
88118919
8920+ if (getScaledReductionForInstr (Instr))
8921+ return tryToCreatePartialReduction (Instr, Operands);
8922+
88128923 if (!shouldWiden (Instr, Range))
88138924 return nullptr ;
88148925
@@ -8829,6 +8940,21 @@ VPRecipeBuilder::tryToCreateWidenRecipe(Instruction *Instr,
88298940 return tryToWiden (Instr, Operands, VPBB);
88308941}
88318942
8943+ VPRecipeBase *
8944+ VPRecipeBuilder::tryToCreatePartialReduction (Instruction *Reduction,
8945+ ArrayRef<VPValue *> Operands) {
8946+ assert (Operands.size () == 2 &&
8947+ " Unexpected number of operands for partial reduction" );
8948+
8949+ VPValue *BinOp = Operands[0 ];
8950+ VPValue *Phi = Operands[1 ];
8951+ if (isa<VPReductionPHIRecipe>(BinOp->getDefiningRecipe ()))
8952+ std::swap (BinOp, Phi);
8953+
8954+ return new VPPartialReductionRecipe (Reduction->getOpcode (), BinOp, Phi,
8955+ Reduction);
8956+ }
8957+
88328958void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
88338959 ElementCount MaxVF) {
88348960 assert (OrigLoop->isInnermost () && " Inner loop expected." );
@@ -9252,7 +9378,8 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92529378 bool HasNUW = !IVUpdateMayOverflow || Style == TailFoldingStyle::None;
92539379 addCanonicalIVRecipes (*Plan, Legal->getWidestInductionType (), HasNUW, DL);
92549380
9255- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9381+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9382+ Builder);
92569383
92579384 // ---------------------------------------------------------------------------
92589385 // Pre-construction: record ingredients whose recipes we'll need to further
@@ -9298,6 +9425,9 @@ LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(VFRange &Range) {
92989425 bool NeedsBlends = BB != HeaderBB && !BB->phis ().empty ();
92999426 return Legal->blockNeedsPredication (BB) || NeedsBlends;
93009427 });
9428+
9429+ RecipeBuilder.collectScaledReductions (Range);
9430+
93019431 auto *MiddleVPBB = Plan->getMiddleBlock ();
93029432 VPBasicBlock::iterator MBIP = MiddleVPBB->getFirstNonPhi ();
93039433 for (BasicBlock *BB : make_range (DFS.beginRPO (), DFS.endRPO ())) {
@@ -9521,7 +9651,8 @@ VPlanPtr LoopVectorizationPlanner::buildVPlan(VFRange &Range) {
95219651
95229652 // Collect mapping of IR header phis to header phi recipes, to be used in
95239653 // addScalarResumePhis.
9524- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, Legal, CM, PSE, Builder);
9654+ VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, PSE,
9655+ Builder);
95259656 for (auto &R : Plan->getVectorLoopRegion ()->getEntryBasicBlock ()->phis ()) {
95269657 if (isa<VPCanonicalIVPHIRecipe>(&R))
95279658 continue ;
0 commit comments