Skip to content

Commit 1e26f88

Browse files
committed
[LV] Add ExtNegatedMulAccReduction expression type
This PR adds the ExtNegatedMulAccReduction expression type for VPExpressionRecipe so that extend-multiply-accumulate reductions with a negated multiply can be bundled. Stacked PRs: 1. llvm#156976 2. -> This 3. llvm#147302
1 parent 2300506 commit 1e26f88

File tree

4 files changed

+192
-17
lines changed

4 files changed

+192
-17
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,6 +2989,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
29892989
/// vector operands, performing a reduction.add on the result, and adding
29902990
/// the scalar result to a chain.
29912991
MulAccReduction,
2992+
/// Represent an inloop multiply-accumulate reduction, multiplying the
2993+
/// extended vector operands, negating the multiplication, performing a
2994+
/// reduction.add
2995+
/// on the result, and adding
2996+
/// the scalar result to a chain.
2997+
ExtNegatedMulAccReduction,
29922998
};
29932999

29943000
/// Type of the expression.
@@ -3012,6 +3018,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
30123018
VPWidenRecipe *Mul, VPReductionRecipe *Red)
30133019
: VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
30143020
{Ext0, Ext1, Mul, Red}) {}
3021+
VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
3022+
VPWidenRecipe *Mul, VPWidenRecipe *Sub,
3023+
VPReductionRecipe *Red)
3024+
: VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
3025+
{Ext0, Ext1, Mul, Sub, Red}) {}
30153026

30163027
~VPExpressionRecipe() override {
30173028
SmallSet<VPSingleDefRecipe *, 4> ExpressionRecipesSeen;

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2861,12 +2861,17 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
28612861
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
28622862
Ctx.CostKind);
28632863

2864-
case ExpressionTypes::ExtMulAccReduction:
2864+
case ExpressionTypes::ExtNegatedMulAccReduction:
2865+
case ExpressionTypes::ExtMulAccReduction: {
2866+
if (ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction &&
2867+
Opcode == Instruction::Add)
2868+
Opcode = Instruction::Sub;
28652869
return Ctx.TTI.getMulAccReductionCost(
28662870
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
28672871
Instruction::ZExt,
28682872
Opcode, RedTy, SrcVecTy, Ctx.CostKind);
28692873
}
2874+
}
28702875
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
28712876
}
28722877

