@@ -43999,18 +43999,18 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
43999
43999
// integer, that requires a potentially expensive XMM -> GPR transfer.
44000
44000
// Additionally, if we can convert to a scalar integer load, that will likely
44001
44001
// be folded into a subsequent integer op.
44002
+ // Note: SrcVec might not have a VecVT type, but it must be the same size.
44002
44003
// Note: Unlike the related fold for this in DAGCombiner, this is not limited
44003
44004
// to a single-use of the loaded vector. For the reasons above, we
44004
44005
// expect this to be profitable even if it creates an extra load.
44005
44006
static SDValue
44006
- combineExtractFromVectorLoad(SDNode *N, SDValue InputVector , uint64_t Idx,
44007
+ combineExtractFromVectorLoad(SDNode *N, EVT VecVT, SDValue SrcVec , uint64_t Idx,
44007
44008
const SDLoc &dl, SelectionDAG &DAG,
44008
44009
TargetLowering::DAGCombinerInfo &DCI) {
44009
44010
assert(N->getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
44010
44011
"Only EXTRACT_VECTOR_ELT supported so far");
44011
44012
44012
44013
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
44013
- EVT SrcVT = InputVector.getValueType();
44014
44014
EVT VT = N->getValueType(0);
44015
44015
44016
44016
bool LikelyUsedAsVector = any_of(N->uses(), [](SDNode *Use) {
@@ -44019,12 +44019,13 @@ combineExtractFromVectorLoad(SDNode *N, SDValue InputVector, uint64_t Idx,
44019
44019
Use->getOpcode() == ISD::SCALAR_TO_VECTOR;
44020
44020
});
44021
44021
44022
- auto *LoadVec = dyn_cast<LoadSDNode>(InputVector );
44022
+ auto *LoadVec = dyn_cast<LoadSDNode>(SrcVec );
44023
44023
if (LoadVec && ISD::isNormalLoad(LoadVec) && VT.isInteger() &&
44024
- SrcVT.getVectorElementType() == VT && DCI.isAfterLegalizeDAG() &&
44025
- !LikelyUsedAsVector && LoadVec->isSimple()) {
44024
+ VecVT.getVectorElementType() == VT &&
44025
+ VecVT.getSizeInBits() == SrcVec.getValueSizeInBits() &&
44026
+ DCI.isAfterLegalizeDAG() && !LikelyUsedAsVector && LoadVec->isSimple()) {
44026
44027
SDValue NewPtr = TLI.getVectorElementPointer(
44027
- DAG, LoadVec->getBasePtr(), SrcVT , DAG.getVectorIdxConstant(Idx, dl));
44028
+ DAG, LoadVec->getBasePtr(), VecVT , DAG.getVectorIdxConstant(Idx, dl));
44028
44029
unsigned PtrOff = VT.getSizeInBits() * Idx / 8;
44029
44030
MachinePointerInfo MPI = LoadVec->getPointerInfo().getWithOffset(PtrOff);
44030
44031
Align Alignment = commonAlignment(LoadVec->getAlign(), PtrOff);
@@ -44234,10 +44235,9 @@ static SDValue combineExtractWithShuffle(SDNode *N, SelectionDAG &DAG,
44234
44235
if (SDValue V = GetLegalExtract(SrcOp, ExtractVT, ExtractIdx))
44235
44236
return DAG.getZExtOrTrunc(V, dl, VT);
44236
44237
44237
- if (N->getOpcode() == ISD::EXTRACT_VECTOR_ELT && ExtractVT == SrcVT &&
44238
- SrcOp.getValueType() == SrcVT)
44239
- if (SDValue V =
44240
- combineExtractFromVectorLoad(N, SrcOp, ExtractIdx, dl, DAG, DCI))
44238
+ if (N->getOpcode() == ISD::EXTRACT_VECTOR_ELT && ExtractVT == SrcVT)
44239
+ if (SDValue V = combineExtractFromVectorLoad(
44240
+ N, SrcVT, peekThroughBitcasts(SrcOp), ExtractIdx, dl, DAG, DCI))
44241
44241
return V;
44242
44242
44243
44243
return SDValue();
@@ -44651,7 +44651,8 @@ static SDValue combineExtractVectorElt(SDNode *N, SelectionDAG &DAG,
44651
44651
44652
44652
if (CIdx)
44653
44653
if (SDValue V = combineExtractFromVectorLoad(
44654
- N, InputVector, CIdx->getZExtValue(), dl, DAG, DCI))
44654
+ N, InputVector.getValueType(), InputVector, CIdx->getZExtValue(),
44655
+ dl, DAG, DCI))
44655
44656
return V;
44656
44657
44657
44658
// Attempt to extract a i1 element by using MOVMSK to extract the signbits
0 commit comments