Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ static const int MinScheduleRegionSize = 16;
/// Maximum allowed number of operands in the PHI nodes.
static const unsigned MaxPHINumOperands = 128;

static SmallDenseMap<Value *, Value *> IdentityInstrsMp;

/// Predicate for the element types that the SLP vectorizer supports.
///
/// The most important thing to filter here are types which are invalid in LLVM
Expand Down Expand Up @@ -2075,6 +2077,57 @@ class BoUpSLP {

OptimizationRemarkEmitter *getORE() { return ORE; }

static SmallVector<Value *, 8> setIdentityInstr(ArrayRef<Value *> VL) {
SmallVector<Value *, 8> New_VL(VL.begin(), VL.end());
if (VL.size() <= 2)
return New_VL;
auto It = find_if(VL, IsaPred<Instruction>);
if (It == VL.end())
return New_VL;
// work on unique list of instructions only:
SmallDenseMap<llvm::StringRef, bool> SeenInstrs;
for (auto *V : VL)
if (auto *I = dyn_cast<Instruction>(V)) {
if (!SeenInstrs[I->getName()])
SeenInstrs[I->getName()] = true;
else {
return New_VL;
}
}
Instruction *MainOp = cast<Instruction>(*It);
auto ValidOperands = count_if(VL, IsaPred<Instruction, PoisonValue>);
if (ValidOperands != (int)VL.size() - 1)
return New_VL;
auto DifferentOperand = find_if_not(VL, IsaPred<Instruction, PoisonValue>);
if (DifferentOperand == VL.end())
return New_VL;
assert(!isa<Instruction>(*DifferentOperand) &&
!isa<PoisonValue>(*DifferentOperand) &&
"Expected different operand to be not an instruction");
auto FoundIdentityInstrIt = IdentityInstrsMp.find(*DifferentOperand);
if (FoundIdentityInstrIt != IdentityInstrsMp.end()) {
auto OperandIndex = std::distance(VL.begin(), DifferentOperand);
New_VL[OperandIndex] = FoundIdentityInstrIt->second;
return New_VL;
}
auto *Identity = ConstantExpr::getIdentity(MainOp, MainOp->getType(),
true /*AllowRHSConstant*/);
if (!Identity)
return New_VL;
auto *NewInstr = MainOp->clone();
NewInstr->setOperand(0, *DifferentOperand);
NewInstr->setOperand(1, Identity);
NewInstr->insertAfter(cast<Instruction>(MainOp));
NewInstr->setName((*DifferentOperand)->getName() + ".identity");
auto OperandIndex = std::distance(VL.begin(), DifferentOperand);
New_VL[OperandIndex] = NewInstr;
assert(find_if_not(New_VL, IsaPred<Instruction, PoisonValue>) ==
New_VL.end() &&
"Expected all operands to be instructions");
IdentityInstrsMp.try_emplace(*DifferentOperand, NewInstr);
return New_VL;
}

/// This structure holds any data we need about the edges being traversed
/// during buildTreeRec(). We keep track of:
/// (i) the user TreeEntry index, and
Expand Down Expand Up @@ -3786,7 +3839,8 @@ class BoUpSLP {
assert(OpVL.size() <= Scalars.size() &&
"Number of operands is greater than the number of scalars.");
Operands[OpIdx].resize(OpVL.size());
copy(OpVL, Operands[OpIdx].begin());
auto NewVL = BoUpSLP::setIdentityInstr(OpVL);
copy(NewVL, Operands[OpIdx].begin());
}

public:
Expand Down Expand Up @@ -4084,18 +4138,19 @@ class BoUpSLP {
"Reshuffling scalars not yet supported for nodes with padding");
Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(),
ReuseShuffleIndices.end());
SmallVector<Value *, 8> NewVL = BoUpSLP::setIdentityInstr(VL);
if (ReorderIndices.empty()) {
Last->Scalars.assign(VL.begin(), VL.end());
Last->Scalars.assign(NewVL.begin(), NewVL.end());
if (S)
Last->setOperations(S);
} else {
// Reorder scalars and build final mask.
Last->Scalars.assign(VL.size(), nullptr);
Last->Scalars.assign(NewVL.size(), nullptr);
transform(ReorderIndices, Last->Scalars.begin(),
[VL](unsigned Idx) -> Value * {
if (Idx >= VL.size())
return UndefValue::get(VL.front()->getType());
return VL[Idx];
[NewVL](unsigned Idx) -> Value * {
if (Idx >= NewVL.size())
return UndefValue::get(NewVL.front()->getType());
return NewVL[Idx];
});
InstructionsState S = getSameOpcode(Last->Scalars, *TLI);
if (S)
Expand All @@ -4106,7 +4161,7 @@ class BoUpSLP {
assert(S && "Split nodes must have operations.");
Last->setOperations(S);
SmallPtrSet<Value *, 4> Processed;
for (Value *V : VL) {
for (Value *V : NewVL) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
continue;
Expand All @@ -4121,10 +4176,10 @@ class BoUpSLP {
}
}
} else if (!Last->isGather()) {
if (doesNotNeedToSchedule(VL))
if (doesNotNeedToSchedule(NewVL))
Last->setDoesNotNeedToSchedule();
SmallPtrSet<Value *, 4> Processed;
for (Value *V : VL) {
for (Value *V : NewVL) {
if (isa<PoisonValue>(V))
continue;
auto It = ScalarToTreeEntries.find(V);
Expand All @@ -4146,7 +4201,7 @@ class BoUpSLP {
#if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS)
auto *BundleMember = Bundle.getBundle().begin();
SmallPtrSet<Value *, 4> Processed;
for (Value *V : VL) {
for (Value *V : NewVL) {
if (doesNotNeedToBeScheduled(V) || !Processed.insert(V).second)
continue;
++BundleMember;
Expand All @@ -4159,7 +4214,7 @@ class BoUpSLP {
} else {
// Build a map for gathered scalars to the nodes where they are used.
bool AllConstsOrCasts = true;
for (Value *V : VL)
for (Value *V : NewVL)
if (!isConstant(V)) {
auto *I = dyn_cast<CastInst>(V);
AllConstsOrCasts &= I && I->getType()->isIntegerTy();
Expand All @@ -4170,7 +4225,7 @@ class BoUpSLP {
if (AllConstsOrCasts)
CastMaxMinBWSizes =
std::make_pair(std::numeric_limits<unsigned>::max(), 1);
MustGather.insert_range(VL);
MustGather.insert_range(NewVL);
}

if (UserTreeIdx.UserTE)
Expand Down Expand Up @@ -20844,6 +20899,11 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_,
}
}

for (auto &I : IdentityInstrsMp) {
if (I.second && cast<Instruction>(I.second)->getParent())
cast<Instruction>(I.second)->eraseFromParent();
}
IdentityInstrsMp.clear();
if (Changed) {
R.optimizeGatherSequence();
LLVM_DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n");
Expand Down
9 changes: 2 additions & 7 deletions llvm/test/Transforms/SLPVectorizer/X86/pr47642.ll
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,8 @@ target triple = "x86_64-unknown-linux-gnu"
define <4 x i32> @foo(<4 x i32> %x, i32 %f) {
; CHECK-LABEL: @foo(
; CHECK-NEXT: [[VECINIT:%.*]] = insertelement <4 x i32> poison, i32 [[F:%.*]], i64 0
; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[F]], 1
; CHECK-NEXT: [[VECINIT1:%.*]] = insertelement <4 x i32> [[VECINIT]], i32 [[ADD]], i64 1
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[F]], i64 0
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i32> [[TMP1]], <2 x i32> poison, <2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = add nsw <2 x i32> [[TMP2]], <i32 2, i32 3>
; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> poison, <4 x i32> <i32 0, i32 1, i32 poison, i32 poison>
; CHECK-NEXT: [[VECINIT51:%.*]] = shufflevector <4 x i32> [[VECINIT1]], <4 x i32> [[TMP4]], <4 x i32> <i32 0, i32 1, i32 4, i32 5>
; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[VECINIT]], <4 x i32> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[VECINIT51:%.*]] = add nsw <4 x i32> [[TMP2]], <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: ret <4 x i32> [[VECINIT51]]
;
%vecinit = insertelement <4 x i32> undef, i32 %f, i32 0
Expand Down
129 changes: 129 additions & 0 deletions llvm/test/Transforms/SLPVectorizer/infer-missing-instruction.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=slp-vectorizer,instcombine -S < %s | FileCheck %s

define dso_local noundef i32 @_Z4testiPs(i32 noundef %a, ptr noundef readonly captures(none) %b) local_unnamed_addr #0 {
; CHECK-LABEL: define dso_local noundef i32 @_Z4testiPs(
; CHECK-SAME: i32 noundef [[A:%.*]], ptr noundef readonly captures(none) [[B:%.*]]) local_unnamed_addr {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[TMP9:%.*]] = insertelement <16 x i32> poison, i32 [[A]], i64 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i32> [[TMP9]], <16 x i32> poison, <16 x i32> zeroinitializer
; CHECK-NEXT: [[TMP16:%.*]] = lshr <16 x i32> [[TMP1]], <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP17:%.*]] = and <16 x i32> [[TMP16]], splat (i32 16)
; CHECK-NEXT: [[TMP18:%.*]] = load <16 x i16>, ptr [[B]], align 2
; CHECK-NEXT: [[TMP19:%.*]] = sext <16 x i16> [[TMP18]] to <16 x i32>
; CHECK-NEXT: [[TMP20:%.*]] = or <16 x i32> [[TMP17]], [[TMP19]]
; CHECK-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP20]])
; CHECK-NEXT: ret i32 [[TMP21]]
;
entry:
%conv = and i32 %a, 16
%0 = load i16, ptr %b, align 2
%conv2 = sext i16 %0 to i32
%or = or i32 %conv, %conv2
%shr.1 = lshr i32 %a, 1
%conv.1 = and i32 %shr.1, 16
%arrayidx.1 = getelementptr inbounds nuw i8, ptr %b, i64 2
%1 = load i16, ptr %arrayidx.1, align 2
%conv2.1 = sext i16 %1 to i32
%or.1 = or i32 %conv.1, %conv2.1
%add.1 = add nsw i32 %or.1, %or
%shr.2 = lshr i32 %a, 2
%conv.2 = and i32 %shr.2, 16
%arrayidx.2 = getelementptr inbounds nuw i8, ptr %b, i64 4
%2 = load i16, ptr %arrayidx.2, align 2
%conv2.2 = sext i16 %2 to i32
%or.2 = or i32 %conv.2, %conv2.2
%add.2 = add nsw i32 %or.2, %add.1
%shr.3 = lshr i32 %a, 3
%conv.3 = and i32 %shr.3, 16
%arrayidx.3 = getelementptr inbounds nuw i8, ptr %b, i64 6
%3 = load i16, ptr %arrayidx.3, align 2
%conv2.3 = sext i16 %3 to i32
%or.3 = or i32 %conv.3, %conv2.3
%add.3 = add nsw i32 %or.3, %add.2
%shr.4 = lshr i32 %a, 4
%conv.4 = and i32 %shr.4, 16
%arrayidx.4 = getelementptr inbounds nuw i8, ptr %b, i64 8
%4 = load i16, ptr %arrayidx.4, align 2
%conv2.4 = sext i16 %4 to i32
%or.4 = or i32 %conv.4, %conv2.4
%add.4 = add nsw i32 %or.4, %add.3
%shr.5 = lshr i32 %a, 5
%conv.5 = and i32 %shr.5, 16
%arrayidx.5 = getelementptr inbounds nuw i8, ptr %b, i64 10
%5 = load i16, ptr %arrayidx.5, align 2
%conv2.5 = sext i16 %5 to i32
%or.5 = or i32 %conv.5, %conv2.5
%add.5 = add nsw i32 %or.5, %add.4
%shr.6 = lshr i32 %a, 6
%conv.6 = and i32 %shr.6, 16
%arrayidx.6 = getelementptr inbounds nuw i8, ptr %b, i64 12
%6 = load i16, ptr %arrayidx.6, align 2
%conv2.6 = sext i16 %6 to i32
%or.6 = or i32 %conv.6, %conv2.6
%add.6 = add nsw i32 %or.6, %add.5
%shr.7 = lshr i32 %a, 7
%conv.7 = and i32 %shr.7, 16
%arrayidx.7 = getelementptr inbounds nuw i8, ptr %b, i64 14
%7 = load i16, ptr %arrayidx.7, align 2
%conv2.7 = sext i16 %7 to i32
%or.7 = or i32 %conv.7, %conv2.7
%add.7 = add nsw i32 %or.7, %add.6
%shr.8 = lshr i32 %a, 8
%conv.8 = and i32 %shr.8, 16
%arrayidx.8 = getelementptr inbounds nuw i8, ptr %b, i64 16
%8 = load i16, ptr %arrayidx.8, align 2
%conv2.8 = sext i16 %8 to i32
%or.8 = or i32 %conv.8, %conv2.8
%add.8 = add nsw i32 %or.8, %add.7
%shr.9 = lshr i32 %a, 9
%conv.9 = and i32 %shr.9, 16
%arrayidx.9 = getelementptr inbounds nuw i8, ptr %b, i64 18
%9 = load i16, ptr %arrayidx.9, align 2
%conv2.9 = sext i16 %9 to i32
%or.9 = or i32 %conv.9, %conv2.9
%add.9 = add nsw i32 %or.9, %add.8
%shr.10 = lshr i32 %a, 10
%conv.10 = and i32 %shr.10, 16
%arrayidx.10 = getelementptr inbounds nuw i8, ptr %b, i64 20
%10 = load i16, ptr %arrayidx.10, align 2
%conv2.10 = sext i16 %10 to i32
%or.10 = or i32 %conv.10, %conv2.10
%add.10 = add nsw i32 %or.10, %add.9
%shr.11 = lshr i32 %a, 11
%conv.11 = and i32 %shr.11, 16
%arrayidx.11 = getelementptr inbounds nuw i8, ptr %b, i64 22
%11 = load i16, ptr %arrayidx.11, align 2
%conv2.11 = sext i16 %11 to i32
%or.11 = or i32 %conv.11, %conv2.11
%add.11 = add nsw i32 %or.11, %add.10
%shr.12 = lshr i32 %a, 12
%conv.12 = and i32 %shr.12, 16
%arrayidx.12 = getelementptr inbounds nuw i8, ptr %b, i64 24
%12 = load i16, ptr %arrayidx.12, align 2
%conv2.12 = sext i16 %12 to i32
%or.12 = or i32 %conv.12, %conv2.12
%add.12 = add nsw i32 %or.12, %add.11
%shr.13 = lshr i32 %a, 13
%conv.13 = and i32 %shr.13, 16
%arrayidx.13 = getelementptr inbounds nuw i8, ptr %b, i64 26
%13 = load i16, ptr %arrayidx.13, align 2
%conv2.13 = sext i16 %13 to i32
%or.13 = or i32 %conv.13, %conv2.13
%add.13 = add nsw i32 %or.13, %add.12
%shr.14 = lshr i32 %a, 14
%conv.14 = and i32 %shr.14, 16
%arrayidx.14 = getelementptr inbounds nuw i8, ptr %b, i64 28
%14 = load i16, ptr %arrayidx.14, align 2
%conv2.14 = sext i16 %14 to i32
%or.14 = or i32 %conv.14, %conv2.14
%add.14 = add nsw i32 %or.14, %add.13
%shr.15 = lshr i32 %a, 15
%conv.15 = and i32 %shr.15, 16
%arrayidx.15 = getelementptr inbounds nuw i8, ptr %b, i64 30
%15 = load i16, ptr %arrayidx.15, align 2
%conv2.15 = sext i16 %15 to i32
%or.15 = or i32 %conv.15, %conv2.15
%add.15 = add nsw i32 %or.15, %add.14
ret i32 %add.15
}
Loading