Skip to content

Commit 80a233b

Browse files
committed
[AArch64][GlobalISel] Perfect Shuffles
This is a port of the existing perfect shuffle generation code from SDAG, geneticized to work for both SDAG and GISel. I wrote it a while ago and it has been sitting on my machine. It brings the codegen for certain shuffles inline and avoids the need for generating a tbl and constant pool load.
1 parent 31bde71 commit 80a233b

15 files changed

+465
-351
lines changed

llvm/include/llvm/CodeGen/GlobalISel/MachineIRBuilder.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1360,6 +1360,24 @@ class MachineIRBuilder {
13601360
const SrcOp &Elt,
13611361
const SrcOp &Idx);
13621362

1363+
/// Build and insert \p Res = G_INSERT_VECTOR_ELT \p Val, \p Elt, \p Idx
1364+
///
1365+
/// \pre setBasicBlock or setMI must have been called.
1366+
/// \pre \p Res must be a generic virtual register with scalar type.
1367+
/// \pre \p Val must be a generic virtual register with vector type.
1368+
/// \pre \p Elt must be a generic virtual register with scalar type.
1369+
///
1370+
/// \return The newly created instruction.
1371+
MachineInstrBuilder buildInsertVectorElementConstant(const DstOp &Res,
1372+
const SrcOp &Val,
1373+
const SrcOp &Elt,
1374+
const int Idx) {
1375+
auto TLI = getMF().getSubtarget().getTargetLowering();
1376+
unsigned VecIdxWidth = TLI->getVectorIdxTy(getDataLayout()).getSizeInBits();
1377+
return buildInsertVectorElement(
1378+
Res, Val, Elt, buildConstant(LLT::scalar(VecIdxWidth), Idx));
1379+
}
1380+
13631381
/// Build and insert \p Res = G_EXTRACT_VECTOR_ELT \p Val, \p Idx
13641382
///
13651383
/// \pre setBasicBlock or setMI must have been called.

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,13 @@ def shuf_to_ins: GICombineRule <
146146
(apply [{ applyINS(*${root}, MRI, B, ${matchinfo}); }])
147147
>;
148148

149+
def perfect_shuffle: GICombineRule <
150+
(defs root:$root),
151+
(match (G_SHUFFLE_VECTOR $dst, $src1, $src2, $mask):$root,
152+
[{ return matchPerfectShuffle(*${root}, MRI); }]),
153+
(apply [{ applyPerfectShuffle(*${root}, MRI, B); }])
154+
>;
155+
149156
def vashr_vlshr_imm_matchdata : GIDefMatchData<"int64_t">;
150157
def vashr_vlshr_imm : GICombineRule<
151158
(defs root:$root, vashr_vlshr_imm_matchdata:$matchinfo),
@@ -164,7 +171,8 @@ def form_duplane : GICombineRule <
164171
>;
165172

166173
def shuffle_vector_lowering : GICombineGroup<[dup, rev, ext, zip, uzp, trn,
167-
form_duplane, shuf_to_ins]>;
174+
form_duplane, shuf_to_ins,
175+
perfect_shuffle]>;
168176

169177
// Turn G_UNMERGE_VALUES -> G_EXTRACT_VECTOR_ELT's
170178
def vector_unmerge_lowering : GICombineRule <

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 89 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -13135,172 +13135,6 @@ static SDValue tryFormConcatFromShuffle(SDValue Op, SelectionDAG &DAG) {
1313513135
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, V0, V1);
1313613136
}
1313713137

