Skip to content

Commit 34659de

Browse files
committed
[InstCombine][X86] simplifyX86immShift - convert variable in-range vector shift by scalar amounts to generic shifts (PR40391)
The sll/srl/sra scalar vector shifts can be replaced with generic shifts if the shift amount is known to be in range. This also required public DemandedElts variants of llvm::computeKnownBits to be exposed (PR36319).
1 parent 1adfa4c commit 34659de

File tree

4 files changed

+101
-25
lines changed

4 files changed

+101
-25
lines changed

llvm/include/llvm/Analysis/ValueTracking.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,22 @@ class Value;
5959
OptimizationRemarkEmitter *ORE = nullptr,
6060
bool UseInstrInfo = true);
6161

62+
/// Determine which bits of V are known to be either zero or one and return
63+
/// them in the KnownZero/KnownOne bit sets.
64+
///
65+
/// This function is defined on values with integer type, values with pointer
66+
/// type, and vectors of integers. In the case
67+
/// where V is a vector, the known zero and known one values are the
68+
/// same width as the vector element, and the bit is set only if it is true
69+
/// for all of the demanded elements in the vector.
70+
void computeKnownBits(const Value *V, const APInt &DemandedElts,
71+
KnownBits &Known, const DataLayout &DL,
72+
unsigned Depth = 0, AssumptionCache *AC = nullptr,
73+
const Instruction *CxtI = nullptr,
74+
const DominatorTree *DT = nullptr,
75+
OptimizationRemarkEmitter *ORE = nullptr,
76+
bool UseInstrInfo = true);
77+
6278
/// Returns the known bits rather than passing by reference.
6379
KnownBits computeKnownBits(const Value *V, const DataLayout &DL,
6480
unsigned Depth = 0, AssumptionCache *AC = nullptr,
@@ -67,6 +83,15 @@ class Value;
6783
OptimizationRemarkEmitter *ORE = nullptr,
6884
bool UseInstrInfo = true);
6985

86+
/// Returns the known bits rather than passing by reference.
87+
KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
88+
const DataLayout &DL, unsigned Depth = 0,
89+
AssumptionCache *AC = nullptr,
90+
const Instruction *CxtI = nullptr,
91+
const DominatorTree *DT = nullptr,
92+
OptimizationRemarkEmitter *ORE = nullptr,
93+
bool UseInstrInfo = true);
94+
7095
/// Compute known bits from the range metadata.
7196
/// \p KnownZero the set of bits that are known to be zero
7297
/// \p KnownOne the set of bits that are known to be one

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,15 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known,
215215
Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
216216
}
217217

218+
void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
219+
KnownBits &Known, const DataLayout &DL,
220+
unsigned Depth, AssumptionCache *AC,
221+
const Instruction *CxtI, const DominatorTree *DT,
222+
OptimizationRemarkEmitter *ORE, bool UseInstrInfo) {
223+
::computeKnownBits(V, DemandedElts, Known, Depth,
224+
Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
225+
}
226+
218227
static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts,
219228
unsigned Depth, const Query &Q);
220229

@@ -231,6 +240,17 @@ KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL,
231240
V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
232241
}
233242

243+
KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts,
244+
const DataLayout &DL, unsigned Depth,
245+
AssumptionCache *AC, const Instruction *CxtI,
246+
const DominatorTree *DT,
247+
OptimizationRemarkEmitter *ORE,
248+
bool UseInstrInfo) {
249+
return ::computeKnownBits(
250+
V, DemandedElts, Depth,
251+
Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE));
252+
}
253+
234254
bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS,
235255
const DataLayout &DL, AssumptionCache *AC,
236256
const Instruction *CxtI, const DominatorTree *DT,

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,14 +376,15 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
376376
auto Amt = II.getArgOperand(1);
377377
auto VT = cast<VectorType>(Vec->getType());
378378
auto SVT = VT->getElementType();
379+
auto AmtVT = Amt->getType();
379380
unsigned VWidth = VT->getNumElements();
380381
unsigned BitWidth = SVT->getPrimitiveSizeInBits();
381382

