Skip to content

Commit eb9570d

Browse files
committed
Include support for Add/Mul/Or/And/Xor Binary Operations
1 parent 44a3268 commit eb9570d

File tree

2 files changed

+257
-63
lines changed

2 files changed

+257
-63
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 189 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -2989,21 +2989,72 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
29892989
return foldSelectShuffle(*Shuffle, true);
29902990
}
29912991

2992+
/// For a given chain of patterns of the following form:
2993+
///
2994+
/// ```
2995+
/// %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
2996+
///
2997+
/// %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
2998+
/// ty1> %1)
2999+
/// OR
3000+
/// %2 = add/mul/or/and/xor <n x ty1> %0, %1
3001+
///
3002+
/// %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3003+
/// ...
3004+
/// ...
3005+
/// %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3006+
/// 3), <n x ty1> %(i - 2)
3007+
/// OR
3008+
/// %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3009+
///
3010+
/// %(i) = extractelement <n x ty1> %(i - 1), 0
3011+
/// ```
3012+
///
3013+
/// Where:
3014+
/// `mask` follows a partition pattern:
3015+
///
3016+
/// Ex:
3017+
/// [n = 8, p = poison]
3018+
///
3019+
/// 4 5 6 7 | p p p p
3020+
/// 2 3 | p p p p p p
3021+
/// 1 | p p p p p p p
3022+
///
3023+
/// For powers of 2, there's a consistent pattern, but for other cases
3024+
/// the parity of the current half value at each step decides the
3025+
/// next partition half (see `ExpectedParityMask` for more logical details
3026+
/// in generalising this).
3027+
///
3028+
/// Ex:
3029+
/// [n = 6]
3030+
///
3031+
/// 3 4 5 | p p p
3032+
/// 1 2 | p p p p
3033+
/// 1 | p p p p p
29923034
bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3035+
// Going bottom-up for the pattern.
29933036
auto *EEI = dyn_cast<ExtractElementInst>(&I);
29943037
if (!EEI)
29953038
return false;
29963039

29973040
std::queue<Value *> InstWorklist;
3041+
InstructionCost OrigCost = 0;
3042+
29983043
Value *InitEEV = nullptr;
2999-
Intrinsic::ID CommonOp = 0;
30003044

3001-
bool IsFirstCallInst = true;
3002-
bool ShouldBeCallInst = true;
3045+
// Common instruction operation after each shuffle op.
3046+
unsigned int CommonCallOp = 0;
3047+
Instruction::BinaryOps CommonBinOp = Instruction::BinaryOpsEnd;
30033048

3049+
bool IsFirstCallOrBinInst = true;
3050+
bool ShouldBeCallOrBinInst = true;
3051+
3052+
// This stores the last used instructions for shuffle/common op.
3053+
//
3054+
// PrevVecV[2] stores the first vector from extract element instruction,
3055+
// while PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3056+
// instructions from either shuffle/common op.
30043057
SmallVector<Value *, 3> PrevVecV(3, nullptr);
3005-
int64_t ShuffleMaskHalf = -1, ExpectedShuffleMaskHalf = 1;
3006-
int64_t VecSize = -1;
30073058

30083059
Value *VecOp;
30093060
if (!match(&I, m_ExtractElt(m_Value(VecOp), m_Zero())))
@@ -3013,11 +3064,29 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
30133064
if (!FVT)
30143065
return false;
30153066

3016-
VecSize = FVT->getNumElements();
3017-
if (VecSize < 2 || (VecSize % 2) != 0)
3067+
int64_t VecSize = FVT->getNumElements();
3068+
if (VecSize < 2)
30183069
return false;
30193070

3020-
ShuffleMaskHalf = 1;
3071+
// Number of levels would be ~log2(n), considering we always partition
3072+
// by half for this fold pattern.
3073+
unsigned int NumLevels = Log2_64_Ceil(VecSize), VisitedCnt = 0;
3074+
int64_t ShuffleMaskHalf = 1, ExpectedParityMask = 0;
3075+
3076+
// This is how we generalise for all element sizes.
3077+
// At each step, if vector size is odd, we need non-poison
3078+
// values to cover the dominant half so we don't miss out on any element.
3079+
//
3080+
// This mask will help us retrieve this as we go from bottom to top:
3081+
//
3082+
// Mask Set -> N = N * 2 - 1
3083+
// Mask Unset -> N = N * 2
3084+
for (int Cur = VecSize, Mask = NumLevels - 1; Cur > 1;
3085+
Cur = (Cur + 1) / 2, --Mask) {
3086+
if (Cur & 1)
3087+
ExpectedParityMask |= (1ll << Mask);
3088+
}
3089+
30213090
PrevVecV[2] = VecOp;
30223091
InitEEV = EEI;
30233092

@@ -3031,49 +3100,100 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
30313100
if (!CI)
30323101
return false;
30333102

