Skip to content

Commit 40b9c42

Browse files
committed
Simplify conditionals and further refine documentation
1 parent b3c5076 commit 40b9c42

File tree

7 files changed

+100
-50
lines changed

7 files changed

+100
-50
lines changed

llvm/docs/LangRef.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20430,7 +20430,9 @@ Semantics:
2043020430

2043120431
As the way in which the arguments to this floating-point intrinsic are reduced
2043220432
is unspecified, this intrinsic will assume floating-point reassociation and
20433-
contraction, which may result in variations to the results.
20433+
contraction can be leveraged to implement the reduction, which may result in
20434+
variations to the results due to reordering or by lowering to different
20435+
instructions (including combining multiple instructions into a single one).
2043420436

2043520437
'``llvm.vector.insert``' Intrinsic
2043620438
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1938,6 +1938,10 @@ LLVM_ABI bool isNullOrNullSplat(SDValue V, bool AllowUndefs = false);
19381938
/// be zero.
19391939
LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
19401940

1941+
/// Return true if the value is a constant floating-point value, or a splatted
1942+
/// vector of a constant floating-point value, of 1.0 (with no undefs).
1943+
LLVM_ABI bool isOneOrOneSplatFP(SDValue V, bool AllowUndefs = false);
1944+
19411945
/// Return true if the value is a constant -1 integer or a splatted vector of a
19421946
/// constant -1 integer (with no undefs).
19431947
/// Does not permit build vector implicit truncation.

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13022,13 +13022,8 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1302213022
Opc = ISD::MUL;
1302313023
}
1302413024

13025-
APInt C;
13026-
ConstantFPSDNode *CFP;
13027-
if (!(Op1->getOpcode() == ISD::MUL &&
13028-
ISD::isConstantSplatVector(Op2.getNode(), C) && C.isOne()) &&
13029-
!(Op1->getOpcode() == ISD::FMUL &&
13030-
(CFP = llvm::isConstOrConstSplatFP(Op2, false)) &&
13031-
CFP->isExactlyValue(1.0)))
13025+
if (!(Opc == ISD::MUL && llvm::isOneOrOneSplat(Op2)) &&
13026+
!(Opc == ISD::FMUL && llvm::isOneOrOneSplatFP(Op2)))
1303213027
return SDValue();
1303313028

1303413029
auto IsIntOrFPExtOpcode = [](unsigned int Opcode) {
@@ -13044,6 +13039,7 @@ SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1304413039

1304513040
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1304613041
// -> partial_reduce_*mla(acc, x, C)
13042+
APInt C;
1304713043
if (ISD::isConstantSplatVector(RHS.getNode(), C)) {
1304813044
// TODO: Make use of partial_reduce_sumla here
1304913045
APInt CTrunc = C.trunc(LHSExtOpVT.getScalarSizeInBits());
@@ -13122,13 +13118,9 @@ SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
1312213118
SDValue Op1 = N->getOperand(1);
1312313119
SDValue Op2 = N->getOperand(2);
1312413120

13125-
APInt ConstantOne;
13126-
ConstantFPSDNode *C;
1312713121
if (!(N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA &&
13128-
(C = llvm::isConstOrConstSplatFP(Op2, false)) &&
13129-
C->isExactlyValue(1.0)) &&
13130-
!(ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) &&
13131-
ConstantOne.isOne()))
13122+
llvm::isOneOrOneSplatFP(Op2)) &&
13123+
!llvm::isOneOrOneSplat(Op2))
1313213124
return SDValue();
1313313125

1313413126
unsigned Op1Opcode = Op1.getOpcode();

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13055,6 +13055,11 @@ bool llvm::isOneOrOneSplat(SDValue N, bool AllowUndefs) {
1305513055
return C && C->isOne();
1305613056
}
1305713057

