@@ -35231,22 +35231,74 @@ static SDValue combineTargetShuffle(SDValue N, SelectionDAG &DAG,
35231
35231
// Due to isTypeDesirableForOp, we won't always shrink a load truncated to
35232
35232
// i16. So shrink it ourselves if we can make a broadcast_load.
35233
35233
if (SrcVT == MVT::i16 && Src.getOpcode() == ISD::TRUNCATE &&
35234
- Src.hasOneUse() && ISD::isNormalLoad(Src.getOperand(0).getNode()) &&
35235
- Src.getOperand(0).hasOneUse()) {
35234
+ Src.hasOneUse() && Src.getOperand(0).hasOneUse()) {
35236
35235
assert(Subtarget.hasAVX2() && "Expected AVX2");
35237
- LoadSDNode *LN = cast<LoadSDNode>(Src.getOperand(0));
35238
- if (LN->isSimple()) {
35239
- SDVTList Tys = DAG.getVTList(VT, MVT::Other);
35240
- SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
35241
- SDValue BcastLd =
35242
- DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops,
35243
- MVT::i16, LN->getPointerInfo(),
35244
- LN->getAlignment(),
35245
- LN->getMemOperand()->getFlags());
35246
- DCI.CombineTo(N.getNode(), BcastLd);
35247
- DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
35248
- DCI.recursivelyDeleteUnusedNodes(LN);
35249
- return N; // Return N so it doesn't get rechecked!
35236
+ SDValue TruncIn = Src.getOperand(0);
35237
+
35238
+ // If this is a truncate of a non extending load we can just narrow it to
35239
+ // use a broadcast_load.
35240
+ if (ISD::isNormalLoad(TruncIn.getNode())) {
35241
+ LoadSDNode *LN = cast<LoadSDNode>(TruncIn);
35242
+ // Unless its volatile or atomic.
35243
+ if (LN->isSimple()) {
35244
+ SDVTList Tys = DAG.getVTList(VT, MVT::Other);
35245
+ SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
35246
+ SDValue BcastLd =
35247
+ DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops,
35248
+ MVT::i16, LN->getPointerInfo(),
35249
+ LN->getAlignment(),
35250
+ LN->getMemOperand()->getFlags());
35251
+ DCI.CombineTo(N.getNode(), BcastLd);
35252
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
35253
+ DCI.recursivelyDeleteUnusedNodes(LN);
35254
+ return N; // Return N so it doesn't get rechecked!
35255
+ }
35256
+ }
35257
+
35258
+ // If this is a truncate of an i16 extload, we can directly replace it.
35259
+ if (ISD::isUNINDEXEDLoad(Src.getOperand(0).getNode()) &&
35260
+ ISD::isEXTLoad(Src.getOperand(0).getNode())) {
35261
+ LoadSDNode *LN = cast<LoadSDNode>(Src.getOperand(0));
35262
+ if (LN->getMemoryVT().getSizeInBits() == 16) {
35263
+ SDVTList Tys = DAG.getVTList(VT, MVT::Other);
35264
+ SDValue Ops[] = { LN->getChain(), LN->getBasePtr() };
35265
+ SDValue BcastLd =
35266
+ DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops,
35267
+ LN->getMemoryVT(), LN->getMemOperand());
35268
+ DCI.CombineTo(N.getNode(), BcastLd);
35269
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
35270
+ DCI.recursivelyDeleteUnusedNodes(LN);
35271
+ return N; // Return N so it doesn't get rechecked!
35272
+ }
35273
+ }
35274
+
35275
+ // If this is a truncate of load that has been shifted right, we can
35276
+ // offset the pointer and use a narrower load.
35277
+ if (TruncIn.getOpcode() == ISD::SRL &&
35278
+ TruncIn.getOperand(0).hasOneUse() &&
35279
+ isa<ConstantSDNode>(TruncIn.getOperand(1)) &&
35280
+ ISD::isNormalLoad(TruncIn.getOperand(0).getNode())) {
35281
+ LoadSDNode *LN = cast<LoadSDNode>(TruncIn.getOperand(0));
35282
+ unsigned ShiftAmt = TruncIn.getConstantOperandVal(1);
35283
+ // Make sure the shift amount and the load size are divisible by 16.
35284
+ // Don't do this if the load is volatile or atomic.
35285
+ if (ShiftAmt % 16 == 0 && TruncIn.getValueSizeInBits() % 16 == 0 &&
35286
+ LN->isSimple()) {
35287
+ unsigned Offset = ShiftAmt / 8;
35288
+ SDVTList Tys = DAG.getVTList(VT, MVT::Other);
35289
+ SDValue Ptr = DAG.getMemBasePlusOffset(LN->getBasePtr(), Offset, DL);
35290
+ SDValue Ops[] = { LN->getChain(), Ptr };
35291
+ SDValue BcastLd =
35292
+ DAG.getMemIntrinsicNode(X86ISD::VBROADCAST_LOAD, DL, Tys, Ops,
35293
+ MVT::i16,
35294
+ LN->getPointerInfo().getWithOffset(Offset),
35295
+ MinAlign(LN->getAlignment(), Offset),
35296
+ LN->getMemOperand()->getFlags());
35297
+ DCI.CombineTo(N.getNode(), BcastLd);
35298
+ DAG.ReplaceAllUsesOfValueWith(SDValue(LN, 1), BcastLd.getValue(1));
35299
+ DCI.recursivelyDeleteUnusedNodes(LN);
35300
+ return N; // Return N so it doesn't get rechecked!
35301
+ }
35250
35302
}
35251
35303
}
35252
35304
0 commit comments