-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[AArch64] Improve lowering for scalable masked deinterleaving loads #154338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1179,6 +1179,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, | |
| setTargetDAGCombine(ISD::SCALAR_TO_VECTOR); | ||
|
|
||
| setTargetDAGCombine(ISD::SHL); | ||
| setTargetDAGCombine(ISD::VECTOR_DEINTERLEAVE); | ||
|
|
||
| // In case of strict alignment, avoid an excessive number of byte wide stores. | ||
| MaxStoresPerMemsetOptSize = 8; | ||
|
|
@@ -27015,6 +27016,115 @@ performScalarToVectorCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, | |
| return NVCAST; | ||
| } | ||
|
|
||
| static SDValue performVectorDeinterleaveCombine( | ||
paulwalker-arm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { | ||
| if (!DCI.isBeforeLegalize()) | ||
| return SDValue(); | ||
|
|
||
| unsigned NumParts = N->getNumOperands(); | ||
| if (NumParts != 2 && NumParts != 4) | ||
c-rhodes marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return SDValue(); | ||
|
|
||
| EVT SubVecTy = N->getValueType(0); | ||
|
|
||
| // At the moment we're unlikely to see a fixed-width vector deinterleave as | ||
paulwalker-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // we usually generate shuffles instead. | ||
paulwalker-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| unsigned MinNumElements = SubVecTy.getVectorMinNumElements(); | ||
| if (!SubVecTy.isScalableVector() || | ||
| SubVecTy.getSizeInBits().getKnownMinValue() != 128 || | ||
| !DAG.getTargetLoweringInfo().isTypeLegal(SubVecTy)) | ||
c-rhodes marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return SDValue(); | ||
|
|
||
| // Make sure each input operand is the correct extract_subvector of the same | ||
| // wider vector. | ||
c-rhodes marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| SDValue Op0 = N->getOperand(0); | ||
| for (unsigned I = 0; I < NumParts; I++) { | ||
| SDValue OpI = N->getOperand(I); | ||
| if (OpI->getOpcode() != ISD::EXTRACT_SUBVECTOR || | ||
| OpI->getOperand(0) != Op0->getOperand(0)) | ||
| return SDValue(); | ||
| if (OpI->getConstantOperandVal(1) != (I * MinNumElements)) | ||
| return SDValue(); | ||
| } | ||
|
|
||
| // Normal loads are currently already handled by the InterleavedAccessPass so | ||
| // we don't expect to see them here. Bail out if the masked load has an | ||
| // unexpected number of uses, since we want to avoid a situation where we have | ||
| // both deinterleaving loads and normal loads in the same block. Also, discard | ||
| // masked loads that are extending, indexed, have an unexpected offset or have | ||
| // an unsupported passthru value until we find a valid use case. | ||
| auto MaskedLoad = dyn_cast<MaskedLoadSDNode>(Op0->getOperand(0)); | ||
| if (!MaskedLoad || !MaskedLoad->hasNUsesOfValue(NumParts, 0) || | ||
| !MaskedLoad->isSimple() || !ISD::isNormalMaskedLoad(MaskedLoad) || | ||
| !MaskedLoad->getOffset().isUndef() || | ||
| (!MaskedLoad->getPassThru()->isUndef() && | ||
| !isZerosVector(MaskedLoad->getPassThru().getNode()))) | ||
|
||
| return SDValue(); | ||
|
|
||
c-rhodes marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // Now prove that the mask is an interleave of identical masks. | ||
| SDValue Mask = MaskedLoad->getMask(); | ||
| if (Mask->getOpcode() != ISD::SPLAT_VECTOR && | ||
| Mask->getOpcode() != ISD::CONCAT_VECTORS) | ||
| return SDValue(); | ||
|
|
||
| SDValue NarrowMask; | ||
| SDLoc DL(N); | ||
| if (Mask->getOpcode() == ISD::CONCAT_VECTORS) { | ||
| if (Mask->getNumOperands() != NumParts) | ||
| return SDValue(); | ||
|
|
||
| // We should be concatenating each sequential result from a | ||
| // VECTOR_INTERLEAVE. | ||
| SDNode *InterleaveOp = Mask->getOperand(0).getNode(); | ||
| if (InterleaveOp->getOpcode() != ISD::VECTOR_INTERLEAVE || | ||
| InterleaveOp->getNumOperands() != NumParts) | ||
| return SDValue(); | ||
|
|
||
| for (unsigned I = 0; I < NumParts; I++) { | ||
| if (Mask.getOperand(I) != SDValue(InterleaveOp, I)) | ||
| return SDValue(); | ||
| } | ||
|
|
||
| // Make sure the inputs to the vector interleave are identical. | ||
| if (!llvm::all_equal(InterleaveOp->op_values())) | ||
| return SDValue(); | ||
|
|
||
| NarrowMask = InterleaveOp->getOperand(0); | ||
| } else { // ISD::SPLAT_VECTOR | ||
| ElementCount EC = Mask.getValueType().getVectorElementCount(); | ||
| assert(EC.isKnownMultipleOf(NumParts) && | ||
| "Expected element count divisible by number of parts"); | ||
| EC = EC.divideCoefficientBy(NumParts); | ||
| NarrowMask = | ||
| DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::getVectorVT(MVT::i1, EC), | ||
| Mask->getOperand(0)); | ||
| } | ||
|
|
||
paulwalker-arm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| const Intrinsic::ID IID = NumParts == 2 ? Intrinsic::aarch64_sve_ld2_sret | ||
| : Intrinsic::aarch64_sve_ld4_sret; | ||
| SDValue NewLdOps[] = {MaskedLoad->getChain(), | ||
| DAG.getConstant(IID, DL, MVT::i32), NarrowMask, | ||
| MaskedLoad->getBasePtr()}; | ||
| SDValue Res; | ||
| if (NumParts == 2) | ||
| Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, | ||
| {SubVecTy, SubVecTy, MVT::Other}, NewLdOps); | ||
| else | ||
| Res = DAG.getNode(ISD::INTRINSIC_W_CHAIN, DL, | ||
| {SubVecTy, SubVecTy, SubVecTy, SubVecTy, MVT::Other}, | ||
| NewLdOps); | ||
|
|
||
| // We can now generate a structured load! | ||
| SmallVector<SDValue, 4> ResOps(NumParts); | ||
| for (unsigned Idx = 0; Idx < NumParts; Idx++) | ||
| ResOps[Idx] = SDValue(Res.getNode(), Idx); | ||
|
|
||
| // Replace uses of the original chain result with the new chain result. | ||
| DAG.ReplaceAllUsesOfValueWith(SDValue(MaskedLoad, 1), | ||
| SDValue(Res.getNode(), NumParts)); | ||
| return DCI.CombineTo(N, ResOps, false); | ||
| } | ||
|
|
||
| /// If the operand is a bitwise AND with a constant RHS, and the shift has a | ||
| /// constant RHS and is the only use, we can pull it out of the shift, i.e. | ||
| /// | ||
|
|
@@ -27083,6 +27193,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N, | |
| default: | ||
| LLVM_DEBUG(dbgs() << "Custom combining: skipping\n"); | ||
| break; | ||
| case ISD::VECTOR_DEINTERLEAVE: | ||
| return performVectorDeinterleaveCombine(N, DCI, DAG); | ||
| case ISD::VECREDUCE_AND: | ||
| case ISD::VECREDUCE_OR: | ||
| case ISD::VECREDUCE_XOR: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.