13058+
bool llvm::isOneOrOneSplatFP(SDValue N, bool AllowUndefs) {
13059+
ConstantFPSDNode *C = isConstOrConstSplatFP(N, AllowUndefs);
13060+
return C && C->isExactlyValue(1.0);
13061+
}
13062+
1305813063
bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
1305913064
N = peekThroughBitcasts(N);
1306013065
unsigned BitWidth = N.getScalarValueSizeInBits();

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8082,16 +8082,12 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
80828082
return;
80838083
}
80848084
case Intrinsic::vector_partial_reduce_fadd: {
8085-
if (!TLI.shouldExpandPartialReductionIntrinsic(cast<IntrinsicInst>(&I))) {
8086-
visitTargetIntrinsic(I, Intrinsic);
8087-
return;
8088-
}
80898085
SDValue Acc = getValue(I.getOperand(0));
80908086
SDValue Input = getValue(I.getOperand(1));
80918087
setValue(&I,
80928088
DAG.getNode(ISD::PARTIAL_REDUCE_FMLA, sdl, Acc.getValueType(), Acc,
80938089
Input,
8094-
DAG.getConstantFP(1.0f, sdl, Input.getValueType())));
8090+
DAG.getConstantFP(1.0, sdl, Input.getValueType())));
80958091
return;
80968092
}
80978093
case Intrinsic::experimental_cttz_elts: {

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12061,27 +12061,30 @@ SDValue TargetLowering::expandPartialReduceMLA(SDNode *N,
1206112061
EVT::getVectorVT(*DAG.getContext(), AccVT.getVectorElementType(),
1206212062
MulOpVT.getVectorElementCount());
1206312063

12064-
unsigned ExtOpcLHS =
12065-
N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND
12066-
: N->getOpcode() == ISD::PARTIAL_REDUCE_UMLA ? ISD::ZERO_EXTEND
12067-
: ISD::SIGN_EXTEND;
12068-
unsigned ExtOpcRHS =
12069-
N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA ? ISD::FP_EXTEND
12070-
: N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA ? ISD::SIGN_EXTEND
12071-
: ISD::ZERO_EXTEND;
12064+
unsigned ExtOpcLHS, ExtOpcRHS;
12065+
switch (N->getOpcode()) {
12066+
default:
12067+
llvm_unreachable("Unexpected opcode");
12068+
case ISD::PARTIAL_REDUCE_UMLA:
12069+
ExtOpcLHS = ExtOpcRHS = ISD::ZERO_EXTEND;
12070+
break;
12071+
case ISD::PARTIAL_REDUCE_SMLA:
12072+
ExtOpcLHS = ExtOpcRHS = ISD::SIGN_EXTEND;
12073+
break;
12074+
case ISD::PARTIAL_REDUCE_FMLA:
12075+
ExtOpcLHS = ExtOpcRHS = ISD::FP_EXTEND;
12076+
break;
12077+
}
1207212078

1207312079
if (ExtMulOpVT != MulOpVT) {
1207412080
MulLHS = DAG.getNode(ExtOpcLHS, DL, ExtMulOpVT, MulLHS);
1207512081
MulRHS = DAG.getNode(ExtOpcRHS, DL, ExtMulOpVT, MulRHS);
1207612082
}
1207712083
SDValue Input = MulLHS;
12078-
APInt ConstantOne;
1207912084
if (N->getOpcode() == ISD::PARTIAL_REDUCE_FMLA) {
12080-
ConstantFPSDNode *C = llvm::isConstOrConstSplatFP(MulRHS, false);
12081-
if (!(C && C->isExactlyValue(1.0)))
12085+
if (!llvm::isOneOrOneSplatFP(MulRHS))
1208212086
Input = DAG.getNode(ISD::FMUL, DL, ExtMulOpVT, MulLHS, MulRHS);
12083-
} else if (!(ISD::isConstantSplatVector(MulRHS.getNode(), ConstantOne) &&
12084-
ConstantOne.isOne())) {
12087+
} else if (!llvm::isOneOrOneSplat(MulRHS)) {
1208512088
Input = DAG.getNode(ISD::MUL, DL, ExtMulOpVT, MulLHS, MulRHS);
1208612089
}
1208712090

llvm/test/CodeGen/AArch64/sve2p1-fdot.ll

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,29 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
2-
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s
2+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2 < %s | FileCheck %s --check-prefixes=CHECK,SVE2
3+
; RUN: llc -mtriple=aarch64-linux-gnu -mattr=+sve2p1 < %s | FileCheck %s --check-prefixes=CHECK,SVE2P1
34

45
define <vscale x 4 x float> @fdot_wide_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a, <vscale x 8 x half> %b) {
5-
; CHECK-LABEL: fdot_wide_vl128:
6-
; CHECK: // %bb.0: // %entry
7-
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
8-
; CHECK-NEXT: ret
6+
; SVE2-LABEL: fdot_wide_vl128:
7+
; SVE2: // %bb.0: // %entry
8+
; SVE2-NEXT: uunpklo z3.s, z1.h
9+
; SVE2-NEXT: uunpklo z4.s, z2.h
10+
; SVE2-NEXT: ptrue p0.s
11+
; SVE2-NEXT: uunpkhi z1.s, z1.h
12+
; SVE2-NEXT: uunpkhi z2.s, z2.h
13+
; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
14+
; SVE2-NEXT: fcvt z4.s, p0/m, z4.h
15+
; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
16+
; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
17+
; SVE2-NEXT: fmul z3.s, z3.s, z4.s
18+
; SVE2-NEXT: fmul z1.s, z1.s, z2.s
19+
; SVE2-NEXT: fadd z0.s, z0.s, z3.s
20+
; SVE2-NEXT: fadd z0.s, z0.s, z1.s
21+
; SVE2-NEXT: ret
22+
;
23+
; SVE2P1-LABEL: fdot_wide_vl128:
24+
; SVE2P1: // %bb.0: // %entry
25+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
26+
; SVE2P1-NEXT: ret
927
entry:
1028
%a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
1129
%b.wide = fpext <vscale x 8 x half> %b to <vscale x 8 x float>
@@ -15,26 +33,56 @@ entry:
1533
}
1634

1735
define <vscale x 4 x float> @fdot_splat_vl128(<vscale x 4 x float> %acc, <vscale x 8 x half> %a) {
18-
; CHECK-LABEL: fdot_splat_vl128:
19-
; CHECK: // %bb.0: // %entry
20-
; CHECK-NEXT: fmov z2.h, #1.00000000
21-
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
22-
; CHECK-NEXT: ret
36+
; SVE2-LABEL: fdot_splat_vl128:
37+
; SVE2: // %bb.0: // %entry
38+
; SVE2-NEXT: uunpklo z2.s, z1.h
39+
; SVE2-NEXT: ptrue p0.s
40+
; SVE2-NEXT: uunpkhi z1.s, z1.h
41+
; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
42+
; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
43+
; SVE2-NEXT: fadd z0.s, z0.s, z2.s
44+
; SVE2-NEXT: fadd z0.s, z0.s, z1.s
45+
; SVE2-NEXT: ret
46+
;
47+
; SVE2P1-LABEL: fdot_splat_vl128:
48+
; SVE2P1: // %bb.0: // %entry
49+
; SVE2P1-NEXT: fmov z2.h, #1.00000000
50+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
51+
; SVE2P1-NEXT: ret
2352
entry:
2453
%a.wide = fpext <vscale x 8 x half> %a to <vscale x 8 x float>
2554
%partial.reduce = call <vscale x 4 x float> @llvm.vector.partial.reduce.fadd(<vscale x 4 x float> %acc, <vscale x 8 x float> %a.wide)
2655
ret <vscale x 4 x float> %partial.reduce
2756
}
2857

2958
define void @fdot_wide_vl256(ptr %accptr, ptr %aptr, ptr %bptr) vscale_range(2,2) {
30-
; CHECK-LABEL: fdot_wide_vl256:
31-
; CHECK: // %bb.0: // %entry
32-
; CHECK-NEXT: ldr z0, [x0]
33-
; CHECK-NEXT: ldr z1, [x1]
34-
; CHECK-NEXT: ldr z2, [x2]
35-
; CHECK-NEXT: fdot z0.s, z1.h, z2.h
36-
; CHECK-NEXT: str z0, [x0]
37-
; CHECK-NEXT: ret
59+
; SVE2-LABEL: fdot_wide_vl256:
60+
; SVE2: // %bb.0: // %entry
61+
; SVE2-NEXT: ptrue p0.s
62+
; SVE2-NEXT: ld1h { z0.s }, p0/z, [x1]
63+
; SVE2-NEXT: ld1h { z1.s }, p0/z, [x2]
64+
; SVE2-NEXT: ld1h { z2.s }, p0/z, [x1, #1, mul vl]
65+
; SVE2-NEXT: ld1h { z3.s }, p0/z, [x2, #1, mul vl]
66+
; SVE2-NEXT: fcvt z0.s, p0/m, z0.h
67+
; SVE2-NEXT: fcvt z1.s, p0/m, z1.h
68+
; SVE2-NEXT: fcvt z2.s, p0/m, z2.h
69+
; SVE2-NEXT: fcvt z3.s, p0/m, z3.h
70+
; SVE2-NEXT: fmul z0.s, z0.s, z1.s
71+
; SVE2-NEXT: ldr z1, [x0]
72+
; SVE2-NEXT: fmul z2.s, z2.s, z3.s
73+
; SVE2-NEXT: fadd z0.s, z1.s, z0.s
74+
; SVE2-NEXT: fadd z0.s, z0.s, z2.s
75+
; SVE2-NEXT: str z0, [x0]
76+
; SVE2-NEXT: ret
77+
;
78+
; SVE2P1-LABEL: fdot_wide_vl256:
79+
; SVE2P1: // %bb.0: // %entry
80+
; SVE2P1-NEXT: ldr z0, [x0]
81+
; SVE2P1-NEXT: ldr z1, [x1]
82+
; SVE2P1-NEXT: ldr z2, [x2]
83+
; SVE2P1-NEXT: fdot z0.s, z1.h, z2.h
84+
; SVE2P1-NEXT: str z0, [x0]
85+
; SVE2P1-NEXT: ret
3886
entry:
3987
%acc = load <8 x float>, ptr %accptr
4088
%a = load <16 x half>, ptr %aptr

0 commit comments

Comments
 (0)