Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,11 @@ LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
Value *&Start, Value *&Step);

/// Analogous to the above, but starting from the binary operator
LLVM_ABI bool matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
LLVM_ABI bool matchSimpleRecurrence(const Instruction *I, PHINode *&P,
Value *&Start, Value *&Step);

/// Analogous to the above, but also supporting non-binary operators.
LLVM_ABI bool matchSimpleRecurrence(const PHINode *P, Instruction *&BO,
Value *&Start, Value *&Step);

/// Attempt to match a simple value-accumulating recurrence of the form:
Expand Down
47 changes: 37 additions & 10 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1577,7 +1577,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
}
case Instruction::PHI: {
const PHINode *P = cast<PHINode>(I);
BinaryOperator *BO = nullptr;
Instruction *BO = nullptr;
Copy link
Contributor

@goldsteinn goldsteinn Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think BO should be renamed. That's probably not worth it given the large amount of unrelated diffs it will create.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could rename it as NFC after the change lands?

Value *R = nullptr, *L = nullptr;
if (matchSimpleRecurrence(P, BO, R, L)) {
// Handle the case of a simple two-predecessor recurrence PHI.
Expand Down Expand Up @@ -1641,6 +1641,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
case Instruction::Sub:
case Instruction::And:
case Instruction::Or:
case Instruction::GetElementPtr:
case Instruction::Mul: {
// Change the context instruction to the "edge" that flows into the
// phi. This is important because that is where the value is actually
Expand All @@ -1659,6 +1660,11 @@ static void computeKnownBitsFromOperator(const Operator *I,

// We need to take the minimum number of known bits
KnownBits Known3(BitWidth);
if (BitWidth != getBitWidth(L->getType(), Q.DL)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the pointer width is different from the index width, the optimization will be disabled. Is there a real target satisfying the condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is happening in a lot of workloads in practice, e.g. index width 64 bits and GEPs with i32 indices.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is ok to fall through as the result is guarded by std::min(Idx.countMinTrailingZeros(), Ptr.countMinTrailingZeros()).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, but unfortunately computeKnowNbits has some assertions that the bitwdith of the operation matches the bassed in KnowBits.

We could operate on a suitable KnownBits object for the getelementptr, and extend as needed as follow-up, if there are any cases this would help.

assert(isa<GetElementPtrInst>(BO) &&
"Bitwidth should only be different for GEPs.");
break;
}
RecQ.CxtI = LInst;
computeKnownBits(L, DemandedElts, Known3, RecQ, Depth + 1);

Expand Down Expand Up @@ -1821,6 +1827,7 @@ static void computeKnownBitsFromOperator(const Operator *I,
Known.resetAll();
}
}

if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) {
switch (II->getIntrinsicID()) {
default:
Expand Down Expand Up @@ -2351,7 +2358,7 @@ void computeKnownBits(const Value *V, const APInt &DemandedElts,
/// always a power of two (or zero).
static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
SimplifyQuery &Q, unsigned Depth) {
BinaryOperator *BO = nullptr;
Instruction *BO = nullptr;
Value *Start = nullptr, *Step = nullptr;
if (!matchSimpleRecurrence(PN, BO, Start, Step))
return false;
Expand Down Expand Up @@ -2389,7 +2396,7 @@ static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
// Divisor must be a power of two.
// If OrZero is false, cannot guarantee induction variable is non-zero after
// division, same for Shr, unless it is exact division.
return (OrZero || Q.IIQ.isExact(BO)) &&
return (OrZero || Q.IIQ.isExact(cast<BinaryOperator>(BO))) &&
isKnownToBeAPowerOfTwo(Step, false, Q, Depth);
case Instruction::Shl:
return OrZero || Q.IIQ.hasNoUnsignedWrap(BO) || Q.IIQ.hasNoSignedWrap(BO);
Expand All @@ -2398,7 +2405,7 @@ static bool isPowerOfTwoRecurrence(const PHINode *PN, bool OrZero,
return false;
[[fallthrough]];
case Instruction::LShr:
return OrZero || Q.IIQ.isExact(BO);
return OrZero || Q.IIQ.isExact(cast<BinaryOperator>(BO));
default:
return false;
}
Expand Down Expand Up @@ -2810,7 +2817,7 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value)
/// Try to detect a recurrence that monotonically increases/decreases from a
/// non-zero starting value. These are common as induction variables.
static bool isNonZeroRecurrence(const PHINode *PN) {
BinaryOperator *BO = nullptr;
Instruction *BO = nullptr;
Value *Start = nullptr, *Step = nullptr;
const APInt *StartC, *StepC;
if (!matchSimpleRecurrence(PN, BO, Start, Step) ||
Expand Down Expand Up @@ -3648,9 +3655,9 @@ getInvertibleOperands(const Operator *Op1,
// If PN1 and PN2 are both recurrences, can we prove the entire recurrences
// are a single invertible function of the start values? Note that repeated
// application of an invertible function is also invertible
BinaryOperator *BO1 = nullptr;
Instruction *BO1 = nullptr;
Value *Start1 = nullptr, *Step1 = nullptr;
BinaryOperator *BO2 = nullptr;
Instruction *BO2 = nullptr;
Value *Start2 = nullptr, *Step2 = nullptr;
if (PN1->getParent() != PN2->getParent() ||
!matchSimpleRecurrence(PN1, BO1, Start1, Step1) ||
Expand Down Expand Up @@ -9125,6 +9132,14 @@ static bool matchTwoInputRecurrence(const PHINode *PN, InstTy *&Inst,
for (unsigned I = 0; I != 2; ++I) {
if (auto *Operation = dyn_cast<InstTy>(PN->getIncomingValue(I));
Operation && Operation->getNumOperands() >= 2) {
Value *LR;
if (match(Operation, m_PtrAdd(m_Specific(PN), m_Value(LR)))) {
Inst = Operation;
Init = PN->getIncomingValue(!I);
OtherOp = LR;
return true;
}

Value *LHS = Operation->getOperand(0);
Value *RHS = Operation->getOperand(1);
if (LHS != PN && RHS != PN)
Expand All @@ -9147,12 +9162,24 @@ bool llvm::matchSimpleRecurrence(const PHINode *P, BinaryOperator *&BO,
// Or:
// %iv = [Start, %entry], [%iv.next, %backedge]
// %iv.next = binop Step, %iv
return matchTwoInputRecurrence(P, BO, Start, Step);
return matchTwoInputRecurrence(P, BO, Start, Step) && isa<BinaryOperator>(BO);
}

bool llvm::matchSimpleRecurrence(const PHINode *P, Instruction *&BO,
Value *&Start, Value *&Step) {
// We try to match a recurrence of the form:
// %iv = [Start, %entry], [%iv.next, %backedge]
// %iv.next = binop %iv, Step
// Or:
// %iv = [Start, %entry], [%iv.next, %backedge]
// %iv.next = binop Step, %iv
return matchTwoInputRecurrence(P, BO, Start, Step) &&
isa<BinaryOperator, GetElementPtrInst>(BO);
}

bool llvm::matchSimpleRecurrence(const BinaryOperator *I, PHINode *&P,
bool llvm::matchSimpleRecurrence(const Instruction *I, PHINode *&P,
Value *&Start, Value *&Step) {
BinaryOperator *BO = nullptr;
Instruction *BO = nullptr;
P = dyn_cast<PHINode>(I->getOperand(0));
if (!P)
P = dyn_cast<PHINode>(I->getOperand(1));
Expand Down
26 changes: 13 additions & 13 deletions llvm/test/Transforms/InferAlignment/gep-recurrence.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ define void @recur_i8_128(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 128
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 128
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -40,7 +40,7 @@ define void @recur_i8_128_no_nusw(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 128
; CHECK-NEXT: [[IV_NEXT]] = getelementptr i8, ptr [[IV]], i64 128
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -68,7 +68,7 @@ define void @recur_i8_64(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 64
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 64
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -124,7 +124,7 @@ define void @recur_i8_32(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 32
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 32
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -152,7 +152,7 @@ define void @recur_i8_16(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 16
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 16
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -180,7 +180,7 @@ define void @recur_i8_8(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 8
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 8
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -208,7 +208,7 @@ define void @recur_i8_4(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 4
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 4
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -236,7 +236,7 @@ define void @recur_i8_2(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 2
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 2
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -412,7 +412,7 @@ define void @recur_i32_4(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 4
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i32, ptr [[IV]], i64 4
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -440,7 +440,7 @@ define void @recur_i32_3(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 4
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i32, ptr [[IV]], i64 4
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -468,7 +468,7 @@ define void @recur_i8_neg_128(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 128
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 -128
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -496,7 +496,7 @@ define void @recur_i8_neg64(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 64
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 -64
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down Expand Up @@ -552,7 +552,7 @@ define void @recur_i8_neg_32(ptr align 128 %dst) {
; CHECK-NEXT: br label %[[LOOP:.*]]
; CHECK: [[LOOP]]:
; CHECK-NEXT: [[IV:%.*]] = phi ptr [ [[DST]], %[[ENTRY]] ], [ [[IV_NEXT:%.*]], %[[LOOP]] ]
; CHECK-NEXT: store i64 0, ptr [[IV]], align 1
; CHECK-NEXT: store i64 0, ptr [[IV]], align 32
; CHECK-NEXT: [[IV_NEXT]] = getelementptr nusw i8, ptr [[IV]], i64 -32
; CHECK-NEXT: [[C:%.*]] = call i1 @cond()
; CHECK-NEXT: br i1 [[C]], label %[[LOOP]], label %[[EXIT:.*]]
Expand Down
Loading