Skip to content

Commit 023337b

Browse files
committed
Fixed-length SVE and fix for generating MUL instructions from a partial.reduce.fadd
1 parent 63bd521 commit 023337b

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12082,7 +12082,14 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1208212082
C->isExactlyValue(1.0)) &&
1208312083
!(ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) &&
1208412084
ConstantOne.isOne()))
12085-
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12085+
switch (N->getOpcode()) {
12086+
case ISD::PARTIAL_REDUCE_FMLA:
12087+
Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12088+
break;
12089+
default:
12090+
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12091+
break;
12092+
};
1208612093

1208712094
unsigned Stride = AccVT.getVectorMinNumElements();
1208812095
unsigned ScaleFactor = MulOpVT.getVectorMinNumElements() / Stride;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2294,6 +2294,13 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
22942294
MVT::getVectorVT(MVT::i8, NumElts * 8), Custom);
22952295
}
22962296

2297+
if (Subtarget->hasSVE2p1()) {
2298+
if (VT.getVectorElementType() == MVT::f32)
2299+
setPartialReduceMLAAction(ISD::PARTIAL_REDUCE_FMLA, VT,
2300+
MVT::getVectorVT(MVT::f16, NumElts * 2),
2301+
Custom);
2302+
}
2303+
22972304
// Lower fixed length vector operations to scalable equivalents.
22982305
setOperationAction(ISD::ABDS, VT, Default);
22992306
setOperationAction(ISD::ABDU, VT, Default);
@@ -7917,6 +7924,7 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
79177924
case ISD::PARTIAL_REDUCE_SMLA:
79187925
case ISD::PARTIAL_REDUCE_UMLA:
79197926
case ISD::PARTIAL_REDUCE_SUMLA:
7927+
case ISD::PARTIAL_REDUCE_FMLA:
79207928
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
79217929
}
79227930
}

llvm/test/CodeGen/AArch64/sve2p1-fdot.ll

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,20 +29,10 @@ entry:
2929
define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
3030
; CHECK-LABEL: fdot_wide_vl256:
3131
; CHECK: // %bb.0: // %entry
32-
; CHECK-NEXT: ptrue p0.s
33-
; CHECK-NEXT: ld1h { z0.s }, p0/z, [x1]
34-
; CHECK-NEXT: ld1h { z1.s }, p0/z, [x2]
35-
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x1, #1, mul vl]
36-
; CHECK-NEXT: ld1h { z3.s }, p0/z, [x2, #1, mul vl]
37-
; CHECK-NEXT: fcvt z0.s, p0/m, z0.h
38-
; CHECK-NEXT: fcvt z1.s, p0/m, z1.h
39-
; CHECK-NEXT: fcvt z2.s, p0/m, z2.h
40-
; CHECK-NEXT: fcvt z3.s, p0/m, z3.h
41-
; CHECK-NEXT: fmul z0.s, z0.s, z1.s
42-
; CHECK-NEXT: ldr z1, [x0]
43-
; CHECK-NEXT: fmul z2.s, z2.s, z3.s
44-
; CHECK-NEXT: fadd z0.s, z1.s, z0.s
45-
; CHECK-NEXT: fadd z0.s, z0.s, z2.s
32+
; CHECK-NEXT: ldr z0, [x0]
33+
; CHECK-NEXT: ldr z1, [x1]
34+
; CHECK-NEXT: ldr z2, [x2]
35+
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
4636
; CHECK-NEXT: str z0, [x0]
4737
; CHECK-NEXT: ret
4838
entry:

0 commit comments

Comments
 (0)