@@ -3649,39 +3649,52 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36493649 Sub = VecOp->getDefiningRecipe ();
36503650 VecOp = Tmp;
36513651 }
3652- // Try to match reduce.add(mul(...)).
3653- if (match (VecOp, m_Mul (m_VPValue (A), m_VPValue (B)))) {
3654- auto *RecipeA =
3655- dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
3656- auto *RecipeB =
3657- dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
3658- auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe ());
36593652
3660- // Match reduce.add(mul(ext, const)) and convert it to
3661- // reduce.add(mul(ext, ext(const)))
3662- if (RecipeA && !RecipeB && B->isLiveIn ()) {
3663- Type *NarrowTy = Ctx.Types .inferScalarType (RecipeA->getOperand (0 ));
3664- Instruction::CastOps ExtOpc = RecipeA->getOpcode ();
3665- auto *Const = dyn_cast<ConstantInt>(B->getLiveInIRValue ());
3653+ // If ValB is a constant and can be safely extended, truncate it to the same
3654+ // type as ExtA's operand, then extend it to the same type as ExtA. This
3655+ // creates two uniform extends that can more easily be matched by the rest of
3656+ // the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
3657+ // replaced with the new extend of the constant.
3658+ auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
3659+ VPWidenCastRecipe *&ExtB,
3660+ VPValue *&ValB, VPWidenRecipe *Mul) {
3661+ if (ExtA && !ExtB && ValB->isLiveIn ()) {
3662+ Type *NarrowTy = Ctx.Types .inferScalarType (ExtA->getOperand (0 ));
3663+ Type *WideTy = Ctx.Types .inferScalarType (ExtA);
3664+ Instruction::CastOps ExtOpc = ExtA->getOpcode ();
3665+ auto *Const = dyn_cast<ConstantInt>(ValB->getLiveInIRValue ());
36663666 if (Const &&
36673667 llvm::canConstantBeExtended (
36683668 Const, NarrowTy, TTI::getPartialReductionExtendKind (ExtOpc))) {
36693669 // The truncate ensures that the type of each extended operand is the
36703670 // same, and it's been proven that the constant can be extended from
3671- // NarrowTy safely. Necessary since RecipeA 's extended operand would be
3671+ // NarrowTy safely. Necessary since ExtA 's extended operand would be
36723672 // e.g. an i8, while the const will likely be an i32. This will be
36733673 // elided by later optimisations.
36743674 auto *Trunc =
3675- new VPWidenCastRecipe (Instruction::CastOps::Trunc, B, NarrowTy);
3676- Trunc->insertBefore (*RecipeA->getParent (),
3677- std::next (RecipeA->getIterator ()));
3678-
3679- Type *WideTy = Ctx.Types .inferScalarType (RecipeA);
3680- RecipeB = new VPWidenCastRecipe (ExtOpc, Trunc, WideTy);
3681- RecipeB->insertAfter (Trunc);
3682- Mul->setOperand (1 , RecipeB);
3675+ new VPWidenCastRecipe (Instruction::CastOps::Trunc, ValB, NarrowTy);
3676+ Trunc->insertBefore (*ExtA->getParent (), std::next (ExtA->getIterator ()));
3677+
3678+ VPWidenCastRecipe *NewCast =
3679+ new VPWidenCastRecipe (ExtOpc, Trunc, WideTy);
3680+ NewCast->insertAfter (Trunc);
3681+ ExtB = NewCast;
3682+ ValB = NewCast;
3683+ Mul->setOperand (1 , NewCast);
36833684 }
36843685 }
3686+ };
3687+
3688+ // Try to match reduce.add(mul(...)).
3689+ if (match (VecOp, m_Mul (m_VPValue (A), m_VPValue (B)))) {
3690+ auto *RecipeA =
3691+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
3692+ auto *RecipeB =
3693+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
3694+ auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe ());
3695+
3696+ // Convert reduce.add(mul(ext, const)) to reduce.add(mul(ext, ext(const)))
3697+ ExtendAndReplaceConstantOp (RecipeA, RecipeB, B, Mul);
36853698
36863699 // Match reduce.add/sub(mul(ext, ext)).
36873700 if (RecipeA && RecipeB && match (RecipeA, m_ZExtOrSExt (m_VPValue ())) &&
@@ -3692,7 +3705,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36923705 cast<VPWidenRecipe>(Sub), Red);
36933706 return new VPExpressionRecipe (RecipeA, RecipeB, Mul, Red);
36943707 }
3695- // Match reduce.add(mul).
36963708 // TODO: Add an expression type for this variant with a negated mul
36973709 if (!Sub && IsMulAccValidAndClampRange (Mul, nullptr , nullptr , nullptr ))
36983710 return new VPExpressionRecipe (Mul, Red);
@@ -3701,18 +3713,23 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
37013713 // variants.
37023714 if (Sub)
37033715 return nullptr ;
3704- // Match reduce.add(ext(mul(ext(A), ext(B)))).
3705- // All extend recipes must have same opcode or A == B
3706- // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
3707- if (match (VecOp, m_ZExtOrSExt (m_Mul (m_ZExtOrSExt (m_VPValue ()),
3708- m_ZExtOrSExt (m_VPValue ()))))) {
3716+
3717+ // Match reduce.add(ext(mul(A, B))).
3718+ if (match (VecOp, m_ZExtOrSExt (m_Mul (m_VPValue (A), m_VPValue (B))))) {
37093719 auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe ());
37103720 auto *Mul = cast<VPWidenRecipe>(Ext->getOperand (0 )->getDefiningRecipe ());
3711- auto *Ext0 =
3712- cast<VPWidenCastRecipe>(Mul->getOperand (0 )->getDefiningRecipe ());
3713- auto *Ext1 =
3714- cast<VPWidenCastRecipe>(Mul->getOperand (1 )->getDefiningRecipe ());
3715- if ((Ext->getOpcode () == Ext0->getOpcode () || Ext0 == Ext1) &&
3721+ auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
3722+ auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
3723+
3724+ // Convert reduce.add(ext(mul(ext, const))) to reduce.add(ext(mul(ext,
3725+ // ext(const))))
3726+ ExtendAndReplaceConstantOp (Ext0, Ext1, B, Mul);
3727+
3728+ // Match reduce.add(ext(mul(ext(A), ext(B))))
3729+ // All extend recipes must have same opcode or A == B
3730+ // which can be transformed to reduce.add(zext(mul(sext(A), sext(B)))).
3731+ if (Ext0 && Ext1 &&
3732+ (Ext->getOpcode () == Ext0->getOpcode () || Ext0 == Ext1) &&
37163733 Ext0->getOpcode () == Ext1->getOpcode () &&
37173734 IsMulAccValidAndClampRange (Mul, Ext0, Ext1, Ext) && Mul->hasOneUse ()) {
37183735 auto *NewExt0 = new VPWidenCastRecipe (
0 commit comments