@@ -1395,22 +1395,46 @@ class BoUpSLP {
13951395 return VectorizableTree.front()->Scalars;
13961396 }
13971397
1398+ /// Returns the type/is-signed info for the root node in the graph without
1399+ /// casting.
1400+ std::optional<std::pair<Type *, bool>> getRootNodeTypeWithNoCast() const {
1401+ const TreeEntry &Root = *VectorizableTree.front().get();
1402+ if (Root.State != TreeEntry::Vectorize || Root.isAltShuffle() ||
1403+ !Root.Scalars.front()->getType()->isIntegerTy())
1404+ return std::nullopt;
1405+ auto It = MinBWs.find(&Root);
1406+ if (It != MinBWs.end())
1407+ return std::make_pair(IntegerType::get(Root.Scalars.front()->getContext(),
1408+ It->second.first),
1409+ It->second.second);
1410+ if (Root.getOpcode() == Instruction::ZExt ||
1411+ Root.getOpcode() == Instruction::SExt)
1412+ return std::make_pair(cast<CastInst>(Root.getMainOp())->getSrcTy(),
1413+ Root.getOpcode() == Instruction::SExt);
1414+ return std::nullopt;
1415+ }
1416+
13981417 /// Checks if the root graph node can be emitted with narrower bitwidth at
13991418 /// codegen and returns it signedness, if so.
14001419 bool isSignedMinBitwidthRootNode() const {
14011420 return MinBWs.at(VectorizableTree.front().get()).second;
14021421 }
14031422
1404- /// Returns reduction bitwidth and signedness, if it does not match the
1405- /// original requested size.
1406- std::optional<std::pair<unsigned, bool>> getReductionBitWidthAndSign() const {
1423+ /// Returns reduction type after minbitdth analysis.
1424+ FixedVectorType *getReductionType() const {
14071425 if (ReductionBitWidth == 0 ||
1426+ !VectorizableTree.front()->Scalars.front()->getType()->isIntegerTy() ||
14081427 ReductionBitWidth >=
14091428 DL->getTypeSizeInBits(
14101429 VectorizableTree.front()->Scalars.front()->getType()))
1411- return std::nullopt;
1412- return std::make_pair(ReductionBitWidth,
1413- MinBWs.at(VectorizableTree.front().get()).second);
1430+ return getWidenedType(
1431+ VectorizableTree.front()->Scalars.front()->getType(),
1432+ VectorizableTree.front()->getVectorFactor());
1433+ return getWidenedType(
1434+ IntegerType::get(
1435+ VectorizableTree.front()->Scalars.front()->getContext(),
1436+ ReductionBitWidth),
1437+ VectorizableTree.front()->getVectorFactor());
14141438 }
14151439
14161440 /// Builds external uses of the vectorized scalars, i.e. the list of
@@ -11384,6 +11408,20 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
1138411408 return CommonCost;
1138511409 auto *VI = VL0->getOpcode() == Opcode ? VL0 : nullptr;
1138611410 TTI::CastContextHint CCH = GetCastContextHint(VL0->getOperand(0));
11411+
11412+ bool IsArithmeticExtendedReduction =
11413+ E->Idx == 0 && UserIgnoreList &&
11414+ all_of(*UserIgnoreList, [](Value *V) {
11415+ auto *I = cast<Instruction>(V);
11416+ return is_contained({Instruction::Add, Instruction::FAdd,
11417+ Instruction::Mul, Instruction::FMul,
11418+ Instruction::And, Instruction::Or,
11419+ Instruction::Xor},
11420+ I->getOpcode());
11421+ });
11422+ if (IsArithmeticExtendedReduction &&
11423+ (VecOpcode == Instruction::ZExt || VecOpcode == Instruction::SExt))
11424+ return CommonCost;
1138711425 return CommonCost +
1138811426 TTI->getCastInstrCost(VecOpcode, VecTy, SrcVecTy, CCH, CostKind,
1138911427 VecOpcode == Opcode ? VI : nullptr);
@@ -12748,32 +12786,48 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
1274812786 unsigned SrcSize = It->second.first;
1274912787 unsigned DstSize = ReductionBitWidth;
1275012788 unsigned Opcode = Instruction::Trunc;
12751- if (SrcSize < DstSize)
12752- Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
12753- auto *SrcVecTy =
12754- getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
12755- auto *DstVecTy =
12756- getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
12757- TTI::CastContextHint CCH = getCastContextHint(E);
12758- InstructionCost CastCost;
12759- switch (E.getOpcode()) {
12760- case Instruction::SExt:
12761- case Instruction::ZExt:
12762- case Instruction::Trunc: {
12763- const TreeEntry *OpTE = getOperandEntry(&E, 0);
12764- CCH = getCastContextHint(*OpTE);
12765- break;
12766- }
12767- default:
12768- break;
12789+ if (SrcSize < DstSize) {
12790+ bool IsArithmeticExtendedReduction =
12791+ all_of(*UserIgnoreList, [](Value *V) {
12792+ auto *I = cast<Instruction>(V);
12793+ return is_contained({Instruction::Add, Instruction::FAdd,
12794+ Instruction::Mul, Instruction::FMul,
12795+ Instruction::And, Instruction::Or,
12796+ Instruction::Xor},
12797+ I->getOpcode());
12798+ });
12799+ if (IsArithmeticExtendedReduction)
12800+ Opcode =
12801+ Instruction::BitCast; // Handle it by getExtendedReductionCost
12802+ else
12803+ Opcode = It->second.second ? Instruction::SExt : Instruction::ZExt;
12804+ }
12805+ if (Opcode != Instruction::BitCast) {
12806+ auto *SrcVecTy =
12807+ getWidenedType(Builder.getIntNTy(SrcSize), E.getVectorFactor());
12808+ auto *DstVecTy =
12809+ getWidenedType(Builder.getIntNTy(DstSize), E.getVectorFactor());
12810+ TTI::CastContextHint CCH = getCastContextHint(E);
12811+ InstructionCost CastCost;
12812+ switch (E.getOpcode()) {
12813+ case Instruction::SExt:
12814+ case Instruction::ZExt:
12815+ case Instruction::Trunc: {
12816+ const TreeEntry *OpTE = getOperandEntry(&E, 0);
12817+ CCH = getCastContextHint(*OpTE);
12818+ break;
12819+ }
12820+ default:
12821+ break;
12822+ }
12823+ CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
12824+ TTI::TCK_RecipThroughput);
12825+ Cost += CastCost;
12826+ LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
12827+ << " for final resize for reduction from " << SrcVecTy
12828+ << " to " << DstVecTy << "\n";
12829+ dbgs() << "SLP: Current total cost = " << Cost << "\n");
1276912830 }
12770- CastCost += TTI->getCastInstrCost(Opcode, DstVecTy, SrcVecTy, CCH,
12771- TTI::TCK_RecipThroughput);
12772- Cost += CastCost;
12773- LLVM_DEBUG(dbgs() << "SLP: Adding cost " << CastCost
12774- << " for final resize for reduction from " << SrcVecTy
12775- << " to " << DstVecTy << "\n";
12776- dbgs() << "SLP: Current total cost = " << Cost << "\n");
1277712831 }
1277812832 }
1277912833
@@ -19951,8 +20005,8 @@ class HorizontalReduction {
1995120005
1995220006 // Estimate cost.
1995320007 InstructionCost TreeCost = V.getTreeCost(VL);
19954- InstructionCost ReductionCost = getReductionCost(
19955- TTI, VL, IsCmpSelMinMax, RdxFMF, V.getReductionBitWidthAndSign() );
20008+ InstructionCost ReductionCost =
20009+ getReductionCost( TTI, VL, IsCmpSelMinMax, RdxFMF, V);
1995620010 InstructionCost Cost = TreeCost + ReductionCost;
1995720011 LLVM_DEBUG(dbgs() << "SLP: Found cost = " << Cost
1995820012 << " for reduction\n");
@@ -20243,14 +20297,14 @@ class HorizontalReduction {
2024320297
2024420298private:
2024520299 /// Calculate the cost of a reduction.
20246- InstructionCost getReductionCost(
20247- TargetTransformInfo *TTI, ArrayRef<Value *> ReducedVals,
20248- bool IsCmpSelMinMax, FastMathFlags FMF,
20249- const std::optional<std::pair<unsigned, bool>> BitwidthAndSign ) {
20300+ InstructionCost getReductionCost(TargetTransformInfo *TTI,
20301+ ArrayRef<Value *> ReducedVals,
20302+ bool IsCmpSelMinMax, FastMathFlags FMF,
20303+ const BoUpSLP &R ) {
2025020304 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
2025120305 Type *ScalarTy = ReducedVals.front()->getType();
2025220306 unsigned ReduxWidth = ReducedVals.size();
20253- FixedVectorType *VectorTy = getWidenedType(ScalarTy, ReduxWidth );
20307+ FixedVectorType *VectorTy = R.getReductionType( );
2025420308 InstructionCost VectorCost = 0, ScalarCost;
2025520309 // If all of the reduced values are constant, the vector cost is 0, since
2025620310 // the reduction value can be calculated at the compile time.
@@ -20308,21 +20362,16 @@ class HorizontalReduction {
2030820362 VecTy, APInt::getAllOnes(ScalarTyNumElements), /*Insert*/ true,
2030920363 /*Extract*/ false, TTI::TCK_RecipThroughput);
2031020364 } else {
20311- auto [Bitwidth, IsSigned] =
20312- BitwidthAndSign.value_or(std::make_pair(0u, false));
20313- if (RdxKind == RecurKind::Add && Bitwidth == 1) {
20314- // Represent vector_reduce_add(ZExt(<n x i1>)) to
20315- // ZExtOrTrunc(ctpop(bitcast <n x i1> to in)).
20316- auto *IntTy = IntegerType::get(ScalarTy->getContext(), ReduxWidth);
20317- IntrinsicCostAttributes ICA(Intrinsic::ctpop, IntTy, {IntTy}, FMF);
20318- VectorCost =
20319- TTI->getCastInstrCost(Instruction::BitCast, IntTy,
20320- getWidenedType(ScalarTy, ReduxWidth),
20321- TTI::CastContextHint::None, CostKind) +
20322- TTI->getIntrinsicInstrCost(ICA, CostKind);
20323- } else {
20365+ Type *RedTy = VectorTy->getElementType();
20366+ auto [RType, IsSigned] = R.getRootNodeTypeWithNoCast().value_or(
20367+ std::make_pair(RedTy, true));
20368+ if (RType == RedTy) {
2032420369 VectorCost = TTI->getArithmeticReductionCost(RdxOpcode, VectorTy,
2032520370 FMF, CostKind);
20371+ } else {
20372+ VectorCost = TTI->getExtendedReductionCost(
20373+ RdxOpcode, !IsSigned, RedTy, getWidenedType(RType, ReduxWidth),
20374+ FMF, CostKind);
2032620375 }
2032720376 }
2032820377 }
0 commit comments