3034-
if (auto *CallI = dyn_cast<CallInst>(CI)) {
3035-
if (!ShouldBeCallInst || !PrevVecV[2])
3103+
if (auto *II = dyn_cast<IntrinsicInst>(CI)) {
3104+
if (!ShouldBeCallOrBinInst || !PrevVecV[2])
30363105
return false;
30373106

3038-
if (!IsFirstCallInst &&
3107+
if (!IsFirstCallOrBinInst &&
30393108
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
30403109
return false;
30413110

3042-
if (CallI != (IsFirstCallInst ? PrevVecV[2] : PrevVecV[0]))
3043-
return false;
3044-
IsFirstCallInst = false;
3045-
3046-
auto *II = dyn_cast<IntrinsicInst>(CallI);
3047-
if (!II)
3111+
// For the first found call/bin op, the vector has to come from the
3112+
// extract element op.
3113+
if (II != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
30483114
return false;
3115+
IsFirstCallOrBinInst = false;
30493116

3050-
if (!CommonOp)
3051-
CommonOp = II->getIntrinsicID();
3052-
if (II->getIntrinsicID() != CommonOp)
3117+
if (!CommonCallOp)
3118+
CommonCallOp = II->getIntrinsicID();
3119+
if (II->getIntrinsicID() != CommonCallOp)
30533120
return false;
30543121

30553122
switch (II->getIntrinsicID()) {
30563123
case Intrinsic::umin:
30573124
case Intrinsic::umax:
30583125
case Intrinsic::smin:
30593126
case Intrinsic::smax: {
3060-
auto *Op0 = CallI->getOperand(0);
3061-
auto *Op1 = CallI->getOperand(1);
3127+
auto *Op0 = II->getOperand(0);
3128+
auto *Op1 = II->getOperand(1);
3129+
PrevVecV[0] = Op0;
3130+
PrevVecV[1] = Op1;
3131+
break;
3132+
}
3133+
default:
3134+
return false;
3135+
}
3136+
ShouldBeCallOrBinInst ^= 1;
3137+
3138+
IntrinsicCostAttributes ICA(
3139+
CommonCallOp, II->getType(),
3140+
{PrevVecV[0]->getType(), PrevVecV[1]->getType()});
3141+
OrigCost += TTI.getIntrinsicInstrCost(ICA, CostKind);
3142+
3143+
// We may need a swap here since it can be (a, b) or (b, a)
3144+
// and accordinly change as we go up.
3145+
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
3146+
std::swap(PrevVecV[0], PrevVecV[1]);
3147+
InstWorklist.push(PrevVecV[1]);
3148+
InstWorklist.push(PrevVecV[0]);
3149+
} else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
3150+
// Similar logic for bin ops.
3151+
3152+
if (!ShouldBeCallOrBinInst || !PrevVecV[2])
3153+
return false;
3154+
3155+
if (!IsFirstCallOrBinInst &&
3156+
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
3157+
return false;
3158+
3159+
if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2] : PrevVecV[0]))
3160+
return false;
3161+
IsFirstCallOrBinInst = false;
3162+
3163+
if (CommonBinOp == Instruction::BinaryOpsEnd)
3164+
CommonBinOp = BinOp->getOpcode();
3165+
3166+
if (BinOp->getOpcode() != CommonBinOp)
3167+
return false;
3168+
3169+
switch (CommonBinOp) {
3170+
case BinaryOperator::Add:
3171+
case BinaryOperator::Mul:
3172+
case BinaryOperator::Or:
3173+
case BinaryOperator::And:
3174+
case BinaryOperator::Xor: {
3175+
auto *Op0 = BinOp->getOperand(0);
3176+
auto *Op1 = BinOp->getOperand(1);
30623177
PrevVecV[0] = Op0;
30633178
PrevVecV[1] = Op1;
30643179
break;
30653180
}
30663181
default:
30673182
return false;
30683183
}
3069-
ShouldBeCallInst ^= 1;
3184+
ShouldBeCallOrBinInst ^= 1;
3185+
3186+
OrigCost +=
3187+
TTI.getArithmeticInstrCost(CommonBinOp, BinOp->getType(), CostKind);
30703188

30713189
if (!isa<ShuffleVectorInst>(PrevVecV[1]))
30723190
std::swap(PrevVecV[0], PrevVecV[1]);
30733191
InstWorklist.push(PrevVecV[1]);
30743192
InstWorklist.push(PrevVecV[0]);
30753193
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3076-
if (ShouldBeCallInst ||
3194+
// We shouldn't have any null values in the previous vectors,
3195+
// is so, there was a mismatch in pattern.
3196+
if (ShouldBeCallOrBinInst ||
30773197
any_of(PrevVecV, [](Value *VecV) { return VecV == nullptr; }))
30783198
return false;
30793199

@@ -3084,70 +3204,76 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
30843204
if (!ShuffleVec || ShuffleVec != PrevVecV[0])
30853205
return false;
30863206

3087-
SmallVector<int> CurMask;
3088-
SVInst->getShuffleMask(CurMask);
3089-
3090-
if (ShuffleMaskHalf != ExpectedShuffleMaskHalf)
3207+
if (!isa<PoisonValue>(SVInst->getOperand(1)))
30913208
return false;
3092-
ExpectedShuffleMaskHalf *= 2;
30933209

3210+
ArrayRef<int> CurMask = SVInst->getShuffleMask();
3211+
3212+
// Subtract the parity mask when checking the condition.
30943213
for (int Mask = 0, MaskSize = CurMask.size(); Mask != MaskSize; ++Mask) {
3095-
if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask)
3214+
if (Mask < ShuffleMaskHalf &&
3215+
CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1))
30963216
return false;
30973217
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1)
30983218
return false;
30993219
}
3220+
3221+
// Update mask values.
31003222
ShuffleMaskHalf *= 2;
3101-
if (ExpectedShuffleMaskHalf == VecSize)
3223+
ShuffleMaskHalf -= (ExpectedParityMask & 1);
3224+
ExpectedParityMask >>= 1;
3225+
3226+
OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3227+
SVInst->getType(), SVInst->getType(),
3228+
CurMask, CostKind);
3229+
3230+
VisitedCnt += 1;
3231+
if (!ExpectedParityMask && VisitedCnt == NumLevels)
31023232
break;
3103-
ShouldBeCallInst ^= 1;
3233+
3234+
ShouldBeCallOrBinInst ^= 1;
31043235
} else {
31053236
return false;
31063237
}
31073238
}
31083239

