@@ -2989,21 +2989,72 @@ bool VectorCombine::foldShuffleFromReductions(Instruction &I) {
2989
2989
return foldSelectShuffle (*Shuffle, true );
2990
2990
}
2991
2991
2992
+ // / For a given chain of patterns of the following form:
2993
+ // /
2994
+ // / ```
2995
+ // / %1 = shufflevector <n x ty1> %0, <n x ty1> poison <n x ty2> mask
2996
+ // /
2997
+ // / %2 = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %0, <n x
2998
+ // / ty1> %1)
2999
+ // / OR
3000
+ // / %2 = add/mul/or/and/xor <n x ty1> %0, %1
3001
+ // /
3002
+ // / %3 = shufflevector <n x ty1> %2, <n x ty1> poison <n x ty2> mask
3003
+ // / ...
3004
+ // / ...
3005
+ // / %(i - 1) = tail call <n x ty1> llvm.<umin/umax/smin/smax>(<n x ty1> %(i -
3006
+ // / 3), <n x ty1> %(i - 2)
3007
+ // / OR
3008
+ // / %(i - 1) = add/mul/or/and/xor <n x ty1> %(i - 3), %(i - 2)
3009
+ // /
3010
+ // / %(i) = extractelement <n x ty1> %(i - 1), 0
3011
+ // / ```
3012
+ // /
3013
+ // / Where:
3014
+ // / `mask` follows a partition pattern:
3015
+ // /
3016
+ // / Ex:
3017
+ // / [n = 8, p = poison]
3018
+ // /
3019
+ // / 4 5 6 7 | p p p p
3020
+ // / 2 3 | p p p p p p
3021
+ // / 1 | p p p p p p p
3022
+ // /
3023
+ // / For powers of 2, there's a consistent pattern, but for other cases
3024
+ // / the parity of the current half value at each step decides the
3025
+ // / next partition half (see `ExpectedParityMask` for more logical details
3026
+ // / in generalising this).
3027
+ // /
3028
+ // / Ex:
3029
+ // / [n = 6]
3030
+ // /
3031
+ // / 3 4 5 | p p p
3032
+ // / 1 2 | p p p p
3033
+ // / 1 | p p p p p
2992
3034
bool VectorCombine::foldShuffleChainsToReduce (Instruction &I) {
3035
+ // Going bottom-up for the pattern.
2993
3036
auto *EEI = dyn_cast<ExtractElementInst>(&I);
2994
3037
if (!EEI)
2995
3038
return false ;
2996
3039
2997
3040
std::queue<Value *> InstWorklist;
3041
+ InstructionCost OrigCost = 0 ;
3042
+
2998
3043
Value *InitEEV = nullptr ;
2999
- Intrinsic::ID CommonOp = 0 ;
3000
3044
3001
- bool IsFirstCallInst = true ;
3002
- bool ShouldBeCallInst = true ;
3045
+ // Common instruction operation after each shuffle op.
3046
+ unsigned int CommonCallOp = 0 ;
3047
+ Instruction::BinaryOps CommonBinOp = Instruction::BinaryOpsEnd;
3003
3048
3049
+ bool IsFirstCallOrBinInst = true ;
3050
+ bool ShouldBeCallOrBinInst = true ;
3051
+
3052
+ // This stores the last used instructions for shuffle/common op.
3053
+ //
3054
+ // PrevVecV[2] stores the first vector from extract element instruction,
3055
+ // while PrevVecV[0] / PrevVecV[1] store the last two simultaneous
3056
+ // instructions from either shuffle/common op.
3004
3057
SmallVector<Value *, 3 > PrevVecV (3 , nullptr );
3005
- int64_t ShuffleMaskHalf = -1 , ExpectedShuffleMaskHalf = 1 ;
3006
- int64_t VecSize = -1 ;
3007
3058
3008
3059
Value *VecOp;
3009
3060
if (!match (&I, m_ExtractElt (m_Value (VecOp), m_Zero ())))
@@ -3013,11 +3064,29 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3013
3064
if (!FVT)
3014
3065
return false ;
3015
3066
3016
- VecSize = FVT->getNumElements ();
3017
- if (VecSize < 2 || (VecSize % 2 ) != 0 )
3067
+ int64_t VecSize = FVT->getNumElements ();
3068
+ if (VecSize < 2 )
3018
3069
return false ;
3019
3070
3020
- ShuffleMaskHalf = 1 ;
3071
+ // Number of levels would be ~log2(n), considering we always partition
3072
+ // by half for this fold pattern.
3073
+ unsigned int NumLevels = Log2_64_Ceil (VecSize), VisitedCnt = 0 ;
3074
+ int64_t ShuffleMaskHalf = 1 , ExpectedParityMask = 0 ;
3075
+
3076
+ // This is how we generalise for all element sizes.
3077
+ // At each step, if vector size is odd, we need non-poison
3078
+ // values to cover the dominant half so we don't miss out on any element.
3079
+ //
3080
+ // This mask will help us retrieve this as we go from bottom to top:
3081
+ //
3082
+ // Mask Set -> N = N * 2 - 1
3083
+ // Mask Unset -> N = N * 2
3084
+ for (int Cur = VecSize, Mask = NumLevels - 1 ; Cur > 1 ;
3085
+ Cur = (Cur + 1 ) / 2 , --Mask) {
3086
+ if (Cur & 1 )
3087
+ ExpectedParityMask |= (1ll << Mask);
3088
+ }
3089
+
3021
3090
PrevVecV[2 ] = VecOp;
3022
3091
InitEEV = EEI;
3023
3092
@@ -3031,49 +3100,100 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3031
3100
if (!CI)
3032
3101
return false ;
3033
3102
3034
- if (auto *CallI = dyn_cast<CallInst >(CI)) {
3035
- if (!ShouldBeCallInst || !PrevVecV[2 ])
3103
+ if (auto *II = dyn_cast<IntrinsicInst >(CI)) {
3104
+ if (!ShouldBeCallOrBinInst || !PrevVecV[2 ])
3036
3105
return false ;
3037
3106
3038
- if (!IsFirstCallInst &&
3107
+ if (!IsFirstCallOrBinInst &&
3039
3108
any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3040
3109
return false ;
3041
3110
3042
- if (CallI != (IsFirstCallInst ? PrevVecV[2 ] : PrevVecV[0 ]))
3043
- return false ;
3044
- IsFirstCallInst = false ;
3045
-
3046
- auto *II = dyn_cast<IntrinsicInst>(CallI);
3047
- if (!II)
3111
+ // For the first found call/bin op, the vector has to come from the
3112
+ // extract element op.
3113
+ if (II != (IsFirstCallOrBinInst ? PrevVecV[2 ] : PrevVecV[0 ]))
3048
3114
return false ;
3115
+ IsFirstCallOrBinInst = false ;
3049
3116
3050
- if (!CommonOp )
3051
- CommonOp = II->getIntrinsicID ();
3052
- if (II->getIntrinsicID () != CommonOp )
3117
+ if (!CommonCallOp )
3118
+ CommonCallOp = II->getIntrinsicID ();
3119
+ if (II->getIntrinsicID () != CommonCallOp )
3053
3120
return false ;
3054
3121
3055
3122
switch (II->getIntrinsicID ()) {
3056
3123
case Intrinsic::umin:
3057
3124
case Intrinsic::umax:
3058
3125
case Intrinsic::smin:
3059
3126
case Intrinsic::smax: {
3060
- auto *Op0 = CallI->getOperand (0 );
3061
- auto *Op1 = CallI->getOperand (1 );
3127
+ auto *Op0 = II->getOperand (0 );
3128
+ auto *Op1 = II->getOperand (1 );
3129
+ PrevVecV[0 ] = Op0;
3130
+ PrevVecV[1 ] = Op1;
3131
+ break ;
3132
+ }
3133
+ default :
3134
+ return false ;
3135
+ }
3136
+ ShouldBeCallOrBinInst ^= 1 ;
3137
+
3138
+ IntrinsicCostAttributes ICA (
3139
+ CommonCallOp, II->getType (),
3140
+ {PrevVecV[0 ]->getType (), PrevVecV[1 ]->getType ()});
3141
+ OrigCost += TTI.getIntrinsicInstrCost (ICA, CostKind);
3142
+
3143
+ // We may need a swap here since it can be (a, b) or (b, a)
3144
+ // and accordinly change as we go up.
3145
+ if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
3146
+ std::swap (PrevVecV[0 ], PrevVecV[1 ]);
3147
+ InstWorklist.push (PrevVecV[1 ]);
3148
+ InstWorklist.push (PrevVecV[0 ]);
3149
+ } else if (auto *BinOp = dyn_cast<BinaryOperator>(CI)) {
3150
+ // Similar logic for bin ops.
3151
+
3152
+ if (!ShouldBeCallOrBinInst || !PrevVecV[2 ])
3153
+ return false ;
3154
+
3155
+ if (!IsFirstCallOrBinInst &&
3156
+ any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3157
+ return false ;
3158
+
3159
+ if (BinOp != (IsFirstCallOrBinInst ? PrevVecV[2 ] : PrevVecV[0 ]))
3160
+ return false ;
3161
+ IsFirstCallOrBinInst = false ;
3162
+
3163
+ if (CommonBinOp == Instruction::BinaryOpsEnd)
3164
+ CommonBinOp = BinOp->getOpcode ();
3165
+
3166
+ if (BinOp->getOpcode () != CommonBinOp)
3167
+ return false ;
3168
+
3169
+ switch (CommonBinOp) {
3170
+ case BinaryOperator::Add:
3171
+ case BinaryOperator::Mul:
3172
+ case BinaryOperator::Or:
3173
+ case BinaryOperator::And:
3174
+ case BinaryOperator::Xor: {
3175
+ auto *Op0 = BinOp->getOperand (0 );
3176
+ auto *Op1 = BinOp->getOperand (1 );
3062
3177
PrevVecV[0 ] = Op0;
3063
3178
PrevVecV[1 ] = Op1;
3064
3179
break ;
3065
3180
}
3066
3181
default :
3067
3182
return false ;
3068
3183
}
3069
- ShouldBeCallInst ^= 1 ;
3184
+ ShouldBeCallOrBinInst ^= 1 ;
3185
+
3186
+ OrigCost +=
3187
+ TTI.getArithmeticInstrCost (CommonBinOp, BinOp->getType (), CostKind);
3070
3188
3071
3189
if (!isa<ShuffleVectorInst>(PrevVecV[1 ]))
3072
3190
std::swap (PrevVecV[0 ], PrevVecV[1 ]);
3073
3191
InstWorklist.push (PrevVecV[1 ]);
3074
3192
InstWorklist.push (PrevVecV[0 ]);
3075
3193
} else if (auto *SVInst = dyn_cast<ShuffleVectorInst>(CI)) {
3076
- if (ShouldBeCallInst ||
3194
+ // We shouldn't have any null values in the previous vectors,
3195
+ // is so, there was a mismatch in pattern.
3196
+ if (ShouldBeCallOrBinInst ||
3077
3197
any_of (PrevVecV, [](Value *VecV) { return VecV == nullptr ; }))
3078
3198
return false ;
3079
3199
@@ -3084,70 +3204,76 @@ bool VectorCombine::foldShuffleChainsToReduce(Instruction &I) {
3084
3204
if (!ShuffleVec || ShuffleVec != PrevVecV[0 ])
3085
3205
return false ;
3086
3206
3087
- SmallVector<int > CurMask;
3088
- SVInst->getShuffleMask (CurMask);
3089
-
3090
- if (ShuffleMaskHalf != ExpectedShuffleMaskHalf)
3207
+ if (!isa<PoisonValue>(SVInst->getOperand (1 )))
3091
3208
return false ;
3092
- ExpectedShuffleMaskHalf *= 2 ;
3093
3209
3210
+ ArrayRef<int > CurMask = SVInst->getShuffleMask ();
3211
+
3212
+ // Subtract the parity mask when checking the condition.
3094
3213
for (int Mask = 0 , MaskSize = CurMask.size (); Mask != MaskSize; ++Mask) {
3095
- if (Mask < ShuffleMaskHalf && CurMask[Mask] != ShuffleMaskHalf + Mask)
3214
+ if (Mask < ShuffleMaskHalf &&
3215
+ CurMask[Mask] != ShuffleMaskHalf + Mask - (ExpectedParityMask & 1 ))
3096
3216
return false ;
3097
3217
if (Mask >= ShuffleMaskHalf && CurMask[Mask] != -1 )
3098
3218
return false ;
3099
3219
}
3220
+
3221
+ // Update mask values.
3100
3222
ShuffleMaskHalf *= 2 ;
3101
- if (ExpectedShuffleMaskHalf == VecSize)
3223
+ ShuffleMaskHalf -= (ExpectedParityMask & 1 );
3224
+ ExpectedParityMask >>= 1 ;
3225
+
3226
+ OrigCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc,
3227
+ SVInst->getType (), SVInst->getType (),
3228
+ CurMask, CostKind);
3229
+
3230
+ VisitedCnt += 1 ;
3231
+ if (!ExpectedParityMask && VisitedCnt == NumLevels)
3102
3232
break ;
3103
- ShouldBeCallInst ^= 1 ;
3233
+
3234
+ ShouldBeCallOrBinInst ^= 1 ;
3104
3235
} else {
3105
3236
return false ;
3106
3237
}
3107
3238
}
3108
3239
3109
- if (ShouldBeCallInst)
3240
+ // Pattern should end with a shuffle op.
3241
+ if (ShouldBeCallOrBinInst)
3110
3242
return false ;
3111
3243
3112
- assert (VecSize != -1 && ExpectedShuffleMaskHalf == VecSize &&
3113
- " Expected Match for Vector Size and Mask Half" );
3244
+ assert (VecSize != -1 && " Expected Match for Vector Size" );
3114
3245
3115
3246
Value *FinalVecV = PrevVecV[0 ];
3116
- auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType ());
3117
-
3118
3247
if (!InitEEV || !FinalVecV)
3119
3248
return false ;
3120
3249
3250
+ auto *FinalVecVTy = dyn_cast<FixedVectorType>(FinalVecV->getType ());
3251
+
3121
3252
assert (FinalVecVTy && " Expected non-null value for Vector Type" );
3122
3253
3123
3254
Intrinsic::ID ReducedOp = 0 ;
3124
- switch (CommonOp) {
3125
- case Intrinsic::umin:
3126
- ReducedOp = Intrinsic::vector_reduce_umin;
3127
- break ;
3128
- case Intrinsic::umax:
3129
- ReducedOp = Intrinsic::vector_reduce_umax;
3130
- break ;
3131
- case Intrinsic::smin:
3132
- ReducedOp = Intrinsic::vector_reduce_smin;
3133
- break ;
3134
- case Intrinsic::smax:
3135
- ReducedOp = Intrinsic::vector_reduce_smax;
3136
- break ;
3137
- default :
3138
- return false ;
3139
- }
3140
-
3141
- InstructionCost OrigCost = 0 ;
3142
- unsigned int NumLevels = Log2_64 (VecSize);
3143
-
3144
- for (unsigned int Level = 0 ; Level < NumLevels; ++Level) {
3145
- OrigCost += TTI.getShuffleCost (TargetTransformInfo::SK_PermuteSingleSrc,
3146
- FinalVecVTy, FinalVecVTy);
3147
- OrigCost += TTI.getArithmeticInstrCost (Instruction::ICmp, FinalVecVTy);
3255
+ if (CommonCallOp) {
3256
+ switch (CommonCallOp) {
3257
+ case Intrinsic::umin:
3258
+ ReducedOp = Intrinsic::vector_reduce_umin;
3259
+ break ;
3260
+ case Intrinsic::umax:
3261
+ ReducedOp = Intrinsic::vector_reduce_umax;
3262
+ break ;
3263
+ case Intrinsic::smin:
3264
+ ReducedOp = Intrinsic::vector_reduce_smin;
3265
+ break ;
3266
+ case Intrinsic::smax:
3267
+ ReducedOp = Intrinsic::vector_reduce_smax;
3268
+ break ;
3269
+ default :
3270
+ return false ;
3271
+ }
3272
+ } else if (CommonBinOp != Instruction::BinaryOpsEnd) {
3273
+ ReducedOp = getReductionForBinop (CommonBinOp);
3274
+ if (!ReducedOp)
3275
+ return false ;
3148
3276
}
3149
- OrigCost += TTI.getVectorInstrCost (Instruction::ExtractElement, FinalVecVTy,
3150
- CostKind, 0 );
3151
3277
3152
3278
IntrinsicCostAttributes ICA (ReducedOp, FinalVecVTy, {FinalVecV});
3153
3279
InstructionCost NewCost = TTI.getIntrinsicInstrCost (ICA, CostKind);
0 commit comments