@@ -2912,6 +2917,30 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
29122917
O << ")";
29132918
break;
29142919
}
2920+
case ExpressionTypes::ExtNegatedMulAccReduction: {
2921+
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
2922+
O << " + reduce."
2923+
<< Instruction::getOpcodeName(
2924+
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
2925+
<< " (sub (0, mul";
2926+
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
2927+
Mul->printFlags(O);
2928+
O << "(";
2929+
getOperand(0)->printAsOperand(O, SlotTracker);
2930+
auto *Ext0 = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
2931+
O << " " << Instruction::getOpcodeName(Ext0->getOpcode()) << " to "
2932+
<< *Ext0->getResultType() << "), (";
2933+
getOperand(1)->printAsOperand(O, SlotTracker);
2934+
auto *Ext1 = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
2935+
O << " " << Instruction::getOpcodeName(Ext1->getOpcode()) << " to "
2936+
<< *Ext1->getResultType() << ")";
2937+
if (Red->isConditional()) {
2938+
O << ", ";
2939+
Red->getCondOp()->printAsOperand(O, SlotTracker);
2940+
}
2941+
O << "))";
2942+
break;
2943+
}
29152944
case ExpressionTypes::MulAccReduction:
29162945
case ExpressionTypes::ExtMulAccReduction: {
29172946
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3524,14 +3524,22 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35243524
};
35253525

35263526
VPValue *VecOp = Red->getVecOp();
3527+
VPValue *Mul = nullptr;
3528+
VPValue *Sub = nullptr;
35273529
VPValue *A, *B;
3530+
// Sub reductions could have a sub between the add reduction and vec op.
3531+
if (match(VecOp,
3532+
m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(Mul))))
3533+
Sub = VecOp;
3534+
else
3535+
Mul = VecOp;
35283536
// Try to match reduce.add(mul(...)).
3529-
if (match(VecOp, m_Mul(m_VPValue(A), m_VPValue(B)))) {
3537+
if (match(Mul, m_Mul(m_VPValue(A), m_VPValue(B)))) {
35303538
auto *RecipeA =
35313539
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
35323540
auto *RecipeB =
35333541
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
3534-
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
3542+
auto *MulR = cast<VPWidenRecipe>(Mul->getDefiningRecipe());
35353543

35363544
// Match reduce.add(mul(ext, ext)).
35373545
if (RecipeA && RecipeB &&
@@ -3540,29 +3548,35 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35403548
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
35413549
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
35423550
Instruction::CastOps::ZExt,
3543-
Mul, RecipeA, RecipeB, nullptr)) {
3544-
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
3551+
MulR, RecipeA, RecipeB, nullptr)) {
3552+
if (Sub)
3553+
return new VPExpressionRecipe(
3554+
RecipeA, RecipeB, MulR,
3555+
cast<VPWidenRecipe>(Sub->getDefiningRecipe()), Red);
3556+
return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
35453557
}
35463558
// Match reduce.add(mul).
3547-
if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
3548-
return new VPExpressionRecipe(Mul, Red);
3559+
// TODO: Add an expression type for this variant with a negated mul
3560+
if (!Sub &&
3561+
IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr))
3562+
return new VPExpressionRecipe(MulR, Red);
35493563
}
35503564
// Match reduce.add(ext(mul(ext(A), ext(B)))).
35513565
// All extend recipes must have same opcode or A == B
35523566
// which can be transform to reduce.add(zext(mul(sext(A), sext(B)))).
3553-
if (match(VecOp, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
3554-
m_ZExtOrSExt(m_VPValue()))))) {
3567+
if (!Sub && match(Mul, m_ZExtOrSExt(m_Mul(m_ZExtOrSExt(m_VPValue()),
3568+
m_ZExtOrSExt(m_VPValue()))))) {
35553569
auto *Ext = cast<VPWidenCastRecipe>(VecOp->getDefiningRecipe());
3556-
auto *Mul = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
3570+
auto *MulR = cast<VPWidenRecipe>(Ext->getOperand(0)->getDefiningRecipe());
35573571
auto *Ext0 =
3558-
cast<VPWidenCastRecipe>(Mul->getOperand(0)->getDefiningRecipe());
3572+
cast<VPWidenCastRecipe>(MulR->getOperand(0)->getDefiningRecipe());
35593573
auto *Ext1 =
3560-
cast<VPWidenCastRecipe>(Mul->getOperand(1)->getDefiningRecipe());
3574+
cast<VPWidenCastRecipe>(MulR->getOperand(1)->getDefiningRecipe());
35613575
if ((Ext->getOpcode() == Ext0->getOpcode() || Ext0 == Ext1) &&
35623576
Ext0->getOpcode() == Ext1->getOpcode() &&
35633577
IsMulAccValidAndClampRange(Ext0->getOpcode() ==
35643578
Instruction::CastOps::ZExt,
3565-
Mul, Ext0, Ext1, Ext)) {
3579+
MulR, Ext0, Ext1, Ext)) {
35663580
auto *NewExt0 = new VPWidenCastRecipe(
35673581
Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0,
35683582
Ext0->getDebugLoc());
@@ -3575,10 +3589,10 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35753589
Ext1->getDebugLoc());
35763590
NewExt1->insertBefore(Ext1);
35773591
}
3578-
Mul->setOperand(0, NewExt0);
3579-
Mul->setOperand(1, NewExt1);
3580-
Red->setOperand(1, Mul);
3581-
return new VPExpressionRecipe(NewExt0, NewExt1, Mul, Red);
3592+
MulR->setOperand(0, NewExt0);
3593+
MulR->setOperand(1, NewExt1);
3594+
Red->setOperand(1, MulR);
3595+
return new VPExpressionRecipe(NewExt0, NewExt1, MulR, Red);
35823596
}
35833597
}
35843598
return nullptr;

llvm/test/Transforms/LoopVectorize/vplan-printing-reductions.ll

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,127 @@ exit:
580580
ret i32 %add
581581
}
582582

