Skip to content

[llvm][SLPVectorizer] Fix a bad cast assertion #97621

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,12 +361,14 @@ static bool isCommutative(Instruction *I) {
return I->isCommutative();
}

/// \returns inserting index of InsertElement or InsertValue instruction,
/// using Offset as base offset for index.
static std::optional<unsigned> getInsertIndex(const Value *InsertInst,
unsigned Offset = 0) {
template <typename T>
static std::optional<unsigned> getInsertExtractIndex(const Value *Inst,
unsigned Offset) {
static_assert(std::is_same_v<T, InsertElementInst> ||
std::is_same_v<T, ExtractElementInst>,
"unsupported T");
int Index = Offset;
if (const auto *IE = dyn_cast<InsertElementInst>(InsertInst)) {
if (const auto *IE = dyn_cast<T>(Inst)) {
const auto *VT = dyn_cast<FixedVectorType>(IE->getType());
if (!VT)
return std::nullopt;
Expand All @@ -379,8 +381,25 @@ static std::optional<unsigned> getInsertIndex(const Value *InsertInst,
Index += CI->getZExtValue();
return Index;
}
return std::nullopt;
}

/// \returns inserting or extracting index of InsertElement, ExtractElement or
/// InsertValue instruction, using Offset as base offset for index.
/// \returns std::nullopt if the index is not an immediate.
static std::optional<unsigned> getElementIndex(const Value *Inst,
unsigned Offset = 0) {
if (auto Index = getInsertExtractIndex<InsertElementInst>(Inst, Offset))
return Index;
if (auto Index = getInsertExtractIndex<ExtractElementInst>(Inst, Offset))
return Index;

int Index = Offset;

const auto *IV = dyn_cast<InsertValueInst>(Inst);
if (!IV)
return std::nullopt;

const auto *IV = cast<InsertValueInst>(InsertInst);
Type *CurrentType = IV->getType();
for (unsigned I : IV->indices()) {
if (const auto *ST = dyn_cast<StructType>(CurrentType)) {
Expand Down Expand Up @@ -454,7 +473,7 @@ static SmallBitVector isUndefVector(const Value *V,
Base = II->getOperand(0);
if (isa<T>(II->getOperand(1)))
continue;
std::optional<unsigned> Idx = getInsertIndex(II);
std::optional<unsigned> Idx = getElementIndex(II);
if (!Idx) {
Res.reset();
return Res;
Expand Down Expand Up @@ -4718,8 +4737,8 @@ static bool areTwoInsertFromSameBuildVector(
return false;
auto *IE1 = VU;
auto *IE2 = V;
std::optional<unsigned> Idx1 = getInsertIndex(IE1);
std::optional<unsigned> Idx2 = getInsertIndex(IE2);
std::optional<unsigned> Idx1 = getElementIndex(IE1);
std::optional<unsigned> Idx2 = getElementIndex(IE2);
if (Idx1 == std::nullopt || Idx2 == std::nullopt)
return false;
// Go through the vector operand of insertelement instructions trying to find
Expand All @@ -4734,7 +4753,7 @@ static bool areTwoInsertFromSameBuildVector(
if (IE1 == V && !IE2)
return V->hasOneUse();
if (IE1 && IE1 != V) {
unsigned Idx1 = getInsertIndex(IE1).value_or(*Idx2);
unsigned Idx1 = getElementIndex(IE1).value_or(*Idx2);
IsReusedIdx |= ReusedIdx.test(Idx1);
ReusedIdx.set(Idx1);
if ((IE1 != VU && !IE1->hasOneUse()) || IsReusedIdx)
Expand All @@ -4743,7 +4762,7 @@ static bool areTwoInsertFromSameBuildVector(
IE1 = dyn_cast_or_null<InsertElementInst>(GetBaseOperand(IE1));
}
if (IE2 && IE2 != VU) {
unsigned Idx2 = getInsertIndex(IE2).value_or(*Idx1);
unsigned Idx2 = getElementIndex(IE2).value_or(*Idx1);
IsReusedIdx |= ReusedIdx.test(Idx2);
ReusedIdx.set(Idx2);
if ((IE2 != V && !IE2->hasOneUse()) || IsReusedIdx)
Expand Down Expand Up @@ -4902,13 +4921,13 @@ BoUpSLP::getReorderingData(const TreeEntry &TE, bool TopToBottom) {
IE1, IE2,
[](InsertElementInst *II) { return II->getOperand(0); }))
return I1 < I2;
return getInsertIndex(IE1) < getInsertIndex(IE2);
return getElementIndex(IE1) < getElementIndex(IE2);
}
if (auto *EE1 = dyn_cast<ExtractElementInst>(FirstUserOfPhi1))
if (auto *EE2 = dyn_cast<ExtractElementInst>(FirstUserOfPhi2)) {
if (EE1->getOperand(0) != EE2->getOperand(0))
return I1 < I2;
return getInsertIndex(EE1) < getInsertIndex(EE2);
return getElementIndex(EE1) < getElementIndex(EE2);
}
return I1 < I2;
};
Expand Down Expand Up @@ -6162,7 +6181,7 @@ BoUpSLP::TreeEntry::EntryState BoUpSLP::getScalarsVectorizationState(
ValueSet SourceVectors;
for (Value *V : VL) {
SourceVectors.insert(cast<Instruction>(V)->getOperand(0));
assert(getInsertIndex(V) != std::nullopt &&
assert(getElementIndex(V) != std::nullopt &&
"Non-constant or undef index?");
}

Expand Down Expand Up @@ -6929,7 +6948,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth,
decltype(OrdCompare)>
Indices(OrdCompare);
for (int I = 0, E = VL.size(); I < E; ++I) {
unsigned Idx = *getInsertIndex(VL[I]);
unsigned Idx = *getElementIndex(VL[I]);
Indices.emplace(Idx, I);
}
OrdersType CurrentOrder(VL.size(), VL.size());
Expand Down Expand Up @@ -9308,11 +9327,11 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
unsigned NumOfParts = TTI->getNumberOfParts(SrcVecTy);

SmallVector<int> InsertMask(NumElts, PoisonMaskElem);
unsigned OffsetBeg = *getInsertIndex(VL.front());
unsigned OffsetBeg = *getElementIndex(VL.front());
unsigned OffsetEnd = OffsetBeg;
InsertMask[OffsetBeg] = 0;
for (auto [I, V] : enumerate(VL.drop_front())) {
unsigned Idx = *getInsertIndex(V);
unsigned Idx = *getElementIndex(V);
if (OffsetBeg > Idx)
OffsetBeg = Idx;
else if (OffsetEnd < Idx)
Expand Down Expand Up @@ -9353,7 +9372,7 @@ BoUpSLP::getEntryCost(const TreeEntry *E, ArrayRef<Value *> VectorizedVals,
SmallVector<int> PrevMask(InsertVecSz, PoisonMaskElem);
Mask.swap(PrevMask);
for (unsigned I = 0; I < NumScalars; ++I) {
unsigned InsertIdx = *getInsertIndex(VL[PrevMask[I]]);
unsigned InsertIdx = *getElementIndex(VL[PrevMask[I]]);
DemandedElts.setBit(InsertIdx);
IsIdentity &= InsertIdx - OffsetBeg == I;
Mask[InsertIdx - OffsetBeg] = I;
Expand Down Expand Up @@ -10103,8 +10122,8 @@ static bool isFirstInsertElement(const InsertElementInst *IE1,
const auto *I2 = IE2;
const InsertElementInst *PrevI1;
const InsertElementInst *PrevI2;
unsigned Idx1 = *getInsertIndex(IE1);
unsigned Idx2 = *getInsertIndex(IE2);
unsigned Idx1 = *getElementIndex(IE1);
unsigned Idx2 = *getElementIndex(IE2);
do {
if (I2 == IE1)
return true;
Expand All @@ -10113,10 +10132,10 @@ static bool isFirstInsertElement(const InsertElementInst *IE1,
PrevI1 = I1;
PrevI2 = I2;
if (I1 && (I1 == IE1 || I1->hasOneUse()) &&
getInsertIndex(I1).value_or(Idx2) != Idx2)
getElementIndex(I1).value_or(Idx2) != Idx2)
I1 = dyn_cast<InsertElementInst>(I1->getOperand(0));
if (I2 && ((I2 == IE2 || I2->hasOneUse())) &&
getInsertIndex(I2).value_or(Idx1) != Idx1)
getElementIndex(I2).value_or(Idx1) != Idx1)
I2 = dyn_cast<InsertElementInst>(I2->getOperand(0));
} while ((I1 && PrevI1 != I1) || (I2 && PrevI2 != I2));
llvm_unreachable("Two different buildvectors not expected.");
Expand Down Expand Up @@ -10308,7 +10327,7 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
if (auto *FTy = dyn_cast<FixedVectorType>(VU->getType())) {
if (!UsedInserts.insert(VU).second)
continue;
std::optional<unsigned> InsertIdx = getInsertIndex(VU);
std::optional<unsigned> InsertIdx = getElementIndex(VU);
if (InsertIdx) {
const TreeEntry *ScalarTE = getTreeEntry(EU.Scalar);
auto *It = find_if(
Expand All @@ -10334,14 +10353,14 @@ InstructionCost BoUpSLP::getTreeCost(ArrayRef<Value *> VectorizedVals) {
while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) {
if (IEBase != EU.User &&
(!IEBase->hasOneUse() ||
getInsertIndex(IEBase).value_or(*InsertIdx) == *InsertIdx))
getElementIndex(IEBase).value_or(*InsertIdx) == *InsertIdx))
break;
// Build the mask for the vectorized insertelement instructions.
if (const TreeEntry *E = getTreeEntry(IEBase)) {
VU = IEBase;
do {
IEBase = cast<InsertElementInst>(Base);
int Idx = *getInsertIndex(IEBase);
int Idx = *getElementIndex(IEBase);
assert(Mask[Idx] == PoisonMaskElem &&
"InsertElementInstruction used already.");
Mask[Idx] = Idx;
Expand Down Expand Up @@ -12755,7 +12774,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
cast<FixedVectorType>(FirstInsert->getType())->getNumElements();
const unsigned NumScalars = E->Scalars.size();

unsigned Offset = *getInsertIndex(VL0);
unsigned Offset = *getElementIndex(VL0);
assert(Offset < NumElts && "Failed to find vector index offset");

// Create shuffle to resize vector
Expand All @@ -12773,7 +12792,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
Mask.swap(PrevMask);
for (unsigned I = 0; I < NumScalars; ++I) {
Value *Scalar = E->Scalars[PrevMask[I]];
unsigned InsertIdx = *getInsertIndex(Scalar);
unsigned InsertIdx = *getElementIndex(Scalar);
IsIdentity &= InsertIdx - Offset == I;
Mask[InsertIdx - Offset] = I;
}
Expand All @@ -12786,7 +12805,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E, bool PostponedPHIs) {
// sequence.
InsertElementInst *Ins = cast<InsertElementInst>(VL0);
do {
std::optional<unsigned> InsertIdx = getInsertIndex(Ins);
std::optional<unsigned> InsertIdx = getElementIndex(Ins);
if (!InsertIdx)
break;
if (InsertMask[*InsertIdx] == PoisonMaskElem)
Expand Down Expand Up @@ -13835,7 +13854,7 @@ Value *BoUpSLP::vectorizeTree(
}
}

std::optional<unsigned> InsertIdx = getInsertIndex(VU);
std::optional<unsigned> InsertIdx = getElementIndex(VU);
if (InsertIdx) {
auto *It =
find_if(ShuffledInserts, [VU](const ShuffledInsertData &Data) {
Expand All @@ -13858,13 +13877,13 @@ Value *BoUpSLP::vectorizeTree(
while (auto *IEBase = dyn_cast<InsertElementInst>(Base)) {
if (IEBase != User &&
(!IEBase->hasOneUse() ||
getInsertIndex(IEBase).value_or(Idx) == Idx))
getElementIndex(IEBase).value_or(Idx) == Idx))
break;
// Build the mask for the vectorized insertelement instructions.
if (const TreeEntry *E = getTreeEntry(IEBase)) {
do {
IEBase = cast<InsertElementInst>(Base);
int IEIdx = *getInsertIndex(IEBase);
int IEIdx = *getElementIndex(IEBase);
assert(Mask[IEIdx] == PoisonMaskElem &&
"InsertElementInstruction used already.");
Mask[IEIdx] = IEIdx;
Expand Down Expand Up @@ -17822,7 +17841,7 @@ static void findBuildAggregate_rec(Instruction *LastInsertInst,
do {
Value *InsertedOperand = LastInsertInst->getOperand(1);
std::optional<unsigned> OperandIndex =
getInsertIndex(LastInsertInst, OperandOffset);
getElementIndex(LastInsertInst, OperandOffset);
if (!OperandIndex)
return;
if (isa<InsertElementInst, InsertValueInst>(InsertedOperand)) {
Expand Down
39 changes: 39 additions & 0 deletions llvm/test/Transforms/SLPVectorizer/rdar128092379.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -passes=slp-vectorizer < %s -o - -S | FileCheck %s

target datalayout = "e-m:o-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128"
target triple = "x86_64-apple-macosx15.4.0"

define fastcc i32 @rdar128092379(i8 %index) {
; CHECK-LABEL: define fastcc i32 @rdar128092379(
; CHECK-SAME: i8 [[INDEX:%.*]]) {
; CHECK-NEXT: [[BLOCK:.*]]:
; CHECK-NEXT: [[ZEXT:%.*]] = zext i8 [[INDEX]] to i64
; CHECK-NEXT: [[ZEXT1:%.*]] = zext i8 [[INDEX]] to i64
; CHECK-NEXT: br label %[[BLOCK3:.*]]
; CHECK: [[BLOCK2:.*]]:
; CHECK-NEXT: br label %[[BLOCK3]]
; CHECK: [[BLOCK3]]:
; CHECK-NEXT: [[PHI:%.*]] = phi i64 [ 0, %[[BLOCK2]] ], [ [[ZEXT1]], %[[BLOCK]] ]
; CHECK-NEXT: [[PHI4:%.*]] = phi i64 [ 0, %[[BLOCK2]] ], [ [[ZEXT]], %[[BLOCK]] ]
; CHECK-NEXT: [[EXTRACTELEMENT:%.*]] = extractelement <16 x i32> zeroinitializer, i64 [[PHI4]]
; CHECK-NEXT: [[EXTRACTELEMENT5:%.*]] = extractelement <16 x i32> zeroinitializer, i64 [[PHI]]
; CHECK-NEXT: [[SUM:%.*]] = add i32 [[EXTRACTELEMENT]], [[EXTRACTELEMENT5]]
; CHECK-NEXT: ret i32 [[SUM]]
;
block:
%zext = zext i8 %index to i64
%zext1 = zext i8 %index to i64
br label %block3

block2:
br label %block3

block3:
%phi = phi i64 [ 0, %block2 ], [ %zext1, %block ]
%phi4 = phi i64 [ 0, %block2 ], [ %zext, %block ]
%extractelement = extractelement <16 x i32> zeroinitializer, i64 %phi4
%extractelement5 = extractelement <16 x i32> zeroinitializer, i64 %phi
%sum = add i32 %extractelement, %extractelement5
ret i32 %sum
}
Loading