@@ -3526,30 +3526,25 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35263526 auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
35273527 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35283528
3529+ InstructionCost ExtRedCost;
35293530 InstructionCost ExtCost =
35303531 cast<VPWidenCastRecipe>(VecOp)->computeCost (VF, Ctx);
35313532 InstructionCost RedCost = Red->computeCost (VF, Ctx);
3532- InstructionCost BaseCost = ExtCost + RedCost;
35333533
35343534 if (isa<VPPartialReductionRecipe>(Red)) {
35353535 TargetTransformInfo::PartialReductionExtendKind ExtKind =
35363536 TargetTransformInfo::getPartialReductionExtendKind (ExtOpc);
35373537 // FIXME: Move partial reduction creation, costing and clamping
35383538 // here from LoopVectorize.cpp.
3539- InstructionCost PartialReductionCost =
3540- Ctx.TTI .getPartialReductionCost (
3541- Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
3542- llvm::TargetTransformInfo::PR_None, std::nullopt ,
3543- Ctx.CostKind );
3544- assert (PartialReductionCost.isValid () &&
3545- " A partial reduction should have a valid cost" );
3546- return true ;
3539+ ExtRedCost = Ctx.TTI .getPartialReductionCost (
3540+ Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
3541+ llvm::TargetTransformInfo::PR_None, std::nullopt , Ctx.CostKind );
3542+ } else {
3543+ ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3544+ Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3545+ Red->getFastMathFlags (), CostKind);
35473546 }
3548-
3549- InstructionCost ExtRedCost = Ctx.TTI .getExtendedReductionCost (
3550- Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3551- Red->getFastMathFlags (), CostKind);
3552- return ExtRedCost.isValid () && ExtRedCost < BaseCost;
3547+ return ExtRedCost.isValid () && ExtRedCost < ExtCost + RedCost;
35533548 },
35543549 Range);
35553550 };
@@ -3594,6 +3589,33 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35943589 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35953590 Type *SrcTy =
35963591 Ext0 ? Ctx.Types .inferScalarType (Ext0->getOperand (0 )) : RedTy;
3592+ InstructionCost MulAccCost;
3593+
3594+ if (IsPartialReduction) {
3595+ Type *SrcTy2 =
3596+ Ext1 ? Ctx.Types .inferScalarType (Ext1->getOperand (0 )) : nullptr ;
3597+ // FIXME: Move partial reduction creation, costing and clamping
3598+ // here from LoopVectorize.cpp.
3599+ MulAccCost = Ctx.TTI .getPartialReductionCost (
3600+ Opcode, SrcTy, SrcTy2, RedTy, VF,
3601+ Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
3602+ Ext0->getOpcode ())
3603+ : TargetTransformInfo::PR_None,
3604+ Ext1 ? TargetTransformInfo::getPartialReductionExtendKind (
3605+ Ext1->getOpcode ())
3606+ : TargetTransformInfo::PR_None,
3607+ Mul->getOpcode (), CostKind);
3608+ } else {
3609+ // Only partial reductions support mixed extends at the moment.
3610+ if (Ext0 && Ext1 && Ext0->getOpcode () != Ext1->getOpcode ())
3611+ return false ;
3612+
3613+ bool IsZExt =
3614+ !Ext0 || Ext0->getOpcode () == Instruction::CastOps::ZExt;
3615+ auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3616+ MulAccCost = Ctx.TTI .getMulAccReductionCost (IsZExt, Opcode, RedTy,
3617+ SrcVecTy, CostKind);
3618+ }
35973619
35983620 InstructionCost MulCost = Mul->computeCost (VF, Ctx);
35993621 InstructionCost RedCost = Red->computeCost (VF, Ctx);
@@ -3604,38 +3626,9 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36043626 ExtCost += Ext1->computeCost (VF, Ctx);
36053627 if (OuterExt)
36063628 ExtCost += OuterExt->computeCost (VF, Ctx);
3607- InstructionCost BaseCost = ExtCost + MulCost + RedCost;
3608-
3609- if (IsPartialReduction) {
3610- Type *SrcTy2 =
3611- Ext1 ? Ctx.Types .inferScalarType (Ext1->getOperand (0 )) : nullptr ;
3612- // FIXME: Move partial reduction creation, costing and clamping
3613- // here from LoopVectorize.cpp.
3614- InstructionCost PartialReductionCost =
3615- Ctx.TTI .getPartialReductionCost (
3616- Opcode, SrcTy, SrcTy2, RedTy, VF,
3617- Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
3618- Ext0->getOpcode ())
3619- : TargetTransformInfo::PR_None,
3620- Ext1 ? TargetTransformInfo::getPartialReductionExtendKind (
3621- Ext1->getOpcode ())
3622- : TargetTransformInfo::PR_None,
3623- Mul->getOpcode (), CostKind);
3624- assert (PartialReductionCost.isValid () &&
3625- " A partial reduction should have a valid cost" );
3626- return true ;
3627- }
3628- // Only partial reductions support mixed extends at the moment.
3629- if (Ext0 && Ext1 && Ext0->getOpcode () != Ext1->getOpcode ())
3630- return false ;
3631-
3632- bool IsZExt =
3633- !Ext0 || Ext0->getOpcode () == Instruction::CastOps::ZExt;
3634- auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3635- InstructionCost MulAccCost = Ctx.TTI .getMulAccReductionCost (
3636- IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
36373629
3638- return MulAccCost.isValid () && MulAccCost < BaseCost;
3630+ return MulAccCost.isValid () &&
3631+ MulAccCost < ExtCost + MulCost + RedCost;
36393632 },
36403633 Range);
36413634 };
0 commit comments