3109-
if (ShouldBeCallInst)
3240+
// Pattern should end with a shuffle op.
3241+
if (ShouldBeCallOrBinInst)
31103242
return false;
31113243

3112-
assert(VecSize != -1 && ExpectedShuffleMaskHalf == VecSize &&
3113-
"Expected Match for Vector Size and Mask Half");
3244+
assert(VecSize != -1 && "Expected Match for Vector Size");
31143245

31153246
Value *FinalVecV = PrevVecV[0];
3116-
auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType());
3117-
31183247
if (!InitEEV || !FinalVecV)
31193248
return false;
31203249

3250+
auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType());
3251+
31213252
assert(FinalVecVTy && "Expected non-null value for Vector Type");
31223253

31233254
Intrinsic::ID ReducedOp = 0;
3124-
switch (CommonOp) {
3125-
case Intrinsic::umin:
3126-
ReducedOp = Intrinsic::vector_reduce_umin;
3127-
break;
3128-
case Intrinsic::umax:
3129-
ReducedOp = Intrinsic::vector_reduce_umax;
3130-
break;
3131-
case Intrinsic::smin:
3132-
ReducedOp = Intrinsic::vector_reduce_smin;
3133-
break;
3134-
case Intrinsic::smax:
3135-
ReducedOp = Intrinsic::vector_reduce_smax;
3136-
break;
3137-
default:
3138-
return false;
3139-
}
3140-
3141-
InstructionCost OrigCost = 0;
3142-
unsigned int NumLevels = Log2_64(VecSize);
3143-
3144-
for (unsigned int Level = 0; Level < NumLevels; ++Level) {
3145-
OrigCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteSingleSrc,
3146-
FinalVecVTy, FinalVecVTy);
3147-
OrigCost += TTI.getArithmeticInstrCost(Instruction::ICmp, FinalVecVTy);
3255+
if (CommonCallOp) {
3256+
switch (CommonCallOp) {
3257+
case Intrinsic::umin:
3258+
ReducedOp = Intrinsic::vector_reduce_umin;
3259+
break;
3260+
case Intrinsic::umax:
3261+
ReducedOp = Intrinsic::vector_reduce_umax;
3262+
break;
3263+
case Intrinsic::smin:
3264+
ReducedOp = Intrinsic::vector_reduce_smin;
3265+
break;
3266+
case Intrinsic::smax:
3267+
ReducedOp = Intrinsic::vector_reduce_smax;
3268+
break;
3269+
default:
3270+
return false;
3271+
}
3272+
} else if (CommonBinOp != Instruction::BinaryOpsEnd) {
3273+
ReducedOp = getReductionForBinop(CommonBinOp);
3274+
if (!ReducedOp)
3275+
return false;
31483276
}
3149-
OrigCost += TTI.getVectorInstrCost(Instruction::ExtractElement, FinalVecVTy,
3150-
CostKind, 0);
31513277

31523278
IntrinsicCostAttributes ICA(ReducedOp, FinalVecVTy, {FinalVecV});
31533279
InstructionCost NewCost = TTI.getIntrinsicInstrCost(ICA, CostKind);

0 commit comments

Comments
 (0)