Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3338,6 +3338,14 @@ namespace ISD {
return St && St->getAddressingMode() == ISD::UNINDEXED;
}

/// Returns true if the specified node is a non-extending and unindexed
/// masked load.
inline bool isNormalMaskedLoad(const SDNode *N) {
auto *Ld = dyn_cast<MaskedLoadSDNode>(N);
return Ld && Ld->getExtensionType() == ISD::NON_EXTLOAD &&
Ld->getAddressingMode() == ISD::UNINDEXED;
}

/// Attempt to match a unary predicate against a scalar/splat constant or
/// every element of a constant BUILD_VECTOR.
/// If AllowUndef is true, then UNDEF elements will pass nullptr to Match.
Expand Down
112 changes: 112 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setTargetDAGCombine(ISD::SCALAR_TO_VECTOR);

setTargetDAGCombine(ISD::SHL);
setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE);

// In case of strict alignment, avoid an excessive number of byte wide stores.
MaxStoresPerMemsetOptSize = 8;
Expand Down Expand Up @@ -27015,6 +27016,115 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
return NVCAST;
}

static SDValue performVectorDeinterleaveCombine(
SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) {
if (!DCI.isBeforeLegalize())
return SDValue();

unsigned NumParts = N->getNumOperands();
if (NumParts != 2 && NumParts != 4)
return SDValue();

EVT SubVecTy = N->getValueType(0);

// At the moment we're unlikely to see a fixed-width vector deinterleave as
// we usually generate shuffles instead.
unsigned MinNumElements = SubVecTy.getVectorMinNumElements();
if (!SubVecTy.isScalableVector() ||
SubVecTy.getSizeInBits().getKnownMinValue() != 128 ||
!DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy))
return SDValue();

// Make sure each input operand is the correct extract_subvector of the same
// wider vector.
SDValue Op0 = N->getOperand(0);
for (unsigned I = 0; I < NumParts; I++) {
SDValue OpI = N->getOperand(I);
if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
OpI->getOperand(0) != Op0->getOperand(0))
return SDValue();
if (OpI->getConstantOperandVal(1) != (I * MinNumElements))
return SDValue();
}

// Normal loads are currently already handled by the InterleavedAccessPass so
// we don't expect to see them here. Bail out if the masked load has an
// unexpected number of uses, since we want to avoid a situation where we have
// both deinterleaving loads and normal loads in the same block. Also, discard
// masked loads that are extending, indexed, have an unexpected offset or have
// an unsupported passthru value until we find a valid use case.
auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0));
if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) ||
!MaskedLoad->isSimple() || !ISD::isNormalMaskedLoad(MaskedLoad) ||
!MaskedLoad->getOffset().isUndef() ||
(!MaskedLoad->getPassThru()->isUndef() &&
!isZerosVector(MaskedLoad->getPassThru().getNode())))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing a test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added foo_ld2_nxv16i8_bad_mask4 for the case when the mask didn't come from a splat or a concat. Instead it comes from a insert_subvector node.

return SDValue();

// Now prove that the mask is an interleave of identical masks.
SDValue Mask = MaskedLoad->getMask();
if (Mask->getOpcode() != ISD::SPLAT_VECTOR &&
Mask->getOpcode() != ISD::CONCAT_VECTORS)
return SDValue();

SDValue NarrowMask;
SDLoc DL(N);
if (Mask->getOpcode() == ISD::CONCAT_VECTORS) {
if (Mask->getNumOperands() != NumParts)
return SDValue();

// We should be concatenating each sequential result from a
// VECTOR_INTERLEAVE.
SDNode *InterleaveOp = Mask->getOperand(0).getNode();
if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE ||
InterleaveOp->getNumOperands() != NumParts)
return SDValue();

for (unsigned I = 0; I < NumParts; I++) {
if (Mask.getOperand(I) != SDValue(InterleaveOp, I))
return SDValue();
}

// Make sure the inputs to the vector interleave are identical.
if (!llvm::all_equal(InterleaveOp->op_values()))
return SDValue();

NarrowMask = InterleaveOp->getOperand(0);
} else { // ISD::SPLAT_VECTOR
ElementCount EC = Mask.getValueType().getVectorElementCount();
assert(EC.isKnownMultipleOf(NumParts) &&
"Expected element count divisible by number of parts");
EC = EC.divideCoefficientBy(NumParts);
NarrowMask =
DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC),
Mask->getOperand(0));
}

const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret
: Intrinsic::aarch64_sve_ld4_sret;
SDValue NewLdOps[] = {MaskedLoad->getChain(),
DAG.getConstant(IID, DL, MVT::i32), NarrowMask,
MaskedLoad->getBasePtr()};
SDValue Res;
if (NumParts == 2)
Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
{SubVecTy, SubVecTy, MVT::Other}, NewLdOps);
else
Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL,
{SubVecTy, SubVecTy, SubVecTy, SubVecTy, MVT::Other},
NewLdOps);

// We can now generate a structured load!
SmallVector<SDValue, 4> ResOps(NumParts);
for (unsigned Idx = 0; Idx < NumParts; Idx++)
ResOps[Idx] = SDValue(Res.getNode(), Idx);

// Replace uses of the original chain result with the new chain result.
DAG.ReplaceAllUsesOfValueWith(SDValue(MaskedLoad, 1),
SDValue(Res.getNode(), NumParts));
return DCI.CombineTo(N, ResOps, false);
}

/// If the operand is a bitwise AND with a constant RHS, and the shift has a
/// constant RHS and is the only use, we can pull it out of the shift, i.e.
///
Expand Down Expand Up @@ -27083,6 +27193,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
default:
LLVM_DEBUG(dbgs() << "Custom combining: skipping\n");
break;
case ISD::VECTOR_DEINTERLEAVE:
return performVectorDeinterleaveCombine(N, DCI, DAG);
case ISD::VECREDUCE_AND:
case ISD::VECREDUCE_OR:
case ISD::VECREDUCE_XOR:
Expand Down
Loading