Skip to content

Commit 615cbaf

Browse files
committed
Simplify conditionals and further refine documentation
1 parent 0809761 commit 615cbaf

File tree

7 files changed

+100
-48
lines changed

7 files changed

+100
-48
lines changed

llvm/docs/LangRef.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20382,7 +20382,9 @@ Semantics:
2038220382

2038320383
As the way in which the arguments to this floating-point intrinsic are reduced
2038420384
is unspecified, this intrinsic will assume floating-point reassociation and
20385-
contraction, which may result in variations to the results.
20385+
contraction can be leveraged to implement the reduction, which may result in
20386+
variations to the results due to reordering or by lowering to different
20387+
instructions (including combining multiple instructions into a single one).
2038620388

2038720389
'``llvm.vector.insert``' Intrinsic
2038820390
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,10 @@ LLVM_ABI bool isNullOrNullSplat(SDValue V, bool AllowUndefs = false);
19381938
/// be zero.
19391939
LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
19401940

1941+
/// Return true if the value is a constant floating-point value, or a splatted
1942+
/// vector of a constant floating-point value, of 1.0 (with no undefs).
1943+
LLVM_ABI bool isOneOrOneSplatFP(SDValue V, bool AllowUndefs = false);
1944+
19411945
/// Return true if the value is a constant -1 integer or a splatted vector of a
19421946
/// constant -1 integer (with no undefs).
19431947
/// Does not permit build vector implicit truncation.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13003,13 +13003,10 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1300313003
SDValue Op1 = N->getOperand(1);
1300413004
SDValue Op2 = N->getOperand(2);
1300513005

13006-
APInt C;
13007-
ConstantFPSDNode *CFP;
1300813006
if (!(Op1->getOpcode() == ISD::MUL &&
13009-
ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
13007+
llvm::isOneOrOneSplat(Op2)) &&
1301013008
!(Op1->getOpcode() == ISD::FMUL &&
13011-
(CFP = llvm::isConstOrConstSplatFP(Op2, false)) &&
13012-
CFP->isExactlyValue(1.0)))
13009+
llvm::isOneOrOneSplatFP(Op2)))
1301313010
return SDValue();
1301413011

1301513012
auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
@@ -13027,6 +13024,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1302713024

1302813025
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1302913026
// -> partial_reduce_*mla(acc, x, C)
13027+
APInt C;
1303013028
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
1303113029
// TODO: Make use of partial_reduce_sumla here
1303213030
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
@@ -13105,13 +13103,9 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1310513103
SDValue Op1 = N->getOperand(1);
1310613104
SDValue Op2 = N->getOperand(2);
1310713105

13108-
APInt ConstantOne;
13109-
ConstantFPSDNode *C;
1311013106
if (!(N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA &&
13111-
(C = llvm::isConstOrConstSplatFP(Op2, false)) &&
13112-
C->isExactlyValue(1.0)) &&
13113-
!(ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) &&
13114-
ConstantOne.isOne()))
13107+
llvm::isOneOrOneSplatFP(Op2)) &&
13108+
!llvm::isOneOrOneSplat(Op2))
1311513109
return SDValue();
1311613110

1311713111
unsigned Op1Opcode = Op1.getOpcode();

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12993,6 +12993,11 @@ bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) {
1299312993
return C && C->isOne();
1299412994
}
1299512995

12996+
bool llvm::isOneOrOneSplatFP(SDValue N, bool AllowUndefs) {
12997+
ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
12998+
return C && C->isExactlyValue(1.0);
12999+
}
13000+
1299613001
bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
1299713002
N = peekThroughBitcasts(N);
1299813003
unsigned BitWidth = N.getScalarValueSizeInBits();

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8115,16 +8115,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
81158115
return;
81168116
}
81178117
case Intrinsic::vector_partial_reduce_fadd: {
8118-
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
8119-
visitTargetIntrinsic(I, Intrinsic);
8120-
return;
8121-
}
81228118
SDValue Acc = getValue(I.getOperand(0));
81238119
SDValue Input = getValue(I.getOperand(1));
81248120
setValue(&I,
81258121
DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
81268122
Input,
8127-
DAG.getConstantFP(1.0f, sdl, Input.getValueType())));
8123+
DAG.getConstantFP(1.0, sdl, Input.getValueType())));
81288124
return;
81298125
}
81308126
case Intrinsic::experimental_cttz_elts: {

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12062,27 +12062,30 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1206212062
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
1206312063
MulOpVT.getVectorElementCount());
1206412064

12065-
unsigned ExtOpcLHS =
12066-
N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND
12067-
: N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND
12068-
: ISD::SIGN_EXTEND;
12069-
unsigned ExtOpcRHS =
12070-
N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND
12071-
: N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND
12072-
: ISD::ZERO_EXTEND;
12065+
unsigned ExtOpcLHS, ExtOpcRHS;
12066+
switch (N->getOpcode()) {
12067+
default:
12068+
llvm_unreachable("Unexpected opcode");
12069+
case ISD::PARTIAL_REDUCE_UMLA:
12070+
ExtOpcLHS = ExtOpcRHS = ISD::ZERO_EXTEND;
12071+
break;
12072+
case ISD::PARTIAL_REDUCE_SMLA:
12073+
ExtOpcLHS = ExtOpcRHS = ISD::SIGN_EXTEND;
12074+
break;
12075+
case ISD::PARTIAL_REDUCE_FMLA:
12076+
ExtOpcLHS = ExtOpcRHS = ISD::FP_EXTEND;
12077+
break;
12078+
}
1207312079

