Skip to content

Commit 07de738

Browse files
committed
Extend the constant operand for ext(mul(a, b)) as well
1 parent 5734bdc commit 07de738

File tree

2 files changed

+326
-33
lines changed

2 files changed

+326
-33
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3649,39 +3649,52 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36493649
Sub = VecOp->getDefiningRecipe();
36503650
VecOp = Tmp;
36513651
}
3652-
// Try to match reduce.add(mul(...)).
3653-
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
3654-
auto *RecipeA =
3655-
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
3656-
auto *RecipeB =
3657-
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3658-
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
36593652

3660-
// Match reduce.add(mul(ext, const)) and convert it to
3661-
// reduce.add(mul(ext, ext(const)))
3662-
if (RecipeA && !RecipeB && B->isLiveIn()) {
3663-
Type *NarrowTy = Ctx.Types.inferScalarType(RecipeA->getOperand(0));
3664-
Instruction::CastOps ExtOpc = RecipeA->getOpcode();
3665-
auto *Const = dyn_cast<ConstantInt>(B->getLiveInIRValue());
3653+
// If ValB is a constant and can be safely extended, truncate it to the same
3654+
// type as ExtA's operand, then extend it to the same type as ExtA. This
3655+
// creates two uniform extends that can more easily be matched by the rest of
3656+
// the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
3657+
// replaced with the new extend of the constant.
3658+
auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
3659+
VPWidenCastRecipe *&ExtB,
3660+
VPValue *&ValB, VPWidenRecipe *Mul) {
3661+
if (ExtA && !ExtB && ValB->isLiveIn()) {
3662+
Type *NarrowTy = Ctx.Types.inferScalarType(ExtA->getOperand(0));
3663+
Type *WideTy = Ctx.Types.inferScalarType(ExtA);
3664+
Instruction::CastOps ExtOpc = ExtA->getOpcode();
3665+
auto *Const = dyn_cast<ConstantInt>(ValB->getLiveInIRValue());
36663666
if (Const &&
36673667
llvm::canConstantBeExtended(
36683668
Const, NarrowTy, TTI::getPartialReductionExtendKind(ExtOpc))) {
36693669
// The truncate ensures that the type of each extended operand is the
36703670
// same, and it's been proven that the constant can be extended from
3671-
// NarrowTy safely. Necessary since RecipeA's extended operand would be
3671+
// NarrowTy safely. Necessary since ExtA's extended operand would be
36723672
// e.g. an i8, while the const will likely be an i32. This will be
36733673
// elided by later optimisations.
36743674
auto *Trunc =
3675-
new VPWidenCastRecipe(Instruction::CastOps::Trunc, B, NarrowTy);
3676-
Trunc->insertBefore(*RecipeA->getParent(),
3677-
std::next(RecipeA->getIterator()));
3678-
3679-
Type *WideTy = Ctx.Types.inferScalarType(RecipeA);
3680-
RecipeB = new VPWidenCastRecipe(ExtOpc, Trunc, WideTy);
3681-
RecipeB->insertAfter(Trunc);
3682-
Mul->setOperand(1, RecipeB);
3675+
new VPWidenCastRecipe(Instruction::CastOps::Trunc, ValB, NarrowTy);
3676+
Trunc->insertBefore(*ExtA->getParent(), std::next(ExtA->getIterator()));
3677+
3678+
VPWidenCastRecipe *NewCast =
3679+
new VPWidenCastRecipe(ExtOpc, Trunc, WideTy);
3680+
NewCast->insertAfter(Trunc);
3681+
ExtB = NewCast;
3682+
ValB = NewCast;
3683+
Mul->setOperand(1, NewCast);
36833684
}
36843685
}
3686+
};
3687+
3688+
// Try to match reduce.add(mul(...)).
3689+
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
3690+
auto *RecipeA =
3691+
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
3692+
auto *RecipeB =
3693+
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3694+
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
3695+
3696+
// Convert reduce.add(mul(ext, const)) to reduce.add(mul(ext, ext(const)))
3697+
ExtendAndReplaceConstantOp(RecipeA, RecipeB, B, Mul);
36853698

36863699
// Match reduce.add/sub(mul(ext, ext)).
36873700
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
@@ -3692,7 +3705,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36923705
cast<VPWidenRecipe>(Sub), Red);
36933706
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
36943707
}
3695-
// Match reduce.add(mul).
36963708
// TODO: Add an expression type for this variant with a negated mul
36973709
if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
36983710
return new VPExpressionRecipe(Mul, Red);
@@ -3701,18 +3713,23 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
37013713
// variants.
37023714
if (Sub)
37033715
return nullptr;
3704-
// Match reduce.add(ext(mul(ext(A), ext(B)))).
3705-
// All extend recipes must have same opcode or A == B
3706-
// which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
3707-
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
3708-
m_ZExtOrSExt(m_VPValue()))))) {
3716+
3717+
// Match reduce.add(ext(mul(A, B))).
3718+
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
37093719
auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
37103720
auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
3711-
auto *Ext0 =
3712-
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
3713-
auto *Ext1 =
3714-
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
3715-
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
3721+
auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
3722+
auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3723+
3724+
// Convert reduce.add(ext(mul(ext, const))) to reduce.add(ext(mul(ext,
3725+
// ext(const))))
3726+
ExtendAndReplaceConstantOp(Ext0, Ext1, B, Mul);
3727+
3728+
// Match reduce.add(ext(mul(ext(A), ext(B))))
3729+
// All extend recipes must have same opcode or A == B
3730+
// which can be transformed to reduce.add(zext(mul(sext(A), sext(B)))).
3731+
if (Ext0 && Ext1 &&
3732+
(Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
37163733
Ext0->getOpcode() == Ext1->getOpcode() &&
37173734
IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext) && Mul->hasOneUse()) {
37183735
auto *NewExt0 = new VPWidenCastRecipe(

0 commit comments

Comments
 (0)