382383
// If the shift amount is guaranteed to be in-range we can replace it with a
383384
// generic shift. If its guaranteed to be out of range, logical shifts combine to
384385
// zero and arithmetic shifts are clamped to (BitWidth - 1).
385386
if (IsImm) {
386-
assert(Amt->getType()->isIntegerTy(32) &&
387+
assert(AmtVT ->isIntegerTy(32) &&
387388
"Unexpected shift-by-immediate type");
388389
KnownBits KnownAmtBits =
389390
llvm::computeKnownBits(Amt, II.getModule()->getDataLayout());
@@ -400,6 +401,27 @@ static Value *simplifyX86immShift(const IntrinsicInst &II,
400401
Amt = ConstantInt::get(SVT, BitWidth - 1);
401402
return Builder.CreateAShr(Vec, Builder.CreateVectorSplat(VWidth, Amt));
402403
}
404+
} else {
405+
// Ensure the first element has an in-range value and the rest of the
406+
// elements in the bottom 64 bits are zero.
407+
assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 &&
408+
cast<VectorType>(AmtVT)->getElementType() == SVT &&
409+
"Unexpected shift-by-scalar type");
410+
unsigned NumAmtElts = cast<VectorType>(AmtVT)->getNumElements();
411+
APInt DemandedLower = APInt::getOneBitSet(NumAmtElts, 0);
412+
APInt DemandedUpper = APInt::getBitsSet(NumAmtElts, 1, NumAmtElts / 2);
413+
KnownBits KnownLowerBits = llvm::computeKnownBits(
414+
Amt, DemandedLower, II.getModule()->getDataLayout());
415+
KnownBits KnownUpperBits = llvm::computeKnownBits(
416+
Amt, DemandedUpper, II.getModule()->getDataLayout());
417+
if (KnownLowerBits.getMaxValue().ult(BitWidth) &&
418+
(DemandedUpper.isNullValue() || KnownUpperBits.isZero())) {
419+
SmallVector<uint32_t, 16> ZeroSplat(VWidth, 0);
420+
Amt = Builder.CreateShuffleVector(Amt, Amt, ZeroSplat);
421+
return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt)
422+
: Builder.CreateLShr(Vec, Amt))
423+
: Builder.CreateAShr(Vec, Amt));
424+
}
403425
}
404426

405427
// Simplify if count is constant vector.

