From e73df41a901a949dd6bebd97efb2b30d6d8432f0 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 22 Sep 2025 17:16:40 +0100 Subject: [PATCH 1/8] [LV] Add ExtNegatedMulAccReduction expression type This PR adds the ExtNegatedMulAccReduction expression type for VPExpressionRecipe so that extend-multiply-accumulate reductions with a negated multiply can be bundled. Stacked PRs: 1. https://github.com/llvm/llvm-project/pull/156976 2. -> This 3. https://github.com/llvm/llvm-project/pull/147302 --- llvm/lib/Transforms/Vectorize/VPlan.h | 11 ++ .../lib/Transforms/Vectorize/VPlanRecipes.cpp | 31 ++++- .../Transforms/Vectorize/VPlanTransforms.cpp | 46 ++++--- .../vplan-printing-reductions.ll | 121 ++++++++++++++++++ 4 files changed, 192 insertions(+), 17 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 0822511150e9e..ed3d5d7da9352 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2997,6 +2997,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe { /// vector operands, performing a reduction.add on the result, and adding /// the scalar result to a chain. MulAccReduction, + /// Represent an inloop multiply-accumulate reduction, multiplying the + /// extended vector operands, negating the multiplication, performing a + /// reduction.add + /// on the result, and adding + /// the scalar result to a chain. + ExtNegatedMulAccReduction, }; /// Type of the expression. @@ -3020,6 +3026,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe { VPWidenRecipe *Mul, VPReductionRecipe *Red) : VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction, {Ext0, Ext1, Mul, Red}) {} + VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1, + VPWidenRecipe *Mul, VPWidenRecipe *Sub, + VPReductionRecipe *Red) + : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction, + {Ext0, Ext1, Mul, Sub, Red}) {} ~VPExpressionRecipe() override { for (auto *R : reverse(ExpressionRecipes)) diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index b5e30cb1fa655..3d9f7ec1d4c7c 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2839,12 +2839,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy, Ctx.CostKind); - case ExpressionTypes::ExtMulAccReduction: + case ExpressionTypes::ExtNegatedMulAccReduction: + case ExpressionTypes::ExtMulAccReduction: { + if (ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction && + Opcode == Instruction::Add) + Opcode = Instruction::Sub; return Ctx.TTI.getMulAccReductionCost( cast(ExpressionRecipes.front())->getOpcode() == Instruction::ZExt, Opcode, RedTy, SrcVecTy, Ctx.CostKind); } + } llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum"); } @@ -2890,6 +2895,30 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent, O << ")"; break; } + case ExpressionTypes::ExtNegatedMulAccReduction: { + getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); + O << " + reduce." + << Instruction::getOpcodeName( + RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind())) + << " (sub (0, mul"; + auto *Mul = cast(ExpressionRecipes[2]); + Mul->printFlags(O); + O << "("; + getOperand(0)->printAsOperand(O, SlotTracker); + auto *Ext0 = cast(ExpressionRecipes[0]); + O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to " + << *Ext0->getResultType() << "), ("; + getOperand(1)->printAsOperand(O, SlotTracker); + auto *Ext1 = cast(ExpressionRecipes[1]); + O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to " + << *Ext1->getResultType() << ")"; + if (Red->isConditional()) { + O << ", "; + Red->getCondOp()->printAsOperand(O, SlotTracker); + } + O << "))"; + break; + } case ExpressionTypes::MulAccReduction: case ExpressionTypes::ExtMulAccReduction: { getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 5252e1f928294..cec05d03a21e6 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3543,14 +3543,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, }; VPValue *VecOp = Red->getVecOp(); + VPValue *Mul = nullptr; + VPValue *Sub = nullptr; VPValue *A, *B; + // Sub reductions could have a sub between the add reduction and vec op. + if (match(VecOp, + m_Binary(m_SpecificInt(0), m_VPValue(Mul)))) + Sub = VecOp; + else + Mul = VecOp; // Try to match reduce.add(mul(...)). - if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) { + if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) { auto *RecipeA = dyn_cast_if_present(A->getDefiningRecipe()); auto *RecipeB = dyn_cast_if_present(B->getDefiningRecipe()); - auto *Mul = cast(VecOp->getDefiningRecipe()); + auto *MulR = cast(Mul->getDefiningRecipe()); // Match reduce.add(mul(ext, ext)). if (RecipeA && RecipeB && @@ -3559,29 +3567,35 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, match(RecipeB, m_ZExtOrSExt(m_VPValue())) && IsMulAccValidAndClampRange(RecipeA->getOpcode() == Instruction::CastOps::ZExt, - Mul, RecipeA, RecipeB, nullptr)) { - return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red); + MulR, RecipeA, RecipeB, nullptr)) { + if (Sub) + return new VPExpressionRecipe( + RecipeA, RecipeB, MulR, + cast(Sub->getDefiningRecipe()), Red); + return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red); } // Match reduce.add(mul). - if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) - return new VPExpressionRecipe(Mul, Red); + // TODO: Add an expression type for this variant with a negated mul + if (!Sub && + IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr)) + return new VPExpressionRecipe(MulR, Red); } // Match reduce.add(ext(mul(ext(A), ext(B)))). // All extend recipes must have same opcode or A == B // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))). - if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()), - m_ZExtOrSExt(m_VPValue()))))) { + if (!Sub && match(Mul, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()), + m_ZExtOrSExt(m_VPValue()))))) { auto *Ext = cast(VecOp->getDefiningRecipe()); - auto *Mul = cast(Ext->getOperand(0)->getDefiningRecipe()); + auto *MulR = cast(Ext->getOperand(0)->getDefiningRecipe()); auto *Ext0 = - cast(Mul->getOperand(0)->getDefiningRecipe()); + cast(MulR->getOperand(0)->getDefiningRecipe()); auto *Ext1 = - cast(Mul->getOperand(1)->getDefiningRecipe()); + cast(MulR->getOperand(1)->getDefiningRecipe()); if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) && Ext0->getOpcode() == Ext1->getOpcode() && IsMulAccValidAndClampRange(Ext0->getOpcode() == Instruction::CastOps::ZExt, - Mul, Ext0, Ext1, Ext)) { + MulR, Ext0, Ext1, Ext)) { auto *NewExt0 = new VPWidenCastRecipe( Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0, *Ext0, Ext0->getDebugLoc()); @@ -3594,10 +3608,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, Ext1->getDebugLoc()); NewExt1->insertBefore(Ext1); } - Mul->setOperand(0, NewExt0); - Mul->setOperand(1, NewExt1); - Red->setOperand(1, Mul); - return new VPExpressionRecipe(NewExt0, NewExt1, Mul, Red); + MulR->setOperand(0, NewExt0); + MulR->setOperand(1, NewExt1); + Red->setOperand(1, MulR); + return new VPExpressionRecipe(NewExt0, NewExt1, MulR, Red); } } return nullptr; diff --git a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll index 4e6ef0de6a9ed..5a0c69bf5db1b 100644 --- a/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll +++ b/llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll @@ -580,6 +580,127 @@ exit: ret i32 %add } +define i32 @print_mulacc_negated(ptr %a, ptr %b) { +; CHECK-LABEL: 'print_mulacc_negated' +; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' { +; CHECK-NEXT: Live-in vp<%0> = VF +; CHECK-NEXT: Live-in vp<%1> = VF * UF +; CHECK-NEXT: Live-in vp<%2> = vector-trip-count +; CHECK-NEXT: Live-in ir<1024> = original trip-count +; CHECK-EMPTY: +; CHECK-NEXT: ir-bb: +; CHECK-NEXT: Successor(s): scalar.ph, vector.ph +; CHECK-EMPTY: +; CHECK-NEXT: vector.ph: +; CHECK-NEXT: EMIT vp<%3> = reduction-start-vector ir<0>, ir<0>, ir<1> +; CHECK-NEXT: Successor(s): vector loop +; CHECK-EMPTY: +; CHECK-NEXT: vector loop: { +; CHECK-NEXT: vector.body: +; CHECK-NEXT: EMIT vp<%4> = CANONICAL-INDUCTION ir<0>, vp<%index.next> +; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi vp<%3>, vp<%8> +; CHECK-NEXT: vp<%5> = SCALAR-STEPS vp<%4>, ir<1>, vp<%0> +; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%5> +; CHECK-NEXT: vp<%6> = vector-pointer ir<%gep.a> +; CHECK-NEXT: WIDEN ir<%load.a> = load vp<%6> +; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%5> +; CHECK-NEXT: vp<%7> = vector-pointer ir<%gep.b> +; CHECK-NEXT: WIDEN ir<%load.b> = load vp<%7> +; CHECK-NEXT: EXPRESSION vp<%8> = ir<%accum> + reduce.add (sub (0, mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32))) +; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%4>, vp<%1> +; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%2> +; CHECK-NEXT: No successors +; CHECK-NEXT: } +; CHECK-NEXT: Successor(s): middle.block +; CHECK-EMPTY: +; CHECK-NEXT: middle.block: +; CHECK-NEXT: EMIT vp<%10> = compute-reduction-result ir<%accum>, vp<%8> +; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<1024>, vp<%2> +; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n> +; CHECK-NEXT: Successor(s): ir-bb, scalar.ph +; CHECK-EMPTY: +; CHECK-NEXT: ir-bb: +; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<%10> from middle.block) +; CHECK-NEXT: No successors +; CHECK-EMPTY: +; CHECK-NEXT: scalar.ph: +; CHECK-NEXT: EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%2>, middle.block ], [ ir<0>, ir-bb ] +; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%10>, middle.block ], [ ir<0>, ir-bb ] +; CHECK-NEXT: Successor(s): ir-bb +; CHECK-EMPTY: +; CHECK-NEXT: ir-bb: +; CHECK-NEXT: IR %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] (extra operand: vp<%bc.resume.val> from scalar.ph) +; CHECK-NEXT: IR %accum = phi i32 [ 0, %entry ], [ %add, %loop ] (extra operand: vp<%bc.merge.rdx> from scalar.ph) +; CHECK-NEXT: IR %gep.a = getelementptr i8, ptr %a, i64 %iv +; CHECK-NEXT: IR %load.a = load i8, ptr %gep.a, align 1 +; CHECK-NEXT: IR %ext.a = zext i8 %load.a to i32 +; CHECK-NEXT: IR %gep.b = getelementptr i8, ptr %b, i64 %iv +; CHECK-NEXT: IR %load.b = load i8, ptr %gep.b, align 1 +; CHECK-NEXT: IR %ext.b = zext i8 %load.b to i32 +; CHECK-NEXT: IR %mul = mul i32 %ext.b, %ext.a +; CHECK-NEXT: IR %sub = sub i32 0, %mul +; CHECK-NEXT: IR %add = add i32 %accum, %sub +; CHECK-NEXT: IR %iv.next = add i64 %iv, 1 +; CHECK-NEXT: IR %exitcond.not = icmp eq i64 %iv.next, 1024 +; CHECK-NEXT: No successors +; CHECK-NEXT: } +; CHECK: VPlan 'Final VPlan for VF={4},UF={1}' { +; CHECK-NEXT: Live-in ir<1024> = vector-trip-count +; CHECK-NEXT: Live-in ir<1024> = original trip-count +; CHECK-EMPTY: +; CHECK-NEXT: ir-bb: +; CHECK-NEXT: Successor(s): vector.ph +; CHECK-EMPTY: +; CHECK-NEXT: vector.ph: +; CHECK-NEXT: Successor(s): vector.body +; CHECK-EMPTY: +; CHECK-NEXT: vector.body: +; CHECK-NEXT: EMIT-SCALAR vp<%index> = phi [ ir<0>, vector.ph ], [ vp<%index.next>, vector.body ] +; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add> +; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%index> +; CHECK-NEXT: WIDEN ir<%load.a> = load ir<%gep.a> +; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%index> +; CHECK-NEXT: WIDEN ir<%load.b> = load ir<%gep.b> +; CHECK-NEXT: WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32 +; CHECK-NEXT: WIDEN-CAST ir<%ext.a> = zext ir<%load.a> to i32 +; CHECK-NEXT: WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a> +; CHECK-NEXT: WIDEN ir<%sub> = sub ir<0>, ir<%mul> +; CHECK-NEXT: REDUCE ir<%add> = ir<%accum> + reduce.add (ir<%sub>) +; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%index>, ir<4> +; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, ir<1024> +; CHECK-NEXT: Successor(s): middle.block, vector.body +; CHECK-EMPTY: +; CHECK-NEXT: middle.block: +; CHECK-NEXT: EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<%accum>, ir<%add> +; CHECK-NEXT: Successor(s): ir-bb +; CHECK-EMPTY: +; CHECK-NEXT: ir-bb: +; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<[[RED_RESULT]]> from middle.block) +; CHECK-NEXT: No successors +; CHECK-NEXT: } +entry: + br label %loop + +loop: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] + %accum = phi i32 [ 0, %entry ], [ %add, %loop ] + %gep.a = getelementptr i8, ptr %a, i64 %iv + %load.a = load i8, ptr %gep.a, align 1 + %ext.a = zext i8 %load.a to i32 + %gep.b = getelementptr i8, ptr %b, i64 %iv + %load.b = load i8, ptr %gep.b, align 1 + %ext.b = zext i8 %load.b to i32 + %mul = mul i32 %ext.b, %ext.a + %sub = sub i32 0, %mul + %add = add i32 %accum, %sub + %iv.next = add i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, 1024 + br i1 %exitcond.not, label %exit, label %loop + +exit: + ret i32 %add +} + define i64 @print_mulacc_sub_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) { ; CHECK-LABEL: 'print_mulacc_sub_extended' ; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' { From afdf8549d2f98f6f7611c17f2e4dca606bc66e23 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Tue, 23 Sep 2025 22:37:43 +0100 Subject: [PATCH 2/8] Fix comment formatting --- llvm/lib/Transforms/Vectorize/VPlan.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index ed3d5d7da9352..2c725368e2265 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2999,9 +2999,7 @@ class VPExpressionRecipe : public VPSingleDefRecipe { MulAccReduction, /// Represent an inloop multiply-accumulate reduction, multiplying the /// extended vector operands, negating the multiplication, performing a - /// reduction.add - /// on the result, and adding - /// the scalar result to a chain. + /// reduction.add on the result, and adding the scalar result to a chain. ExtNegatedMulAccReduction, }; From 43814c04fc3c59a7b3d55bfc951bddfeaf8f8566 Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Tue, 23 Sep 2025 22:50:12 +0100 Subject: [PATCH 3/8] Remove renaming --- .../Transforms/Vectorize/VPlanTransforms.cpp | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index cec05d03a21e6..214a317b353cd 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3543,22 +3543,20 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, }; VPValue *VecOp = Red->getVecOp(); - VPValue *Mul = nullptr; VPValue *Sub = nullptr; VPValue *A, *B; // Sub reductions could have a sub between the add reduction and vec op. - if (match(VecOp, - m_Binary(m_SpecificInt(0), m_VPValue(Mul)))) + if (match(VecOp, m_Binary(m_SpecificInt(0), m_VPValue()))) { Sub = VecOp; - else - Mul = VecOp; + VecOp = cast(VecOp->getDefiningRecipe())->getOperand(1); + } // Try to match reduce.add(mul(...)). - if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) { + if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) { auto *RecipeA = dyn_cast_if_present(A->getDefiningRecipe()); auto *RecipeB = dyn_cast_if_present(B->getDefiningRecipe()); - auto *MulR = cast(Mul->getDefiningRecipe()); + auto *Mul = cast(VecOp->getDefiningRecipe()); // Match reduce.add(mul(ext, ext)). if (RecipeA && RecipeB && @@ -3567,35 +3565,35 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, match(RecipeB, m_ZExtOrSExt(m_VPValue())) && IsMulAccValidAndClampRange(RecipeA->getOpcode() == Instruction::CastOps::ZExt, - MulR, RecipeA, RecipeB, nullptr)) { + Mul, RecipeA, RecipeB, nullptr)) { if (Sub) return new VPExpressionRecipe( - RecipeA, RecipeB, MulR, + RecipeA, RecipeB, Mul, cast(Sub->getDefiningRecipe()), Red); - return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red); + return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red); } // Match reduce.add(mul). // TODO: Add an expression type for this variant with a negated mul if (!Sub && - IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr)) - return new VPExpressionRecipe(MulR, Red); + IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) + return new VPExpressionRecipe(Mul, Red); } // Match reduce.add(ext(mul(ext(A), ext(B)))). // All extend recipes must have same opcode or A == B // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))). - if (!Sub && match(Mul, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()), - m_ZExtOrSExt(m_VPValue()))))) { + if (!Sub && match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()), + m_ZExtOrSExt(m_VPValue()))))) { auto *Ext = cast(VecOp->getDefiningRecipe()); - auto *MulR = cast(Ext->getOperand(0)->getDefiningRecipe()); + auto *Mul = cast(Ext->getOperand(0)->getDefiningRecipe()); auto *Ext0 = - cast(MulR->getOperand(0)->getDefiningRecipe()); + cast(Mul->getOperand(0)->getDefiningRecipe()); auto *Ext1 = - cast(MulR->getOperand(1)->getDefiningRecipe()); + cast(Mul->getOperand(1)->getDefiningRecipe()); if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) && Ext0->getOpcode() == Ext1->getOpcode() && IsMulAccValidAndClampRange(Ext0->getOpcode() == Instruction::CastOps::ZExt, - MulR, Ext0, Ext1, Ext)) { + Mul, Ext0, Ext1, Ext)) { auto *NewExt0 = new VPWidenCastRecipe( Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0, *Ext0, Ext0->getDebugLoc()); @@ -3608,10 +3606,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, Ext1->getDebugLoc()); NewExt1->insertBefore(Ext1); } - MulR->setOperand(0, NewExt0); - MulR->setOperand(1, NewExt1); - Red->setOperand(1, MulR); - return new VPExpressionRecipe(NewExt0, NewExt1, MulR, Red); + Mul->setOperand(0, NewExt0); + Mul->setOperand(1, NewExt1); + Red->setOperand(1, Mul); + return new VPExpressionRecipe(NewExt0, NewExt1, Mul, Red); } } return nullptr; From 1d342204a5d467986eef4b8a85e95ffbfe49f58f Mon Sep 17 00:00:00 2001 From: Samuel Tebbs Date: Tue, 23 Sep 2025 22:52:18 +0100 Subject: [PATCH 4/8] Add missing TODO --- llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 214a317b353cd..887e20f3d3afe 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3581,6 +3581,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, // Match reduce.add(ext(mul(ext(A), ext(B)))). // All extend recipes must have same opcode or A == B // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))). + // TODO: Add an expression type for this variant with a negated mul if (!Sub && match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()), m_ZExtOrSExt(m_VPValue()))))) { auto *Ext = cast(VecOp->getDefiningRecipe()); From 37604c94d0fa333a5e48258ee308759490bb995f Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Fri, 26 Sep 2025 14:41:40 +0100 Subject: [PATCH 5/8] Assert on mul, reduction and sub --- llvm/lib/Transforms/Vectorize/VPlan.h | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 2c725368e2265..4c7a083e0d9b7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -3028,7 +3028,15 @@ class VPExpressionRecipe : public VPSingleDefRecipe { VPWidenRecipe *Mul, VPWidenRecipe *Sub, VPReductionRecipe *Red) : VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction, - {Ext0, Ext1, Mul, Sub, Red}) {} + {Ext0, Ext1, Mul, Sub, Red}) { + assert(Mul->getOpcode() == Instruction::Mul && "Expected a mul"); + assert(Red->getRecurrenceKind() == RecurKind::Add && + "Expected an add reduction"); + assert(getNumOperands() >= 3 && "Expected at least three operands"); + auto *SubConst = dyn_cast(getOperand(2)->getLiveInIRValue()); + assert(SubConst && SubConst->getValue() == 0 && + Sub->getOpcode() == Instruction::Sub && "Expected a negating sub"); + } ~VPExpressionRecipe() override { for (auto *R : reverse(ExpressionRecipes)) From 8448fd215d906328a1320ed2b4f32df8410afbe7 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Fri, 26 Sep 2025 16:37:00 +0100 Subject: [PATCH 6/8] Simplify computeCost --- llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 3d9f7ec1d4c7c..5a4ef0ff4532b 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2840,10 +2840,10 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, Ctx.CostKind); case ExpressionTypes::ExtNegatedMulAccReduction: - case ExpressionTypes::ExtMulAccReduction: { - if (ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction && - Opcode == Instruction::Add) + if (Opcode == Instruction::Add) Opcode = Instruction::Sub; + LLVM_FALLTHROUGH; + case ExpressionTypes::ExtMulAccReduction: { return Ctx.TTI.getMulAccReductionCost( cast(ExpressionRecipes.front())->getOpcode() == Instruction::ZExt, From 90a5888e42af195b18c0551729049ce7b105acdf Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Fri, 26 Sep 2025 17:00:22 +0100 Subject: [PATCH 7/8] Use an assert --- llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 5a4ef0ff4532b..ee03729f150b2 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2840,8 +2840,8 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF, Ctx.CostKind); case ExpressionTypes::ExtNegatedMulAccReduction: - if (Opcode == Instruction::Add) - Opcode = Instruction::Sub; + assert(Opcode == Instruction::Add && "Unexpected opcode"); + Opcode = Instruction::Sub; LLVM_FALLTHROUGH; case ExpressionTypes::ExtMulAccReduction: { return Ctx.TTI.getMulAccReductionCost( From 1987159e177a1fcc8cef5884d662e113597ff5d5 Mon Sep 17 00:00:00 2001 From: Sam Tebbs Date: Mon, 29 Sep 2025 14:07:45 +0100 Subject: [PATCH 8/8] Address review --- .../Transforms/Vectorize/VPlanTransforms.cpp | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index 887e20f3d3afe..969dce4bc98ae 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3543,12 +3543,14 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, }; VPValue *VecOp = Red->getVecOp(); - VPValue *Sub = nullptr; + VPRecipeBase *Sub = nullptr; VPValue *A, *B; + VPValue *Tmp = nullptr; // Sub reductions could have a sub between the add reduction and vec op. - if (match(VecOp, m_Binary(m_SpecificInt(0), m_VPValue()))) { - Sub = VecOp; - VecOp = cast(VecOp->getDefiningRecipe())->getOperand(1); + if (match(VecOp, + m_Binary(m_SpecificInt(0), m_VPValue(Tmp)))) { + Sub = VecOp->getDefiningRecipe(); + VecOp = Tmp; } // Try to match reduce.add(mul(...)). if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) { @@ -3567,9 +3569,8 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, Instruction::CastOps::ZExt, Mul, RecipeA, RecipeB, nullptr)) { if (Sub) - return new VPExpressionRecipe( - RecipeA, RecipeB, Mul, - cast(Sub->getDefiningRecipe()), Red); + return new VPExpressionRecipe(RecipeA, RecipeB, Mul, + cast(Sub), Red); return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red); } // Match reduce.add(mul). @@ -3578,12 +3579,15 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red, IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr)) return new VPExpressionRecipe(Mul, Red); } + // TODO: Add an expression type for negated versions of other expression + // variants. + if (Sub) + return nullptr; // Match reduce.add(ext(mul(ext(A), ext(B)))). // All extend recipes must have same opcode or A == B // which can be transform to reduce.add(zext(mul(sext(A), sext(B)))). - // TODO: Add an expression type for this variant with a negated mul - if (!Sub && match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()), - m_ZExtOrSExt(m_VPValue()))))) { + if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()), + m_ZExtOrSExt(m_VPValue()))))) { auto *Ext = cast(VecOp->getDefiningRecipe()); auto *Mul = cast(Ext->getOperand(0)->getDefiningRecipe()); auto *Ext0 =