Skip to content

Commit 6fad720

Browse files
committed
Extend the constant operand for ext(mul(a, b)) as well
1 parent fb1dd8d commit 6fad720

File tree

2 files changed

+326
-33
lines changed

2 files changed

+326
-33
lines changed

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 50 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3596,39 +3596,52 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35963596
Sub = VecOp->getDefiningRecipe();
35973597
VecOp = Tmp;
35983598
}
3599-
// Try to match reduce.add(mul(...)).
3600-
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
3601-
auto *RecipeA =
3602-
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
3603-
auto *RecipeB =
3604-
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3605-
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
36063599

3607-
// Match reduce.add(mul(ext, const)) and convert it to
3608-
// reduce.add(mul(ext, ext(const)))
3609-
if (RecipeA && !RecipeB && B->isLiveIn()) {
3610-
Type *NarrowTy = Ctx.Types.inferScalarType(RecipeA->getOperand(0));
3611-
Instruction::CastOps ExtOpc = RecipeA->getOpcode();
3612-
auto *Const = dyn_cast<ConstantInt>(B->getLiveInIRValue());
3600+
// If ValB is a constant and can be safely extended, truncate it to the same
3601+
// type as ExtA's operand, then extend it to the same type as ExtA. This
3602+
// creates two uniform extends that can more easily be matched by the rest of
3603+
// the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
3604+
// replaced with the new extend of the constant.
3605+
auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
3606+
VPWidenCastRecipe *&ExtB,
3607+
VPValue *&ValB, VPWidenRecipe *Mul) {
3608+
if (ExtA && !ExtB && ValB->isLiveIn()) {
3609+
Type *NarrowTy = Ctx.Types.inferScalarType(ExtA->getOperand(0));
3610+
Type *WideTy = Ctx.Types.inferScalarType(ExtA);
3611+
Instruction::CastOps ExtOpc = ExtA->getOpcode();
3612+
auto *Const = dyn_cast<ConstantInt>(ValB->getLiveInIRValue());
36133613
if (Const &&
36143614
llvm::canConstantBeExtended(
36153615
Const, NarrowTy, TTI::getPartialReductionExtendKind(ExtOpc))) {
36163616
// The truncate ensures that the type of each extended operand is the
36173617
// same, and it's been proven that the constant can be extended from
3618-
// NarrowTy safely. Necessary since RecipeA's extended operand would be
3618+
// NarrowTy safely. Necessary since ExtA's extended operand would be
36193619
// e.g. an i8, while the const will likely be an i32. This will be
36203620
// elided by later optimisations.
36213621
auto *Trunc =
3622-
new VPWidenCastRecipe(Instruction::CastOps::Trunc, B, NarrowTy);
3623-
Trunc->insertBefore(*RecipeA->getParent(),
3624-
std::next(RecipeA->getIterator()));
3625-
3626-
Type *WideTy = Ctx.Types.inferScalarType(RecipeA);
3627-
RecipeB = new VPWidenCastRecipe(ExtOpc, Trunc, WideTy);
3628-
RecipeB->insertAfter(Trunc);
3629-
Mul->setOperand(1, RecipeB);
3622+
new VPWidenCastRecipe(Instruction::CastOps::Trunc, ValB, NarrowTy);
3623+
Trunc->insertBefore(*ExtA->getParent(), std::next(ExtA->getIterator()));
3624+
3625+
VPWidenCastRecipe *NewCast =
3626+
new VPWidenCastRecipe(ExtOpc, Trunc, WideTy);
3627+
NewCast->insertAfter(Trunc);
3628+
ExtB = NewCast;
3629+
ValB = NewCast;
3630+
Mul->setOperand(1, NewCast);
36303631
}
36313632
}
3633+
};
3634+
3635+
// Try to match reduce.add(mul(...)).
3636+
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
3637+
auto *RecipeA =
3638+
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
3639+
auto *RecipeB =
3640+
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3641+
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
3642+
3643+
// Convert reduce.add(mul(ext, const)) to reduce.add(mul(ext, ext(const)))
3644+
ExtendAndReplaceConstantOp(RecipeA, RecipeB, B, Mul);
36323645

36333646
// Match reduce.add/sub(mul(ext, ext)).
36343647
if (RecipeA && RecipeB && match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
@@ -3639,7 +3652,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36393652
cast<VPWidenRecipe>(Sub), Red);
36403653
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
36413654
}
3642-
// Match reduce.add(mul).
36433655
// TODO: Add an expression type for this variant with a negated mul
36443656
if (!Sub && IsMulAccValidAndClampRange(Mul, nullptr, nullptr, nullptr))
36453657
return new VPExpressionRecipe(Mul, Red);
@@ -3648,18 +3660,23 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36483660
// variants.
36493661
if (Sub)
36503662
return nullptr;
3651-
// Match reduce.add(ext(mul(ext(A), ext(B)))).
3652-
// All extend recipes must have same opcode or A == B
3653-
// which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
3654-
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
3655-
m_ZExtOrSExt(m_VPValue()))))) {
3663+
3664+
// Match reduce.add(ext(mul(A, B))).
3665+
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_VPValue(A), m_VPValue(B))))) {
36563666
auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
36573667
auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
3658-
auto *Ext0 =
3659-
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
3660-
auto *Ext1 =
3661-
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
3662-
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
3668+
auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
3669+
auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3670+
3671+
// Convert reduce.add(ext(mul(ext, const))) to reduce.add(ext(mul(ext,
3672+
// ext(const))))
3673+
ExtendAndReplaceConstantOp(Ext0, Ext1, B, Mul);
3674+
3675+
// Match reduce.add(ext(mul(ext(A), ext(B))))
3676+
// All extend recipes must have same opcode or A == B
3677+
// which can be transformed to reduce.add(zext(mul(sext(A), sext(B)))).
3678+
if (Ext0 && Ext1 &&
3679+
(Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
36633680
Ext0->getOpcode() == Ext1->getOpcode() &&
36643681
IsMulAccValidAndClampRange(Mul, Ext0, Ext1, Ext)) {
36653682
auto *NewExt0 = new VPWidenCastRecipe(

0 commit comments

Comments
 (0)