1207412080
if (ExtMulOpVT != MulOpVT) {
1207512081
MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
1207612082
MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
1207712083
}
1207812084
SDValue Input = MulLHS;
12079-
APInt ConstantOne;
1208012085
if (N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA) {
12081-
ConstantFPSDNode *C = llvm::isConstOrConstSplatFP(MulRHS, false);
12082-
if (!(C && C->isExactlyValue(1.0)))
12086+
if (!llvm::isOneOrOneSplatFP(MulRHS))
1208312087
Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12084-
} else if (!(ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) &&
12085-
ConstantOne.isOne())) {
12088+
} else if (!llvm::isOneOrOneSplat(MulRHS)) {
1208612089
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
1208712090
}
1208812091

llvm/test/CodeGen/AArch64/sve2p1-fdot.ll

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2-
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
3+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
34

45
define <vscale x 4 x float> @fdot_wide_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
5-
; CHECK-LABEL: fdot_wide_vl128:
6-
; CHECK: // %bb.0: // %entry
7-
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
8-
; CHECK-NEXT: ret
6+
; SVE2-LABEL: fdot_wide_vl128:
7+
; SVE2: // %bb.0: // %entry
8+
; SVE2-NEXT: uunpklo z3.s, z1.h
9+
; SVE2-NEXT: uunpklo z4.s, z2.h
10+
; SVE2-NEXT: ptrue p0.s
11+
; SVE2-NEXT: uunpkhi z1.s, z1.h
12+
; SVE2-NEXT: uunpkhi z2.s, z2.h
13+
; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
14+
; SVE2-NEXT: fcvt z4.s, p0/m, z4.h
15+
; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
16+
; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
17+
; SVE2-NEXT: fmul z3.s, z3.s, z4.s
18+
; SVE2-NEXT: fmul z1.s, z1.s, z2.s
19+
; SVE2-NEXT: fadd z0.s, z0.s, z3.s
20+
; SVE2-NEXT: fadd z0.s, z0.s, z1.s
21+
; SVE2-NEXT: ret
22+
;
23+
; SVE2P1-LABEL: fdot_wide_vl128:
24+
; SVE2P1: // %bb.0: // %entry
25+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
26+
; SVE2P1-NEXT: ret
927
entry:
1028
%a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
1129
%b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
@@ -15,26 +33,56 @@ entry:
1533
}
1634

1735
define <vscale x 4 x float> @fdot_splat_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a) {
18-
; CHECK-LABEL: fdot_splat_vl128:
19-
; CHECK: // %bb.0: // %entry
20-
; CHECK-NEXT: fmov z2.h, #1.00000000
21-
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
22-
; CHECK-NEXT: ret
36+
; SVE2-LABEL: fdot_splat_vl128:
37+
; SVE2: // %bb.0: // %entry
38+
; SVE2-NEXT: uunpklo z2.s, z1.h
39+
; SVE2-NEXT: ptrue p0.s
40+
; SVE2-NEXT: uunpkhi z1.s, z1.h
41+
; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
42+
; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
43+
; SVE2-NEXT: fadd z0.s, z0.s, z2.s
44+
; SVE2-NEXT: fadd z0.s, z0.s, z1.s
45+
; SVE2-NEXT: ret
46+
;
47+
; SVE2P1-LABEL: fdot_splat_vl128:
48+
; SVE2P1: // %bb.0: // %entry
49+
; SVE2P1-NEXT: fmov z2.h, #1.00000000
50+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
51+
; SVE2P1-NEXT: ret
2352
entry:
2453
%a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
2554
%partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a.wide)
2655
ret <vscale x 4 x float> %partial.reduce
2756
}
2857

2958
define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
30-
; CHECK-LABEL: fdot_wide_vl256:
31-
; CHECK: // %bb.0: // %entry
32-
; CHECK-NEXT: ldr z0, [x0]
33-
; CHECK-NEXT: ldr z1, [x1]
34-
; CHECK-NEXT: ldr z2, [x2]
35-
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
36-
; CHECK-NEXT: str z0, [x0]
37-
; CHECK-NEXT: ret
59+
; SVE2-LABEL: fdot_wide_vl256:
60+
; SVE2: // %bb.0: // %entry
61+
; SVE2-NEXT: ptrue p0.s
62+
; SVE2-NEXT: ld1h { z0.s }, p0/z, [x1]
63+
; SVE2-NEXT: ld1h { z1.s }, p0/z, [x2]
64+
; SVE2-NEXT: ld1h { z2.s }, p0/z, [x1, #1, mul vl]
65+
; SVE2-NEXT: ld1h { z3.s }, p0/z, [x2, #1, mul vl]
66+
; SVE2-NEXT: fcvt z0.s, p0/m, z0.h
67+
; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
68+
; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
69+
; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
70+
; SVE2-NEXT: fmul z0.s, z0.s, z1.s
71+
; SVE2-NEXT: ldr z1, [x0]
72+
; SVE2-NEXT: fmul z2.s, z2.s, z3.s
73+
; SVE2-NEXT: fadd z0.s, z1.s, z0.s
74+
; SVE2-NEXT: fadd z0.s, z0.s, z2.s
75+
; SVE2-NEXT: str z0, [x0]
76+
; SVE2-NEXT: ret
77+
;
78+
; SVE2P1-LABEL: fdot_wide_vl256:
79+
; SVE2P1: // %bb.0: // %entry
80+
; SVE2P1-NEXT: ldr z0, [x0]
81+
; SVE2P1-NEXT: ldr z1, [x1]
82+
; SVE2P1-NEXT: ldr z2, [x2]
83+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
84+
; SVE2P1-NEXT: str z0, [x0]
85+
; SVE2P1-NEXT: ret
3886
entry:
3987
%acc = load <8 x float>, ptr %accptr
4088
%a = load <16 x half>, ptr %aptr

0 commit comments

Comments
 (0)