diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp index ec40124c57a6a..c854d6ac9350a 100644 --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -226,6 +226,8 @@ static const int MinScheduleRegionSize = 16; /// Maximum allowed number of operands in the PHI nodes. static const unsigned MaxPHINumOperands = 128; +static SmallDenseMap IdentityInstrsMp; + /// Predicate for the element types that the SLP vectorizer supports. /// /// The most important thing to filter here are types which are invalid in LLVM @@ -2075,6 +2077,57 @@ class BoUpSLP { OptimizationRemarkEmitter *getORE() { return ORE; } + static SmallVector setIdentityInstr(ArrayRef VL) { + SmallVector New_VL(VL.begin(), VL.end()); + if (VL.size() <= 2) + return New_VL; + auto It = find_if(VL, IsaPred); + if (It == VL.end()) + return New_VL; + // work on unique list of instructions only: + SmallDenseMap SeenInstrs; + for (auto *V : VL) + if (auto *I = dyn_cast(V)) { + if (!SeenInstrs[I->getName()]) + SeenInstrs[I->getName()] = true; + else { + return New_VL; + } + } + Instruction *MainOp = cast(*It); + auto ValidOperands = count_if(VL, IsaPred); + if (ValidOperands != (int)VL.size() - 1) + return New_VL; + auto DifferentOperand = find_if_not(VL, IsaPred); + if (DifferentOperand == VL.end()) + return New_VL; + assert(!isa(*DifferentOperand) && + !isa(*DifferentOperand) && + "Expected different operand to be not an instruction"); + auto FoundIdentityInstrIt = IdentityInstrsMp.find(*DifferentOperand); + if (FoundIdentityInstrIt != IdentityInstrsMp.end()) { + auto OperandIndex = std::distance(VL.begin(), DifferentOperand); + New_VL[OperandIndex] = FoundIdentityInstrIt->second; + return New_VL; + } + auto *Identity = ConstantExpr::getIdentity(MainOp, MainOp->getType(), + true /*AllowRHSConstant*/); + if (!Identity) + return New_VL; + auto *NewInstr = MainOp->clone(); + NewInstr->setOperand(0, *DifferentOperand); + NewInstr->setOperand(1, Identity); + NewInstr->insertAfter(cast(MainOp)); + NewInstr->setName((*DifferentOperand)->getName() + ".identity"); + auto OperandIndex = std::distance(VL.begin(), DifferentOperand); + New_VL[OperandIndex] = NewInstr; + assert(find_if_not(New_VL, IsaPred) == + New_VL.end() && + "Expected all operands to be instructions"); + IdentityInstrsMp.try_emplace(*DifferentOperand, NewInstr); + return New_VL; + } + /// This structure holds any data we need about the edges being traversed /// during buildTreeRec(). We keep track of: /// (i) the user TreeEntry index, and @@ -3786,7 +3839,8 @@ class BoUpSLP { assert(OpVL.size() <= Scalars.size() && "Number of operands is greater than the number of scalars."); Operands[OpIdx].resize(OpVL.size()); - copy(OpVL, Operands[OpIdx].begin()); + auto NewVL = BoUpSLP::setIdentityInstr(OpVL); + copy(NewVL, Operands[OpIdx].begin()); } public: @@ -4084,18 +4138,19 @@ class BoUpSLP { "Reshuffling scalars not yet supported for nodes with padding"); Last->ReuseShuffleIndices.append(ReuseShuffleIndices.begin(), ReuseShuffleIndices.end()); + SmallVector NewVL = BoUpSLP::setIdentityInstr(VL); if (ReorderIndices.empty()) { - Last->Scalars.assign(VL.begin(), VL.end()); + Last->Scalars.assign(NewVL.begin(), NewVL.end()); if (S) Last->setOperations(S); } else { // Reorder scalars and build final mask. - Last->Scalars.assign(VL.size(), nullptr); + Last->Scalars.assign(NewVL.size(), nullptr); transform(ReorderIndices, Last->Scalars.begin(), - [VL](unsigned Idx) -> Value * { - if (Idx >= VL.size()) - return UndefValue::get(VL.front()->getType()); - return VL[Idx]; + [NewVL](unsigned Idx) -> Value * { + if (Idx >= NewVL.size()) + return UndefValue::get(NewVL.front()->getType()); + return NewVL[Idx]; }); InstructionsState S = getSameOpcode(Last->Scalars, *TLI); if (S) @@ -4106,7 +4161,7 @@ class BoUpSLP { assert(S && "Split nodes must have operations."); Last->setOperations(S); SmallPtrSet Processed; - for (Value *V : VL) { + for (Value *V : NewVL) { auto *I = dyn_cast(V); if (!I) continue; @@ -4121,10 +4176,10 @@ class BoUpSLP { } } } else if (!Last->isGather()) { - if (doesNotNeedToSchedule(VL)) + if (doesNotNeedToSchedule(NewVL)) Last->setDoesNotNeedToSchedule(); SmallPtrSet Processed; - for (Value *V : VL) { + for (Value *V : NewVL) { if (isa(V)) continue; auto It = ScalarToTreeEntries.find(V); @@ -4146,7 +4201,7 @@ class BoUpSLP { #if !defined(NDEBUG) || defined(EXPENSIVE_CHECKS) auto *BundleMember = Bundle.getBundle().begin(); SmallPtrSet Processed; - for (Value *V : VL) { + for (Value *V : NewVL) { if (doesNotNeedToBeScheduled(V) || !Processed.insert(V).second) continue; ++BundleMember; @@ -4159,7 +4214,7 @@ class BoUpSLP { } else { // Build a map for gathered scalars to the nodes where they are used. bool AllConstsOrCasts = true; - for (Value *V : VL) + for (Value *V : NewVL) if (!isConstant(V)) { auto *I = dyn_cast(V); AllConstsOrCasts &= I && I->getType()->isIntegerTy(); @@ -4170,7 +4225,7 @@ class BoUpSLP { if (AllConstsOrCasts) CastMaxMinBWSizes = std::make_pair(std::numeric_limits::max(), 1); - MustGather.insert_range(VL); + MustGather.insert_range(NewVL); } if (UserTreeIdx.UserTE) @@ -20844,6 +20899,11 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, } } + for (auto &I : IdentityInstrsMp) { + if (I.second && cast(I.second)->getParent()) + cast(I.second)->eraseFromParent(); + } + IdentityInstrsMp.clear(); if (Changed) { R.optimizeGatherSequence(); LLVM_DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n"); diff --git a/llvm/test/Transforms/SLPVectorizer/X86/pr47642.ll b/llvm/test/Transforms/SLPVectorizer/X86/pr47642.ll index 42a50384787c8..57a4e474850bb 100644 --- a/llvm/test/Transforms/SLPVectorizer/X86/pr47642.ll +++ b/llvm/test/Transforms/SLPVectorizer/X86/pr47642.ll @@ -7,13 +7,8 @@ target triple = "x86_64-unknown-linux-gnu" define <4 x i32> @foo(<4 x i32> %x, i32 %f) { ; CHECK-LABEL: @foo( ; CHECK-NEXT: [[VECINIT:%.*]] = insertelement <4 x i32> poison, i32 [[F:%.*]], i64 0 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i32 [[F]], 1 -; CHECK-NEXT: [[VECINIT1:%.*]] = insertelement <4 x i32> [[VECINIT]], i32 [[ADD]], i64 1 -; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x i32> poison, i32 [[F]], i64 0 -; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <2 x i32> [[TMP1]], <2 x i32> poison, <2 x i32> zeroinitializer -; CHECK-NEXT: [[TMP3:%.*]] = add nsw <2 x i32> [[TMP2]], -; CHECK-NEXT: [[TMP4:%.*]] = shufflevector <2 x i32> [[TMP3]], <2 x i32> poison, <4 x i32> -; CHECK-NEXT: [[VECINIT51:%.*]] = shufflevector <4 x i32> [[VECINIT1]], <4 x i32> [[TMP4]], <4 x i32> +; CHECK-NEXT: [[TMP2:%.*]] = shufflevector <4 x i32> [[VECINIT]], <4 x i32> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[VECINIT51:%.*]] = add nsw <4 x i32> [[TMP2]], ; CHECK-NEXT: ret <4 x i32> [[VECINIT51]] ; %vecinit = insertelement <4 x i32> undef, i32 %f, i32 0 diff --git a/llvm/test/Transforms/SLPVectorizer/infer-missing-instruction.ll b/llvm/test/Transforms/SLPVectorizer/infer-missing-instruction.ll new file mode 100644 index 0000000000000..0e57113a38e27 --- /dev/null +++ b/llvm/test/Transforms/SLPVectorizer/infer-missing-instruction.ll @@ -0,0 +1,129 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=slp-vectorizer,instcombine -S < %s | FileCheck %s + +define dso_local noundef i32 @_Z4testiPs(i32 noundef %a, ptr noundef readonly captures(none) %b) local_unnamed_addr #0 { +; CHECK-LABEL: define dso_local noundef i32 @_Z4testiPs( +; CHECK-SAME: i32 noundef [[A:%.*]], ptr noundef readonly captures(none) [[B:%.*]]) local_unnamed_addr { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[TMP9:%.*]] = insertelement <16 x i32> poison, i32 [[A]], i64 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <16 x i32> [[TMP9]], <16 x i32> poison, <16 x i32> zeroinitializer +; CHECK-NEXT: [[TMP16:%.*]] = lshr <16 x i32> [[TMP1]], +; CHECK-NEXT: [[TMP17:%.*]] = and <16 x i32> [[TMP16]], splat (i32 16) +; CHECK-NEXT: [[TMP18:%.*]] = load <16 x i16>, ptr [[B]], align 2 +; CHECK-NEXT: [[TMP19:%.*]] = sext <16 x i16> [[TMP18]] to <16 x i32> +; CHECK-NEXT: [[TMP20:%.*]] = or <16 x i32> [[TMP17]], [[TMP19]] +; CHECK-NEXT: [[TMP21:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP20]]) +; CHECK-NEXT: ret i32 [[TMP21]] +; +entry: + %conv = and i32 %a, 16 + %0 = load i16, ptr %b, align 2 + %conv2 = sext i16 %0 to i32 + %or = or i32 %conv, %conv2 + %shr.1 = lshr i32 %a, 1 + %conv.1 = and i32 %shr.1, 16 + %arrayidx.1 = getelementptr inbounds nuw i8, ptr %b, i64 2 + %1 = load i16, ptr %arrayidx.1, align 2 + %conv2.1 = sext i16 %1 to i32 + %or.1 = or i32 %conv.1, %conv2.1 + %add.1 = add nsw i32 %or.1, %or + %shr.2 = lshr i32 %a, 2 + %conv.2 = and i32 %shr.2, 16 + %arrayidx.2 = getelementptr inbounds nuw i8, ptr %b, i64 4 + %2 = load i16, ptr %arrayidx.2, align 2 + %conv2.2 = sext i16 %2 to i32 + %or.2 = or i32 %conv.2, %conv2.2 + %add.2 = add nsw i32 %or.2, %add.1 + %shr.3 = lshr i32 %a, 3 + %conv.3 = and i32 %shr.3, 16 + %arrayidx.3 = getelementptr inbounds nuw i8, ptr %b, i64 6 + %3 = load i16, ptr %arrayidx.3, align 2 + %conv2.3 = sext i16 %3 to i32 + %or.3 = or i32 %conv.3, %conv2.3 + %add.3 = add nsw i32 %or.3, %add.2 + %shr.4 = lshr i32 %a, 4 + %conv.4 = and i32 %shr.4, 16 + %arrayidx.4 = getelementptr inbounds nuw i8, ptr %b, i64 8 + %4 = load i16, ptr %arrayidx.4, align 2 + %conv2.4 = sext i16 %4 to i32 + %or.4 = or i32 %conv.4, %conv2.4 + %add.4 = add nsw i32 %or.4, %add.3 + %shr.5 = lshr i32 %a, 5 + %conv.5 = and i32 %shr.5, 16 + %arrayidx.5 = getelementptr inbounds nuw i8, ptr %b, i64 10 + %5 = load i16, ptr %arrayidx.5, align 2 + %conv2.5 = sext i16 %5 to i32 + %or.5 = or i32 %conv.5, %conv2.5 + %add.5 = add nsw i32 %or.5, %add.4 + %shr.6 = lshr i32 %a, 6 + %conv.6 = and i32 %shr.6, 16 + %arrayidx.6 = getelementptr inbounds nuw i8, ptr %b, i64 12 + %6 = load i16, ptr %arrayidx.6, align 2 + %conv2.6 = sext i16 %6 to i32 + %or.6 = or i32 %conv.6, %conv2.6 + %add.6 = add nsw i32 %or.6, %add.5 + %shr.7 = lshr i32 %a, 7 + %conv.7 = and i32 %shr.7, 16 + %arrayidx.7 = getelementptr inbounds nuw i8, ptr %b, i64 14 + %7 = load i16, ptr %arrayidx.7, align 2 + %conv2.7 = sext i16 %7 to i32 + %or.7 = or i32 %conv.7, %conv2.7 + %add.7 = add nsw i32 %or.7, %add.6 + %shr.8 = lshr i32 %a, 8 + %conv.8 = and i32 %shr.8, 16 + %arrayidx.8 = getelementptr inbounds nuw i8, ptr %b, i64 16 + %8 = load i16, ptr %arrayidx.8, align 2 + %conv2.8 = sext i16 %8 to i32 + %or.8 = or i32 %conv.8, %conv2.8 + %add.8 = add nsw i32 %or.8, %add.7 + %shr.9 = lshr i32 %a, 9 + %conv.9 = and i32 %shr.9, 16 + %arrayidx.9 = getelementptr inbounds nuw i8, ptr %b, i64 18 + %9 = load i16, ptr %arrayidx.9, align 2 + %conv2.9 = sext i16 %9 to i32 + %or.9 = or i32 %conv.9, %conv2.9 + %add.9 = add nsw i32 %or.9, %add.8 + %shr.10 = lshr i32 %a, 10 + %conv.10 = and i32 %shr.10, 16 + %arrayidx.10 = getelementptr inbounds nuw i8, ptr %b, i64 20 + %10 = load i16, ptr %arrayidx.10, align 2 + %conv2.10 = sext i16 %10 to i32 + %or.10 = or i32 %conv.10, %conv2.10 + %add.10 = add nsw i32 %or.10, %add.9 + %shr.11 = lshr i32 %a, 11 + %conv.11 = and i32 %shr.11, 16 + %arrayidx.11 = getelementptr inbounds nuw i8, ptr %b, i64 22 + %11 = load i16, ptr %arrayidx.11, align 2 + %conv2.11 = sext i16 %11 to i32 + %or.11 = or i32 %conv.11, %conv2.11 + %add.11 = add nsw i32 %or.11, %add.10 + %shr.12 = lshr i32 %a, 12 + %conv.12 = and i32 %shr.12, 16 + %arrayidx.12 = getelementptr inbounds nuw i8, ptr %b, i64 24 + %12 = load i16, ptr %arrayidx.12, align 2 + %conv2.12 = sext i16 %12 to i32 + %or.12 = or i32 %conv.12, %conv2.12 + %add.12 = add nsw i32 %or.12, %add.11 + %shr.13 = lshr i32 %a, 13 + %conv.13 = and i32 %shr.13, 16 + %arrayidx.13 = getelementptr inbounds nuw i8, ptr %b, i64 26 + %13 = load i16, ptr %arrayidx.13, align 2 + %conv2.13 = sext i16 %13 to i32 + %or.13 = or i32 %conv.13, %conv2.13 + %add.13 = add nsw i32 %or.13, %add.12 + %shr.14 = lshr i32 %a, 14 + %conv.14 = and i32 %shr.14, 16 + %arrayidx.14 = getelementptr inbounds nuw i8, ptr %b, i64 28 + %14 = load i16, ptr %arrayidx.14, align 2 + %conv2.14 = sext i16 %14 to i32 + %or.14 = or i32 %conv.14, %conv2.14 + %add.14 = add nsw i32 %or.14, %add.13 + %shr.15 = lshr i32 %a, 15 + %conv.15 = and i32 %shr.15, 16 + %arrayidx.15 = getelementptr inbounds nuw i8, ptr %b, i64 30 + %15 = load i16, ptr %arrayidx.15, align 2 + %conv2.15 = sext i16 %15 to i32 + %or.15 = or i32 %conv.15, %conv2.15 + %add.15 = add nsw i32 %or.15, %add.14 + ret i32 %add.15 +}