Skip to content

Commit 5734bdc

Browse files
committed
Address review
1 parent 0920112 commit 5734bdc

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3541,9 +3541,8 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
35413541
Opcode, SrcTy, nullptr, RedTy, VF, ExtKind,
35423542
llvm::TargetTransformInfo::PR_None, std::nullopt,
35433543
Ctx.CostKind);
3544-
assert(PartialReductionCost <= BaseCost &&
3545-
"A partial reduction should have a lower cost than the "
3546-
"extend + add");
3544+
assert(PartialReductionCost.isValid() &&
3545+
"A partial reduction should have a valid cost");
35473546
return true;
35483547
}
35493548

@@ -3622,23 +3621,21 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36223621
Ext1->getOpcode())
36233622
: TargetTransformInfo::PR_None,
36243623
Mul->getOpcode(), CostKind);
3625-
assert(PartialReductionCost <= BaseCost &&
3626-
"A partial reduction should have a lower cost than the "
3627-
"extend + mul + add");
3624+
assert(PartialReductionCost.isValid() &&
3625+
"A partial reduction should have a valid cost");
36283626
return true;
36293627
}
3630-
// Only partial reductions support mixed extends at the moment.
3631-
if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
3632-
return false;
3633-
3634-
bool IsZExt =
3635-
!Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
3636-
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
3637-
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
3638-
IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
3639-
3640-
return MulAccCost.isValid() &&
3641-
MulAccCost < ExtCost + MulCost + RedCost;
3628+
// Only partial reductions support mixed extends at the moment.
3629+
if (Ext0 && Ext1 && Ext0->getOpcode() != Ext1->getOpcode())
3630+
return false;
3631+
3632+
bool IsZExt =
3633+
!Ext0 || Ext0->getOpcode() == Instruction::CastOps::ZExt;
3634+
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
3635+
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
3636+
IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
3637+
3638+
return MulAccCost.isValid() && MulAccCost < BaseCost;
36423639
},
36433640
Range);
36443641
};

0 commit comments

Comments
 (0)