llvm/test/Transforms/InstCombine/X86/x86-vector-shifts.ll

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,9 +2680,10 @@ define <32 x i16> @avx512_psllv_w_512_undef(<32 x i16> %v) {
26802680

26812681
define <8 x i16> @sse2_psra_w_128_masked(<8 x i16> %v, <8 x i16> %a) {
26822682
; CHECK-LABEL: @sse2_psra_w_128_masked(
2683-
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
2684-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i16> @llvm.x86.sse2.psra.w(<8 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
2685-
; CHECK-NEXT: ret <8 x i16> [[TMP2]]
2683+
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
2684+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <8 x i32> zeroinitializer
2685+
; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i16> [[V:%.*]], [[TMP2]]
2686+
; CHECK-NEXT: ret <8 x i16> [[TMP3]]
26862687
;
26872688
%1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
26882689
%2 = tail call <8 x i16> @llvm.x86.sse2.psra.w(<8 x i16> %v, <8 x i16> %1)
@@ -2691,9 +2692,10 @@ define <8 x i16> @sse2_psra_w_128_masked(<8 x i16> %v, <8 x i16> %a) {
26912692

26922693
define <8 x i32> @avx2_psra_d_256_masked(<8 x i32> %v, <4 x i32> %a) {
26932694
; CHECK-LABEL: @avx2_psra_d_256_masked(
2694-
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
2695-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i32> @llvm.x86.avx2.psra.d(<8 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
2696-
; CHECK-NEXT: ret <8 x i32> [[TMP2]]
2695+
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
2696+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <8 x i32> zeroinitializer
2697+
; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i32> [[V:%.*]], [[TMP2]]
2698+
; CHECK-NEXT: ret <8 x i32> [[TMP3]]
26972699
;
26982700
%1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
26992701
%2 = tail call <8 x i32> @llvm.x86.avx2.psra.d(<8 x i32> %v, <4 x i32> %1)
@@ -2703,8 +2705,9 @@ define <8 x i32> @avx2_psra_d_256_masked(<8 x i32> %v, <4 x i32> %a) {
27032705
define <8 x i64> @avx512_psra_q_512_masked(<8 x i64> %v, <2 x i64> %a) {
27042706
; CHECK-LABEL: @avx512_psra_q_512_masked(
27052707
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
2706-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <8 x i64> @llvm.x86.avx512.psra.q.512(<8 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
2707-
; CHECK-NEXT: ret <8 x i64> [[TMP2]]
2708+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <8 x i32> zeroinitializer
2709+
; CHECK-NEXT: [[TMP3:%.*]] = ashr <8 x i64> [[V:%.*]], [[TMP2]]
2710+
; CHECK-NEXT: ret <8 x i64> [[TMP3]]
27082711
;
27092712
%1 = and <2 x i64> %a, <i64 63, i64 undef>
27102713
%2 = tail call <8 x i64> @llvm.x86.avx512.psra.q.512(<8 x i64> %v, <2 x i64> %1)
@@ -2713,9 +2716,10 @@ define <8 x i64> @avx512_psra_q_512_masked(<8 x i64> %v, <2 x i64> %a) {
27132716

27142717
define <4 x i32> @sse2_psrl_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
27152718
; CHECK-LABEL: @sse2_psrl_d_128_masked(
2716-
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
2717-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
2718-
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
2719+
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
2720+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <4 x i32> zeroinitializer
2721+
; CHECK-NEXT: [[TMP3:%.*]] = lshr <4 x i32> [[V:%.*]], [[TMP2]]
2722+
; CHECK-NEXT: ret <4 x i32> [[TMP3]]
27192723
;
27202724
%1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
27212725
%2 = tail call <4 x i32> @llvm.x86.sse2.psrl.d(<4 x i32> %v, <4 x i32> %1)
@@ -2725,8 +2729,9 @@ define <4 x i32> @sse2_psrl_d_128_masked(<4 x i32> %v, <4 x i32> %a) {
27252729
define <4 x i64> @avx2_psrl_q_256_masked(<4 x i64> %v, <2 x i64> %a) {
27262730
; CHECK-LABEL: @avx2_psrl_q_256_masked(
27272731
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
2728-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
2729-
; CHECK-NEXT: ret <4 x i64> [[TMP2]]
2732+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <4 x i32> zeroinitializer
2733+
; CHECK-NEXT: [[TMP3:%.*]] = lshr <4 x i64> [[V:%.*]], [[TMP2]]
2734+
; CHECK-NEXT: ret <4 x i64> [[TMP3]]
27302735
;
27312736
%1 = and <2 x i64> %a, <i64 63, i64 undef>
27322737
%2 = tail call <4 x i64> @llvm.x86.avx2.psrl.q(<4 x i64> %v, <2 x i64> %1)
@@ -2735,9 +2740,10 @@ define <4 x i64> @avx2_psrl_q_256_masked(<4 x i64> %v, <2 x i64> %a) {
27352740

27362741
define <32 x i16> @avx512_psrl_w_512_masked(<32 x i16> %v, <8 x i16> %a) {
27372742
; CHECK-LABEL: @avx512_psrl_w_512_masked(
2738-
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
2739-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <32 x i16> @llvm.x86.avx512.psrl.w.512(<32 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
2740-
; CHECK-NEXT: ret <32 x i16> [[TMP2]]
2743+
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
2744+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <32 x i32> zeroinitializer
2745+
; CHECK-NEXT: [[TMP3:%.*]] = lshr <32 x i16> [[V:%.*]], [[TMP2]]
2746+
; CHECK-NEXT: ret <32 x i16> [[TMP3]]
27412747
;
27422748
%1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
27432749
%2 = tail call <32 x i16> @llvm.x86.avx512.psrl.w.512(<32 x i16> %v, <8 x i16> %1)
@@ -2747,8 +2753,9 @@ define <32 x i16> @avx512_psrl_w_512_masked(<32 x i16> %v, <8 x i16> %a) {
27472753
define <2 x i64> @sse2_psll_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
27482754
; CHECK-LABEL: @sse2_psll_q_128_masked(
27492755
; CHECK-NEXT: [[TMP1:%.*]] = and <2 x i64> [[A:%.*]], <i64 63, i64 undef>
2750-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64> [[V:%.*]], <2 x i64> [[TMP1]])
2751-
; CHECK-NEXT: ret <2 x i64> [[TMP2]]
2756+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i64> [[TMP1]], <2 x i64> undef, <2 x i32> zeroinitializer
2757+
; CHECK-NEXT: [[TMP3:%.*]] = shl <2 x i64> [[V:%.*]], [[TMP2]]
2758+
; CHECK-NEXT: ret <2 x i64> [[TMP3]]
27522759
;
27532760
%1 = and <2 x i64> %a, <i64 63, i64 undef>
27542761
%2 = tail call <2 x i64> @llvm.x86.sse2.psll.q(<2 x i64> %v, <2 x i64> %1)
@@ -2757,9 +2764,10 @@ define <2 x i64> @sse2_psll_q_128_masked(<2 x i64> %v, <2 x i64> %a) {
27572764

27582765
define <16 x i16> @avx2_psll_w_256_masked(<16 x i16> %v, <8 x i16> %a) {
27592766
; CHECK-LABEL: @avx2_psll_w_256_masked(
2760-
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
2761-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <16 x i16> @llvm.x86.avx2.psll.w(<16 x i16> [[V:%.*]], <8 x i16> [[TMP1]])
2762-
; CHECK-NEXT: ret <16 x i16> [[TMP2]]
2767+
; CHECK-NEXT: [[TMP1:%.*]] = and <8 x i16> [[A:%.*]], <i16 15, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef, i16 undef>
2768+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <8 x i16> [[TMP1]], <8 x i16> undef, <16 x i32> zeroinitializer
2769+
; CHECK-NEXT: [[TMP3:%.*]] = shl <16 x i16> [[V:%.*]], [[TMP2]]
2770+
; CHECK-NEXT: ret <16 x i16> [[TMP3]]
27632771
;
27642772
%1 = and <8 x i16> %a, <i16 15, i16 0, i16 0, i16 0, i16 undef, i16 undef, i16 undef, i16 undef>
27652773
%2 = tail call <16 x i16> @llvm.x86.avx2.psll.w(<16 x i16> %v, <8 x i16> %1)
@@ -2768,9 +2776,10 @@ define <16 x i16> @avx2_psll_w_256_masked(<16 x i16> %v, <8 x i16> %a) {
27682776

27692777
define <16 x i32> @avx512_psll_d_512_masked(<16 x i32> %v, <4 x i32> %a) {
27702778
; CHECK-LABEL: @avx512_psll_d_512_masked(
2771-
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 0, i32 undef, i32 undef>
2772-
; CHECK-NEXT: [[TMP2:%.*]] = tail call <16 x i32> @llvm.x86.avx512.psll.d.512(<16 x i32> [[V:%.*]], <4 x i32> [[TMP1]])
2773-
; CHECK-NEXT: ret <16 x i32> [[TMP2]]
2779+
; CHECK-NEXT: [[TMP1:%.*]] = and <4 x i32> [[A:%.*]], <i32 31, i32 undef, i32 undef, i32 undef>
2780+
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[TMP1]], <4 x i32> undef, <16 x i32> zeroinitializer
2781+
; CHECK-NEXT: [[TMP3:%.*]] = shl <16 x i32> [[V:%.*]], [[TMP2]]
2782+
; CHECK-NEXT: ret <16 x i32> [[TMP3]]
27742783
;
27752784
%1 = and <4 x i32> %a, <i32 31, i32 0, i32 undef, i32 undef>
27762785
%2 = tail call <16 x i32> @llvm.x86.avx512.psll.d.512(<16 x i32> %v, <4 x i32> %1)

0 commit comments

Comments
 (0)