Skip to content

Commit 0920112

Browse files
committed
Use assertion in isMulAccValidAndClampRange
1 parent e571aff commit 0920112

File tree

1 file changed

+45
-35
lines changed

1 file changed

+45
-35
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3526,25 +3526,31 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35263526
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
35273527
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35283528

3529-
InstructionCost ExtRedCost;
35303529
InstructionCost ExtCost =
35313530
cast<VPWidenCastRecipe>(VecOp)->computeCost(VF, Ctx);
35323531
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3532+
InstructionCost BaseCost = ExtCost + RedCost;
35333533

35343534
if (isa<VPPartialReductionRecipe>(Red)) {
35353535
TargetTransformInfo::PartialReductionExtendKind ExtKind =
35363536
TargetTransformInfo::getPartialReductionExtendKind(ExtOpc);
35373537
// FIXME: Move partial reduction creation, costing and clamping
35383538
// here from LoopVectorize.cpp.
3539-
ExtRedCost = Ctx.TTI.getPartialReductionCost(
3540-
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
3541-
llvm::TargetTransformInfo::PR_None, std::nullopt, Ctx.CostKind);
3542-
} else {
3543-
ExtRedCost = Ctx.TTI.getExtendedReductionCost(
3544-
Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3545-
Red->getFastMathFlags(), CostKind);
3539+
InstructionCost PartialReductionCost =
3540+
Ctx.TTI.getPartialReductionCost(
3541+
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
3542+
llvm::TargetTransformInfo::PR_None, std::nullopt,
3543+
Ctx.CostKind);
3544+
assert(PartialReductionCost <= BaseCost &&
3545+
"A partial reduction should have a lower cost than the "
3546+
"extend + add");
3547+
return true;
35463548
}
3547-
return ExtRedCost.isValid() && ExtRedCost < ExtCost + RedCost;
3549+
3550+
InstructionCost ExtRedCost = Ctx.TTI.getExtendedReductionCost(
3551+
Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
3552+
Red->getFastMathFlags(), CostKind);
3553+
return ExtRedCost.isValid() && ExtRedCost < BaseCost;
35483554
},
35493555
Range);
35503556
};
@@ -3589,46 +3595,50 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35893595
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35903596
Type *SrcTy =
35913597
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
3592-
InstructionCost MulAccCost;
3598+
3599+
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
3600+
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3601+
InstructionCost ExtCost = 0;
3602+
if (Ext0)
3603+
ExtCost += Ext0->computeCost(VF, Ctx);
3604+
if (Ext1)
3605+
ExtCost += Ext1->computeCost(VF, Ctx);
3606+
if (OuterExt)
3607+
ExtCost += OuterExt->computeCost(VF, Ctx);
3608+
InstructionCost BaseCost = ExtCost + MulCost + RedCost;
35933609

35943610
if (IsPartialReduction) {
35953611
Type *SrcTy2 =
35963612
Ext1 ? Ctx.Types.inferScalarType(Ext1->getOperand(0)) : nullptr;
35973613
// FIXME: Move partial reduction creation, costing and clamping
35983614
// here from LoopVectorize.cpp.
3599-
MulAccCost = Ctx.TTI.getPartialReductionCost(
3600-
Opcode, SrcTy, SrcTy2, RedTy, VF,
3601-
Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
3602-
Ext0->getOpcode())
3603-
: TargetTransformInfo::PR_None,
3604-
Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
3605-
Ext1->getOpcode())
3606-
: TargetTransformInfo::PR_None,
3607-
Mul->getOpcode(), CostKind);
3608-
} else {
3615+
InstructionCost PartialReductionCost =
3616+
Ctx.TTI.getPartialReductionCost(
3617+
Opcode, SrcTy, SrcTy2, RedTy, VF,
3618+
Ext0 ? TargetTransformInfo::getPartialReductionExtendKind(
3619+
Ext0->getOpcode())
3620+
: TargetTransformInfo::PR_None,
3621+
Ext1 ? TargetTransformInfo::getPartialReductionExtendKind(
3622+
Ext1->getOpcode())
3623+
: TargetTransformInfo::PR_None,
3624+
Mul->getOpcode(), CostKind);
3625+
assert(PartialReductionCost <= BaseCost &&
3626+
"A partial reduction should have a lower cost than the "
3627+
"extend + mul + add");
3628+
return true;
3629+
}
36093630
// Only partial reductions support mixed extends at the moment.
36103631
if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
36113632
return false;
36123633

36133634
bool IsZExt =
36143635
!Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
36153636
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
3616-
MulAccCost = Ctx.TTI.getMulAccReductionCost(IsZExt, Opcode, RedTy,
3617-
SrcVecTy, CostKind);
3618-
}
3619-
3620-
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
3621-
InstructionCost RedCost = Red->computeCost(VF, Ctx);
3622-
InstructionCost ExtCost = 0;
3623-
if (Ext0)
3624-
ExtCost += Ext0->computeCost(VF, Ctx);
3625-
if (Ext1)
3626-
ExtCost += Ext1->computeCost(VF, Ctx);
3627-
if (OuterExt)
3628-
ExtCost += OuterExt->computeCost(VF, Ctx);
3637+
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
3638+
IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
36293639

3630-
return MulAccCost.isValid() &&
3631-
MulAccCost < ExtCost + MulCost + RedCost;
3640+
return MulAccCost.isValid() &&
3641+
MulAccCost < ExtCost + MulCost + RedCost;
36323642
},
36333643
Range);
36343644
};

0 commit comments

Comments
 (0)