@@ -3596,39 +3596,52 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35963596 Sub = VecOp->getDefiningRecipe ();
35973597 VecOp = Tmp;
35983598 }
3599- // Try to match reduce.add(mul(...)).
3600- if (match (VecOp, m_Mul (m_VPValue (A), m_VPValue (B)))) {
3601- auto *RecipeA =
3602- dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
3603- auto *RecipeB =
3604- dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
3605- auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe ());
36063599
3607- // Match reduce.add(mul(ext, const)) and convert it to
3608- // reduce.add(mul(ext, ext(const)))
3609- if (RecipeA && !RecipeB && B->isLiveIn ()) {
3610- Type *NarrowTy = Ctx.Types .inferScalarType (RecipeA->getOperand (0 ));
3611- Instruction::CastOps ExtOpc = RecipeA->getOpcode ();
3612- auto *Const = dyn_cast<ConstantInt>(B->getLiveInIRValue ());
3600+ // If ValB is a constant and can be safely extended, truncate it to the same
3601+ // type as ExtA's operand, then extend it to the same type as ExtA. This
3602+ // creates two uniform extends that can more easily be matched by the rest of
3603+ // the bundling code. The ExtB reference, ValB and operand 1 of Mul are all
3604+ // replaced with the new extend of the constant.
3605+ auto ExtendAndReplaceConstantOp = [&Ctx](VPWidenCastRecipe *ExtA,
3606+ VPWidenCastRecipe *&ExtB,
3607+ VPValue *&ValB, VPWidenRecipe *Mul) {
3608+ if (ExtA && !ExtB && ValB->isLiveIn ()) {
3609+ Type *NarrowTy = Ctx.Types .inferScalarType (ExtA->getOperand (0 ));
3610+ Type *WideTy = Ctx.Types .inferScalarType (ExtA);
3611+ Instruction::CastOps ExtOpc = ExtA->getOpcode ();
3612+ auto *Const = dyn_cast<ConstantInt>(ValB->getLiveInIRValue ());
36133613 if (Const &&
36143614 llvm::canConstantBeExtended (
36153615 Const, NarrowTy, TTI::getPartialReductionExtendKind (ExtOpc))) {
36163616 // The truncate ensures that the type of each extended operand is the
36173617 // same, and it's been proven that the constant can be extended from
3618- // NarrowTy safely. Necessary since RecipeA 's extended operand would be
3618+ // NarrowTy safely. Necessary since ExtA 's extended operand would be
36193619 // e.g. an i8, while the const will likely be an i32. This will be
36203620 // elided by later optimisations.
36213621 auto *Trunc =
3622- new VPWidenCastRecipe (Instruction::CastOps::Trunc, B, NarrowTy);
3623- Trunc->insertBefore (*RecipeA->getParent (),
3624- std::next (RecipeA->getIterator ()));
3625-
3626- Type *WideTy = Ctx.Types .inferScalarType (RecipeA);
3627- RecipeB = new VPWidenCastRecipe (ExtOpc, Trunc, WideTy);
3628- RecipeB->insertAfter (Trunc);
3629- Mul->setOperand (1 , RecipeB);
3622+ new VPWidenCastRecipe (Instruction::CastOps::Trunc, ValB, NarrowTy);
3623+ Trunc->insertBefore (*ExtA->getParent (), std::next (ExtA->getIterator ()));
3624+
3625+ VPWidenCastRecipe *NewCast =
3626+ new VPWidenCastRecipe (ExtOpc, Trunc, WideTy);
3627+ NewCast->insertAfter (Trunc);
3628+ ExtB = NewCast;
3629+ ValB = NewCast;
3630+ Mul->setOperand (1 , NewCast);
36303631 }
36313632 }
3633+ };
3634+
3635+ // Try to match reduce.add(mul(...)).
3636+ if (match (VecOp, m_Mul (m_VPValue (A), m_VPValue (B)))) {
3637+ auto *RecipeA =
3638+ dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
3639+ auto *RecipeB =
3640+ dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
3641+ auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe ());
3642+
3643+ // Convert reduce.add(mul(ext, const)) to reduce.add(mul(ext, ext(const)))
3644+ ExtendAndReplaceConstantOp (RecipeA, RecipeB, B, Mul);
36323645
36333646 // Match reduce.add/sub(mul(ext, ext)).
36343647 if (RecipeA && RecipeB && match (RecipeA, m_ZExtOrSExt (m_VPValue ())) &&
@@ -3639,7 +3652,6 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36393652 cast<VPWidenRecipe>(Sub), Red);
36403653 return new VPExpressionRecipe (RecipeA, RecipeB, Mul, Red);
36413654 }
3642- // Match reduce.add(mul).
36433655 // TODO: Add an expression type for this variant with a negated mul
36443656 if (!Sub && IsMulAccValidAndClampRange (Mul, nullptr , nullptr , nullptr ))
36453657 return new VPExpressionRecipe (Mul, Red);
@@ -3648,18 +3660,23 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
36483660 // variants.
36493661 if (Sub)
36503662 return nullptr ;
3651- // Match reduce.add(ext(mul(ext(A), ext(B)))).
3652- // All extend recipes must have same opcode or A == B
3653- // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
3654- if (match (VecOp, m_ZExtOrSExt (m_Mul (m_ZExtOrSExt (m_VPValue ()),
3655- m_ZExtOrSExt (m_VPValue ()))))) {
3663+
3664+ // Match reduce.add(ext(mul(A, B))).
3665+ if (match (VecOp, m_ZExtOrSExt (m_Mul (m_VPValue (A), m_VPValue (B))))) {
36563666 auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe ());
36573667 auto *Mul = cast<VPWidenRecipe>(Ext->getOperand (0 )->getDefiningRecipe ());
3658- auto *Ext0 =
3659- cast<VPWidenCastRecipe>(Mul->getOperand (0 )->getDefiningRecipe ());
3660- auto *Ext1 =
3661- cast<VPWidenCastRecipe>(Mul->getOperand (1 )->getDefiningRecipe ());
3662- if ((Ext->getOpcode () == Ext0->getOpcode () || Ext0 == Ext1) &&
3668+ auto *Ext0 = dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe ());
3669+ auto *Ext1 = dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe ());
3670+
3671+ // Convert reduce.add(ext(mul(ext, const))) to reduce.add(ext(mul(ext,
3672+ // ext(const))))
3673+ ExtendAndReplaceConstantOp (Ext0, Ext1, B, Mul);
3674+
3675+ // Match reduce.add(ext(mul(ext(A), ext(B))))
3676+ // All extend recipes must have same opcode or A == B
3677+ // which can be transformed to reduce.add(zext(mul(sext(A), sext(B)))).
3678+ if (Ext0 && Ext1 &&
3679+ (Ext->getOpcode () == Ext0->getOpcode () || Ext0 == Ext1) &&
36633680 Ext0->getOpcode () == Ext1->getOpcode () &&
36643681 IsMulAccValidAndClampRange (Mul, Ext0, Ext1, Ext)) {
36653682 auto *NewExt0 = new VPWidenCastRecipe (
0 commit comments