@@ -3526,25 +3526,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35263526 auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
35273527 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35283528
3529- InstructionCost ExtRedCost;
35303529 InstructionCost ExtCost =
35313530 cast<VPWidenCastRecipe>(VecOp)->computeCost (VF, Ctx);
35323531 InstructionCost RedCost = Red->computeCost (VF, Ctx);
3532+ InstructionCost BaseCost = ExtCost + RedCost;
35333533
35343534 if (isa<VPPartialReductionRecipe>(Red)) {
35353535 TargetTransformInfo::PartialReductionExtendKind ExtKind =
35363536 TargetTransformInfo::getPartialReductionExtendKind (ExtOpc);
35373537 // FIXME: Move partial reduction creation, costing and clamping
35383538 // here from LoopVectorize.cpp.
3539- ExtRedCost = Ctx.TTI .getPartialReductionCost (
3540- Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
3541- llvm::TargetTransformInfo::PR_None, std::nullopt , Ctx.CostKind );
3542- } else {
3543- ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3544- Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3545- Red->getFastMathFlags (), CostKind);
3539+ InstructionCost PartialReductionCost =
3540+ Ctx.TTI .getPartialReductionCost (
3541+ Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
3542+ llvm::TargetTransformInfo::PR_None, std::nullopt ,
3543+ Ctx.CostKind );
3544+ assert (PartialReductionCost <= BaseCost &&
3545+ " A partial reduction should have a lower cost than the "
3546+ " extend + add" );
3547+ return true ;
35463548 }
3547- return ExtRedCost.isValid () && ExtRedCost < ExtCost + RedCost;
3549+
3550+ InstructionCost ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3551+ Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3552+ Red->getFastMathFlags (), CostKind);
3553+ return ExtRedCost.isValid () && ExtRedCost < BaseCost;
35483554 },
35493555 Range);
35503556 };
@@ -3589,46 +3595,50 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35893595 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35903596 Type *SrcTy =
35913597 Ext0 ? Ctx.Types .inferScalarType (Ext0->getOperand (0 )) : RedTy;
3592- InstructionCost MulAccCost;
3598+
3599+ InstructionCost MulCost = Mul->computeCost (VF, Ctx);
3600+ InstructionCost RedCost = Red->computeCost (VF, Ctx);
3601+ InstructionCost ExtCost = 0 ;
3602+ if (Ext0)
3603+ ExtCost += Ext0->computeCost (VF, Ctx);
3604+ if (Ext1)
3605+ ExtCost += Ext1->computeCost (VF, Ctx);
3606+ if (OuterExt)
3607+ ExtCost += OuterExt->computeCost (VF, Ctx);
3608+ InstructionCost BaseCost = ExtCost + MulCost + RedCost;
35933609
35943610 if (IsPartialReduction) {
35953611 Type *SrcTy2 =
35963612 Ext1 ? Ctx.Types .inferScalarType (Ext1->getOperand (0 )) : nullptr ;
35973613 // FIXME: Move partial reduction creation, costing and clamping
35983614 // here from LoopVectorize.cpp.
3599- MulAccCost = Ctx.TTI .getPartialReductionCost (
3600- Opcode, SrcTy, SrcTy2, RedTy, VF,
3601- Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
3602- Ext0->getOpcode ())
3603- : TargetTransformInfo::PR_None,
3604- Ext1 ? TargetTransformInfo::getPartialReductionExtendKind (
3605- Ext1->getOpcode ())
3606- : TargetTransformInfo::PR_None,
3607- Mul->getOpcode (), CostKind);
3608- } else {
3615+ InstructionCost PartialReductionCost =
3616+ Ctx.TTI .getPartialReductionCost (
3617+ Opcode, SrcTy, SrcTy2, RedTy, VF,
3618+ Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
3619+ Ext0->getOpcode ())
3620+ : TargetTransformInfo::PR_None,
3621+ Ext1 ? TargetTransformInfo::getPartialReductionExtendKind (
3622+ Ext1->getOpcode ())
3623+ : TargetTransformInfo::PR_None,
3624+ Mul->getOpcode (), CostKind);
3625+ assert (PartialReductionCost <= BaseCost &&
3626+ " A partial reduction should have a lower cost than the "
3627+ " extend + mul + add" );
3628+ return true ;
3629+ }
36093630 // Only partial reductions support mixed extends at the moment.
36103631 if (Ext0 && Ext1 && Ext0->getOpcode () != Ext1->getOpcode ())
36113632 return false ;
36123633
36133634 bool IsZExt =
36143635 !Ext0 || Ext0->getOpcode () == Instruction::CastOps::ZExt;
36153636 auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3616- MulAccCost = Ctx.TTI .getMulAccReductionCost (IsZExt, Opcode, RedTy,
3617- SrcVecTy, CostKind);
3618- }
3619-
3620- InstructionCost MulCost = Mul->computeCost (VF, Ctx);
3621- InstructionCost RedCost = Red->computeCost (VF, Ctx);
3622- InstructionCost ExtCost = 0 ;
3623- if (Ext0)
3624- ExtCost += Ext0->computeCost (VF, Ctx);
3625- if (Ext1)
3626- ExtCost += Ext1->computeCost (VF, Ctx);
3627- if (OuterExt)
3628- ExtCost += OuterExt->computeCost (VF, Ctx);
3637+ InstructionCost MulAccCost = Ctx.TTI .getMulAccReductionCost (
3638+ IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
36293639
3630- return MulAccCost.isValid () &&
3631- MulAccCost < ExtCost + MulCost + RedCost;
3640+ return MulAccCost.isValid () &&
3641+ MulAccCost < ExtCost + MulCost + RedCost;
36323642 },
36333643 Range);
36343644 };
0 commit comments