13138-
/// GeneratePerfectShuffle - Given an entry in the perfect-shuffle table, emit
13139-
/// the specified operations to build the shuffle. ID is the perfect-shuffle
13140-
//ID, V1 and V2 are the original shuffle inputs. PFEntry is the Perfect shuffle
13141-
//table entry and LHS/RHS are the immediate inputs for this stage of the
13142-
//shuffle.
13143-
static SDValue GeneratePerfectShuffle(unsigned ID, SDValue V1,
13144-
SDValue V2, unsigned PFEntry, SDValue LHS,
13145-
SDValue RHS, SelectionDAG &DAG,
13146-
const SDLoc &dl) {
13147-
unsigned OpNum = (PFEntry >> 26) & 0x0F;
13148-
unsigned LHSID = (PFEntry >> 13) & ((1 << 13) - 1);
13149-
unsigned RHSID = (PFEntry >> 0) & ((1 << 13) - 1);
13150-
13151-
enum {
13152-
OP_COPY = 0, // Copy, used for things like <u,u,u,3> to say it is <0,1,2,3>
13153-
OP_VREV,
13154-
OP_VDUP0,
13155-
OP_VDUP1,
13156-
OP_VDUP2,
13157-
OP_VDUP3,
13158-
OP_VEXT1,
13159-
OP_VEXT2,
13160-
OP_VEXT3,
13161-
OP_VUZPL, // VUZP, left result
13162-
OP_VUZPR, // VUZP, right result
13163-
OP_VZIPL, // VZIP, left result
13164-
OP_VZIPR, // VZIP, right result
13165-
OP_VTRNL, // VTRN, left result
13166-
OP_VTRNR, // VTRN, right result
13167-
OP_MOVLANE // Move lane. RHSID is the lane to move into
13168-
};
13169-
13170-
if (OpNum == OP_COPY) {
13171-
if (LHSID == (1 * 9 + 2) * 9 + 3)
13172-
return LHS;
13173-
assert(LHSID == ((4 * 9 + 5) * 9 + 6) * 9 + 7 && "Illegal OP_COPY!");
13174-
return RHS;
13175-
}
13176-
13177-
if (OpNum == OP_MOVLANE) {
13178-
// Decompose a PerfectShuffle ID to get the Mask for lane Elt
13179-
auto getPFIDLane = [](unsigned ID, int Elt) -> int {
13180-
assert(Elt < 4 && "Expected Perfect Lanes to be less than 4");
13181-
Elt = 3 - Elt;
13182-
while (Elt > 0) {
13183-
ID /= 9;
13184-
Elt--;
13185-
}
13186-
return (ID % 9 == 8) ? -1 : ID % 9;
13187-
};
13188-
13189-
// For OP_MOVLANE shuffles, the RHSID represents the lane to move into. We
13190-
// get the lane to move from the PFID, which is always from the
13191-
// original vectors (V1 or V2).
13192-
SDValue OpLHS = GeneratePerfectShuffle(
13193-
LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS, RHS, DAG, dl);
13194-
EVT VT = OpLHS.getValueType();
13195-
assert(RHSID < 8 && "Expected a lane index for RHSID!");
13196-
unsigned ExtLane = 0;
13197-
SDValue Input;
13198-
13199-
// OP_MOVLANE are either D movs (if bit 0x4 is set) or S movs. D movs
13200-
// convert into a higher type.
13201-
if (RHSID & 0x4) {
13202-
int MaskElt = getPFIDLane(ID, (RHSID & 0x01) << 1) >> 1;
13203-
if (MaskElt == -1)
13204-
MaskElt = (getPFIDLane(ID, ((RHSID & 0x01) << 1) + 1) - 1) >> 1;
13205-
assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
13206-
ExtLane = MaskElt < 2 ? MaskElt : (MaskElt - 2);
13207-
Input = MaskElt < 2 ? V1 : V2;
13208-
if (VT.getScalarSizeInBits() == 16) {
13209-
Input = DAG.getBitcast(MVT::v2f32, Input);
13210-
OpLHS = DAG.getBitcast(MVT::v2f32, OpLHS);
13211-
} else {
13212-
assert(VT.getScalarSizeInBits() == 32 &&
13213-
"Expected 16 or 32 bit shuffle elemements");
13214-
Input = DAG.getBitcast(MVT::v2f64, Input);
13215-
OpLHS = DAG.getBitcast(MVT::v2f64, OpLHS);
13216-
}
13217-
} else {
13218-
int MaskElt = getPFIDLane(ID, RHSID);
13219-
assert(MaskElt >= 0 && "Didn't expect an undef movlane index!");
13220-
ExtLane = MaskElt < 4 ? MaskElt : (MaskElt - 4);
13221-
Input = MaskElt < 4 ? V1 : V2;
13222-
// Be careful about creating illegal types. Use f16 instead of i16.
13223-
if (VT == MVT::v4i16) {
13224-
Input = DAG.getBitcast(MVT::v4f16, Input);
13225-
OpLHS = DAG.getBitcast(MVT::v4f16, OpLHS);
13226-
}
13227-
}
13228-
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
13229-
Input.getValueType().getVectorElementType(),
13230-
Input, DAG.getVectorIdxConstant(ExtLane, dl));
13231-
SDValue Ins =
13232-
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, Input.getValueType(), OpLHS,
13233-
Ext, DAG.getVectorIdxConstant(RHSID & 0x3, dl));
13234-
return DAG.getBitcast(VT, Ins);
13235-
}
13236-
13237-
SDValue OpLHS, OpRHS;
13238-
OpLHS = GeneratePerfectShuffle(LHSID, V1, V2, PerfectShuffleTable[LHSID], LHS,
13239-
RHS, DAG, dl);
13240-
OpRHS = GeneratePerfectShuffle(RHSID, V1, V2, PerfectShuffleTable[RHSID], LHS,
13241-
RHS, DAG, dl);
13242-
EVT VT = OpLHS.getValueType();
13243-
13244-
switch (OpNum) {
13245-
default:
13246-
llvm_unreachable("Unknown shuffle opcode!");
13247-
case OP_VREV:
13248-
// VREV divides the vector in half and swaps within the half.
13249-
if (VT.getVectorElementType() == MVT::i32 ||
13250-
VT.getVectorElementType() == MVT::f32)
13251-
return DAG.getNode(AArch64ISD::REV64, dl, VT, OpLHS);
13252-
// vrev <4 x i16> -> REV32
13253-
if (VT.getVectorElementType() == MVT::i16 ||
13254-
VT.getVectorElementType() == MVT::f16 ||
13255-
VT.getVectorElementType() == MVT::bf16)
13256-
return DAG.getNode(AArch64ISD::REV32, dl, VT, OpLHS);
13257-
// vrev <4 x i8> -> REV16
13258-
assert(VT.getVectorElementType() == MVT::i8);
13259-
return DAG.getNode(AArch64ISD::REV16, dl, VT, OpLHS);
13260-
case OP_VDUP0:
13261-
case OP_VDUP1:
13262-
case OP_VDUP2:
13263-
case OP_VDUP3: {
13264-
EVT EltTy = VT.getVectorElementType();
13265-
unsigned Opcode;
13266-
if (EltTy == MVT::i8)
13267-
Opcode = AArch64ISD::DUPLANE8;
13268-
else if (EltTy == MVT::i16 || EltTy == MVT::f16 || EltTy == MVT::bf16)
13269-
Opcode = AArch64ISD::DUPLANE16;
13270-
else if (EltTy == MVT::i32 || EltTy == MVT::f32)
13271-
Opcode = AArch64ISD::DUPLANE32;
13272-
else if (EltTy == MVT::i64 || EltTy == MVT::f64)
13273-
Opcode = AArch64ISD::DUPLANE64;
13274-
else
13275-
llvm_unreachable("Invalid vector element type?");
13276-
13277-
if (VT.getSizeInBits() == 64)
13278-
OpLHS = WidenVector(OpLHS, DAG);
13279-
SDValue Lane = DAG.getConstant(OpNum - OP_VDUP0, dl, MVT::i64);
13280-
return DAG.getNode(Opcode, dl, VT, OpLHS, Lane);
13281-
}
13282-
case OP_VEXT1:
13283-
case OP_VEXT2:
13284-
case OP_VEXT3: {
13285-
unsigned Imm = (OpNum - OP_VEXT1 + 1) * getExtFactor(OpLHS);
13286-
return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
13287-
DAG.getConstant(Imm, dl, MVT::i32));
13288-
}
13289-
case OP_VUZPL:
13290-
return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
13291-
case OP_VUZPR:
13292-
return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
13293-
case OP_VZIPL:
13294-
return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
13295-
case OP_VZIPR:
13296-
return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
13297-
case OP_VTRNL:
13298-
return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
13299-
case OP_VTRNR:
13300-
return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
13301-
}
13302-
}
13303-
1330413138
static SDValue GenerateTBL(SDValue Op, ArrayRef<int> ShuffleMask,
1330513139
SelectionDAG &DAG) {
1330613140
// Check to see if we can use the TBL instruction.
@@ -13762,8 +13596,95 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
1376213596
unsigned PFTableIndex = PFIndexes[0] * 9 * 9 * 9 + PFIndexes[1] * 9 * 9 +
1376313597
PFIndexes[2] * 9 + PFIndexes[3];
1376413598
unsigned PFEntry = PerfectShuffleTable[PFTableIndex];
13765-
return GeneratePerfectShuffle(PFTableIndex, V1, V2, PFEntry, V1, V2, DAG,
13766-
dl);
13599+
13600+
auto BuildRev = [&DAG, &dl](SDValue OpLHS) {
13601+
EVT VT = OpLHS.getValueType();
13602+
unsigned Opcode = VT.getScalarSizeInBits() == 32 ? AArch64ISD::REV64
13603+
: VT.getScalarSizeInBits() == 16 ? AArch64ISD::REV32
13604+
: AArch64ISD::REV16;
13605+
return DAG.getNode(Opcode, dl, VT, OpLHS);
13606+
};
13607+
auto BuildDup = [&DAG, &dl](SDValue OpLHS, unsigned Lane) {
13608+
EVT VT = OpLHS.getValueType();
13609+
unsigned Opcode;
13610+
if (VT.getScalarSizeInBits() == 8)
13611+
Opcode = AArch64ISD::DUPLANE8;
13612+
else if (VT.getScalarSizeInBits() == 16)
13613+
Opcode = AArch64ISD::DUPLANE16;
13614+
else if (VT.getScalarSizeInBits() == 32)
13615+
Opcode = AArch64ISD::DUPLANE32;
13616+
else if (VT.getScalarSizeInBits() == 64)
13617+
Opcode = AArch64ISD::DUPLANE64;
13618+
else
13619+
llvm_unreachable("Invalid vector element type?");
13620+
13621+
if (VT.getSizeInBits() == 64)
13622+
OpLHS = WidenVector(OpLHS, DAG);
13623+
return DAG.getNode(Opcode, dl, VT, OpLHS,
13624+
DAG.getConstant(Lane, dl, MVT::i64));
13625+
};
13626+
auto BuildExt = [&DAG, &dl](SDValue OpLHS, SDValue OpRHS, unsigned Imm) {
13627+
EVT VT = OpLHS.getValueType();
13628+
Imm = Imm * getExtFactor(OpLHS);
13629+
return DAG.getNode(AArch64ISD::EXT, dl, VT, OpLHS, OpRHS,
13630+
DAG.getConstant(Imm, dl, MVT::i32));
13631+
};
13632+
auto BuildZipLike = [&DAG, &dl](unsigned OpNum, SDValue OpLHS,
13633+
SDValue OpRHS) {
13634+
EVT VT = OpLHS.getValueType();
13635+
switch (OpNum) {
13636+
default:
13637+
llvm_unreachable("Unexpected perfect shuffle opcode\n");
13638+
case OP_VUZPL:
13639+
return DAG.getNode(AArch64ISD::UZP1, dl, VT, OpLHS, OpRHS);
13640+
case OP_VUZPR:
13641+
return DAG.getNode(AArch64ISD::UZP2, dl, VT, OpLHS, OpRHS);
13642+
case OP_VZIPL:
13643+
return DAG.getNode(AArch64ISD::ZIP1, dl, VT, OpLHS, OpRHS);
13644+
case OP_VZIPR:
13645+
return DAG.getNode(AArch64ISD::ZIP2, dl, VT, OpLHS, OpRHS);
13646+
case OP_VTRNL:
13647+
return DAG.getNode(AArch64ISD::TRN1, dl, VT, OpLHS, OpRHS);
13648+
case OP_VTRNR:
13649+
return DAG.getNode(AArch64ISD::TRN2, dl, VT, OpLHS, OpRHS);
13650+
}
13651+
};
13652+
auto BuildExtractInsert64 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
13653+
SDValue InsSrc, unsigned InsLane) {
13654+
EVT VT = InsSrc.getValueType();
13655+
if (VT.getScalarSizeInBits() == 16) {
13656+
ExtSrc = DAG.getBitcast(MVT::v2f32, ExtSrc);
13657+
InsSrc = DAG.getBitcast(MVT::v2f32, InsSrc);
13658+
} else if (VT.getScalarSizeInBits() == 32) {
13659+
ExtSrc = DAG.getBitcast(MVT::v2f64, ExtSrc);
13660+
InsSrc = DAG.getBitcast(MVT::v2f64, InsSrc);
13661+
}
13662+
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
13663+
ExtSrc.getValueType().getVectorElementType(),
13664+
ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
13665+
SDValue Ins =
13666+
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
13667+
Ext, DAG.getVectorIdxConstant(InsLane, dl));
13668+
return DAG.getBitcast(VT, Ins);
13669+
};
13670+
auto BuildExtractInsert32 = [&DAG, &dl](SDValue ExtSrc, unsigned ExtLane,
13671+
SDValue InsSrc, unsigned InsLane) {
13672+
EVT VT = InsSrc.getValueType();
13673+
if (VT.getScalarSizeInBits() == 16) {
13674+
ExtSrc = DAG.getBitcast(MVT::v4f16, ExtSrc);
13675+
InsSrc = DAG.getBitcast(MVT::v4f16, InsSrc);
13676+
}
13677+
SDValue Ext = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl,
13678+
ExtSrc.getValueType().getVectorElementType(),
13679+
ExtSrc, DAG.getVectorIdxConstant(ExtLane, dl));
13680+
SDValue Ins =
13681+
DAG.getNode(ISD::INSERT_VECTOR_ELT, dl, ExtSrc.getValueType(), InsSrc,
13682+
Ext, DAG.getVectorIdxConstant(InsLane, dl));
13683+
return DAG.getBitcast(VT, Ins);
13684+
};
13685+
return generatePerfectShuffle<SDValue, MVT>(
13686+
PFTableIndex, V1, V2, PFEntry, V1, V2, BuildExtractInsert64,
13687+
BuildExtractInsert32, BuildRev, BuildDup, BuildExt, BuildZipLike);
1376713688
}
1376813689

1376913690
return GenerateTBL(Op, ShuffleMask, DAG);

0 commit comments

Comments
 (0)