583+
define i32 @print_mulacc_negated(ptr %a, ptr %b) {
584+
; CHECK-LABEL: 'print_mulacc_negated'
585+
; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {
586+
; CHECK-NEXT: Live-in vp<%0> = VF
587+
; CHECK-NEXT: Live-in vp<%1> = VF * UF
588+
; CHECK-NEXT: Live-in vp<%2> = vector-trip-count
589+
; CHECK-NEXT: Live-in ir<1024> = original trip-count
590+
; CHECK-EMPTY:
591+
; CHECK-NEXT: ir-bb<entry>:
592+
; CHECK-NEXT: Successor(s): scalar.ph, vector.ph
593+
; CHECK-EMPTY:
594+
; CHECK-NEXT: vector.ph:
595+
; CHECK-NEXT: EMIT vp<%3> = reduction-start-vector ir<0>, ir<0>, ir<1>
596+
; CHECK-NEXT: Successor(s): vector loop
597+
; CHECK-EMPTY:
598+
; CHECK-NEXT: <x1> vector loop: {
599+
; CHECK-NEXT: vector.body:
600+
; CHECK-NEXT: EMIT vp<%4> = CANONICAL-INDUCTION ir<0>, vp<%index.next>
601+
; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi vp<%3>, vp<%8>
602+
; CHECK-NEXT: vp<%5> = SCALAR-STEPS vp<%4>, ir<1>, vp<%0>
603+
; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%5>
604+
; CHECK-NEXT: vp<%6> = vector-pointer ir<%gep.a>
605+
; CHECK-NEXT: WIDEN ir<%load.a> = load vp<%6>
606+
; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%5>
607+
; CHECK-NEXT: vp<%7> = vector-pointer ir<%gep.b>
608+
; CHECK-NEXT: WIDEN ir<%load.b> = load vp<%7>
609+
; CHECK-NEXT: EXPRESSION vp<%8> = ir<%accum> + reduce.add (sub (0, mul (ir<%load.b> zext to i32), (ir<%load.a> zext to i32)))
610+
; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%4>, vp<%1>
611+
; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, vp<%2>
612+
; CHECK-NEXT: No successors
613+
; CHECK-NEXT: }
614+
; CHECK-NEXT: Successor(s): middle.block
615+
; CHECK-EMPTY:
616+
; CHECK-NEXT: middle.block:
617+
; CHECK-NEXT: EMIT vp<%10> = compute-reduction-result ir<%accum>, vp<%8>
618+
; CHECK-NEXT: EMIT vp<%cmp.n> = icmp eq ir<1024>, vp<%2>
619+
; CHECK-NEXT: EMIT branch-on-cond vp<%cmp.n>
620+
; CHECK-NEXT: Successor(s): ir-bb<exit>, scalar.ph
621+
; CHECK-EMPTY:
622+
; CHECK-NEXT: ir-bb<exit>:
623+
; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<%10> from middle.block)
624+
; CHECK-NEXT: No successors
625+
; CHECK-EMPTY:
626+
; CHECK-NEXT: scalar.ph:
627+
; CHECK-NEXT: EMIT-SCALAR vp<%bc.resume.val> = phi [ vp<%2>, middle.block ], [ ir<0>, ir-bb<entry> ]
628+
; CHECK-NEXT: EMIT-SCALAR vp<%bc.merge.rdx> = phi [ vp<%10>, middle.block ], [ ir<0>, ir-bb<entry> ]
629+
; CHECK-NEXT: Successor(s): ir-bb<loop>
630+
; CHECK-EMPTY:
631+
; CHECK-NEXT: ir-bb<loop>:
632+
; CHECK-NEXT: IR %iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ] (extra operand: vp<%bc.resume.val> from scalar.ph)
633+
; CHECK-NEXT: IR %accum = phi i32 [ 0, %entry ], [ %add, %loop ] (extra operand: vp<%bc.merge.rdx> from scalar.ph)
634+
; CHECK-NEXT: IR %gep.a = getelementptr i8, ptr %a, i64 %iv
635+
; CHECK-NEXT: IR %load.a = load i8, ptr %gep.a, align 1
636+
; CHECK-NEXT: IR %ext.a = zext i8 %load.a to i32
637+
; CHECK-NEXT: IR %gep.b = getelementptr i8, ptr %b, i64 %iv
638+
; CHECK-NEXT: IR %load.b = load i8, ptr %gep.b, align 1
639+
; CHECK-NEXT: IR %ext.b = zext i8 %load.b to i32
640+
; CHECK-NEXT: IR %mul = mul i32 %ext.b, %ext.a
641+
; CHECK-NEXT: IR %sub = sub i32 0, %mul
642+
; CHECK-NEXT: IR %add = add i32 %accum, %sub
643+
; CHECK-NEXT: IR %iv.next = add i64 %iv, 1
644+
; CHECK-NEXT: IR %exitcond.not = icmp eq i64 %iv.next, 1024
645+
; CHECK-NEXT: No successors
646+
; CHECK-NEXT: }
647+
; CHECK: VPlan 'Final VPlan for VF={4},UF={1}' {
648+
; CHECK-NEXT: Live-in ir<1024> = vector-trip-count
649+
; CHECK-NEXT: Live-in ir<1024> = original trip-count
650+
; CHECK-EMPTY:
651+
; CHECK-NEXT: ir-bb<entry>:
652+
; CHECK-NEXT: Successor(s): vector.ph
653+
; CHECK-EMPTY:
654+
; CHECK-NEXT: vector.ph:
655+
; CHECK-NEXT: Successor(s): vector.body
656+
; CHECK-EMPTY:
657+
; CHECK-NEXT: vector.body:
658+
; CHECK-NEXT: EMIT-SCALAR vp<%index> = phi [ ir<0>, vector.ph ], [ vp<%index.next>, vector.body ]
659+
; CHECK-NEXT: WIDEN-REDUCTION-PHI ir<%accum> = phi ir<0>, ir<%add>
660+
; CHECK-NEXT: CLONE ir<%gep.a> = getelementptr ir<%a>, vp<%index>
661+
; CHECK-NEXT: WIDEN ir<%load.a> = load ir<%gep.a>
662+
; CHECK-NEXT: CLONE ir<%gep.b> = getelementptr ir<%b>, vp<%index>
663+
; CHECK-NEXT: WIDEN ir<%load.b> = load ir<%gep.b>
664+
; CHECK-NEXT: WIDEN-CAST ir<%ext.b> = zext ir<%load.b> to i32
665+
; CHECK-NEXT: WIDEN-CAST ir<%ext.a> = zext ir<%load.a> to i32
666+
; CHECK-NEXT: WIDEN ir<%mul> = mul ir<%ext.b>, ir<%ext.a>
667+
; CHECK-NEXT: WIDEN ir<%sub> = sub ir<0>, ir<%mul>
668+
; CHECK-NEXT: REDUCE ir<%add> = ir<%accum> + reduce.add (ir<%sub>)
669+
; CHECK-NEXT: EMIT vp<%index.next> = add nuw vp<%index>, ir<4>
670+
; CHECK-NEXT: EMIT branch-on-count vp<%index.next>, ir<1024>
671+
; CHECK-NEXT: Successor(s): middle.block, vector.body
672+
; CHECK-EMPTY:
673+
; CHECK-NEXT: middle.block:
674+
; CHECK-NEXT: EMIT vp<[[RED_RESULT:%.+]]> = compute-reduction-result ir<%accum>, ir<%add>
675+
; CHECK-NEXT: Successor(s): ir-bb<exit>
676+
; CHECK-EMPTY:
677+
; CHECK-NEXT: ir-bb<exit>:
678+
; CHECK-NEXT: IR %add.lcssa = phi i32 [ %add, %loop ] (extra operand: vp<[[RED_RESULT]]> from middle.block)
679+
; CHECK-NEXT: No successors
680+
; CHECK-NEXT: }
681+
entry:
682+
br label %loop
683+
684+
loop:
685+
%iv = phi i64 [ 0, %entry ], [ %iv.next, %loop ]
686+
%accum = phi i32 [ 0, %entry ], [ %add, %loop ]
687+
%gep.a = getelementptr i8, ptr %a, i64 %iv
688+
%load.a = load i8, ptr %gep.a, align 1
689+
%ext.a = zext i8 %load.a to i32
690+
%gep.b = getelementptr i8, ptr %b, i64 %iv
691+
%load.b = load i8, ptr %gep.b, align 1
692+
%ext.b = zext i8 %load.b to i32
693+
%mul = mul i32 %ext.b, %ext.a
694+
%sub = sub i32 0, %mul
695+
%add = add i32 %accum, %sub
696+
%iv.next = add i64 %iv, 1
697+
%exitcond.not = icmp eq i64 %iv.next, 1024
698+
br i1 %exitcond.not, label %exit, label %loop
699+
700+
exit:
701+
ret i32 %add
702+
}
703+
583704
define i64 @print_mulacc_sub_extended(ptr nocapture readonly %x, ptr nocapture readonly %y, i32 %n) {
584705
; CHECK-LABEL: 'print_mulacc_sub_extended'
585706
; CHECK: VPlan 'Initial VPlan for VF={4},UF>=1' {

0 commit comments

Comments
 (0)