Skip to content

Commit 24cec88

Browse files
committed
Fixed-length SVE and fix for generating MUL instructions from a partial.reduce.fadd
1 parent 41d845f commit 24cec88

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12083,7 +12083,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1208312083
C->isExactlyValue(1.0)) &&
1208412084
!(ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) &&
1208512085
ConstantOne.isOne()))
12086-
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12086+
switch (N->getOpcode()) {
12087+
case ISD::PARTIAL_REDUCE_FMLA:
12088+
Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12089+
break;
12090+
default:
12091+
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12092+
break;
12093+
};
1208712094

1208812095
unsigned Stride = AccVT.getVectorMinNumElements();
1208912096
unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,6 +2290,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22902290
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
22912291
}
22922292

2293+
if (Subtarget->hasSVE2p1()) {
2294+
if (VT.getVectorElementType() == MVT::f32)
2295+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, VT,
2296+
MVT::getVectorVT(MVT::f16, NumElts * 2),
2297+
Custom);
2298+
}
2299+
22932300
// Lower fixed length vector operations to scalable equivalents.
22942301
setOperationAction(ISD::ABDS, VT, Default);
22952302
setOperationAction(ISD::ABDU, VT, Default);
@@ -7911,6 +7918,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
79117918
case ISD::PARTIAL_REDUCE_SMLA:
79127919
case ISD::PARTIAL_REDUCE_UMLA:
79137920
case ISD::PARTIAL_REDUCE_SUMLA:
7921+
case ISD::PARTIAL_REDUCE_FMLA:
79147922
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
79157923
}
79167924
}

llvm/test/CodeGen/AArch64/sve2p1-fdot.ll

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,10 @@ entry:
2929
define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
3030
; CHECK-LABEL: fdot_wide_vl256:
3131
; CHECK: // %bb.0: // %entry
32-
; CHECK-NEXT: ptrue p0.s
33-
; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1]
34-
; CHECK-NEXT: ld1h { z1.s }, p0/z, [x2]
35-
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x1, #1, mul vl]
36-
; CHECK-NEXT: ld1h { z3.s }, p0/z, [x2, #1, mul vl]
37-
; CHECK-NEXT: fcvt z0.s, p0/m, z0.h
38-
; CHECK-NEXT: fcvt z1.s, p0/m, z1.h
39-
; CHECK-NEXT: fcvt z2.s, p0/m, z2.h
40-
; CHECK-NEXT: fcvt z3.s, p0/m, z3.h
41-
; CHECK-NEXT: fmul z0.s, z0.s, z1.s
42-
; CHECK-NEXT: ldr z1, [x0]
43-
; CHECK-NEXT: fmul z2.s, z2.s, z3.s
44-
; CHECK-NEXT: fadd z0.s, z1.s, z0.s
45-
; CHECK-NEXT: fadd z0.s, z0.s, z2.s
32+
; CHECK-NEXT: ldr z0, [x0]
33+
; CHECK-NEXT: ldr z1, [x1]
34+
; CHECK-NEXT: ldr z2, [x2]
35+
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
4636
; CHECK-NEXT: str z0, [x0]
4737
; CHECK-NEXT: ret
4838
entry:

0 commit comments

Comments
 (0)