Skip to content

Commit b2df961

Browse files
committed
[IndVarSimplify][LoopUtils] Avoid TOCTOU/ordering issues (PR45835)
Summary: Currently, `rewriteLoopExitValues()`'s logic is roughly as following: > Loop over each incoming value in each PHI node. > Query whether the SCEV for that incoming value is high-cost. > Expand the SCEV. > Perform sanity check (`isValidRewrite()`, D51582) > Record the info > Afterwards, see if we can drop the loop given replacements. > Maybe perform replacements. The problem is that we interleave SCEV cost checking and expansion. This is A Problem, because `isHighCostExpansion()` takes special care to not bill for the expansions that were already expanded, and we can reuse. While it makes sense in general - if we know that we will expand some SCEV, all the other SCEV's costs should account for that, which might cause some of them to become non-high-cost too, and cause chain reaction. But that isn't what we are doing here. We expand *all* SCEV's, unconditionally. So every next SCEV's cost will be affected by the already-performed expansions for previous SCEV's. Even if we are not planning on keeping some of the expansions we performed. Worse yet, this current "bonus" depends on the exact PHI node incoming value processing order. This is completely wrong. As an example of an issue, see @dmajor's `pr45835.ll` - if we happen to have a PHI node with two(!) identical high-cost incoming values for the same basic blocks, we would decide first time around that it is high-cost, expand it, and immediately decide that it is not high-cost because we have an expansion that we could reuse (because we expanded it right before, temporarily), and replace the second incoming value but not the first one; thus resulting in a broken PHI. What we instead should do for now, is not perform any expansions until after we've queried all the costs. Later, in particular after `isValidRewrite()` is an assertion (D51582) we could improve upon that, but in a more coherent fashion. See [[ https://bugs.llvm.org/show_bug.cgi?id=45835 | PR45835 ]] Reviewers: dmajor, reames, mkazantsev, fhahn, efriedma Reviewed By: dmajor, mkazantsev Subscribers: smeenai, nikic, hiraditya, javed.absar, llvm-commits, dmajor Tags: #llvm Differential Revision: https://reviews.llvm.org/D79787
1 parent b061450 commit b2df961

File tree

2 files changed

+98
-28
lines changed

2 files changed

+98
-28
lines changed

llvm/lib/Transforms/Utils/LoopUtils.cpp

Lines changed: 60 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,13 +1216,19 @@ static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) {
12161216
// Collect information about PHI nodes which can be transformed in
12171217
// rewriteLoopExitValues.
12181218
struct RewritePhi {
1219-
PHINode *PN;
1220-
unsigned Ith; // Ith incoming value.
1221-
Value *Val; // Exit value after expansion.
1222-
bool HighCost; // High Cost when expansion.
1223-
1224-
RewritePhi(PHINode *P, unsigned I, Value *V, bool H)
1225-
: PN(P), Ith(I), Val(V), HighCost(H) {}
1219+
PHINode *PN; // For which PHI node is this replacement?
1220+
unsigned Ith; // For which incoming value?
1221+
const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting.
1222+
Instruction *ExpansionPoint; // Where we'd like to expand that SCEV?
1223+
bool HighCost; // Is this expansion a high-cost?
1224+
1225+
Value *Expansion = nullptr;
1226+
bool ValidRewrite = false;
1227+
1228+
RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt,
1229+
bool H)
1230+
: PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt),
1231+
HighCost(H) {}
12261232
};
12271233

12281234
// Check whether it is possible to delete the loop after rewriting exit
@@ -1255,6 +1261,8 @@ static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet)
12551261
// phase later. Skip it in the loop invariant check below.
12561262
bool found = false;
12571263
for (const RewritePhi &Phi : RewritePhiSet) {
1264+
if (!Phi.ValidRewrite)
1265+
continue;
12581266
unsigned i = Phi.Ith;
12591267
if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
12601268
found = true;
@@ -1372,42 +1380,66 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
13721380
!isa<SCEVUnknown>(ExitValue) && hasHardUserWithinLoop(L, Inst))
13731381
continue;
13741382

1383+
// Check if expansions of this SCEV would count as being high cost.
13751384
bool HighCost = Rewriter.isHighCostExpansion(
13761385
ExitValue, L, SCEVCheapExpansionBudget, TTI, Inst);
1377-
Value *ExitVal = Rewriter.expandCodeFor(ExitValue, PN->getType(), Inst);
1378-
1379-
LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = "
1380-
<< *ExitVal << '\n' << " LoopVal = " << *Inst
1381-
<< "\n");
1382-
1383-
if (!isValidRewrite(SE, Inst, ExitVal)) {
1384-
DeadInsts.push_back(ExitVal);
1385-
continue;
1386-
}
13871386

1388-
#ifndef NDEBUG
1389-
// If we reuse an instruction from a loop which is neither L nor one of
1390-
// its containing loops, we end up breaking LCSSA form for this loop by
1391-
// creating a new use of its instruction.
1392-
if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
1393-
if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
1394-
if (EVL != L)
1395-
assert(EVL->contains(L) && "LCSSA breach detected!");
1396-
#endif
1387+
// Note that we must not perform expansions until after
1388+
// we query *all* the costs, because if we perform temporary expansion
1389+
// inbetween, one that we might not intend to keep, said expansion
1390+
// *may* affect cost calculation of the the next SCEV's we'll query,
1391+
// and next SCEV may errneously get smaller cost.
13971392

13981393
// Collect all the candidate PHINodes to be rewritten.
1399-
RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost);
1394+
RewritePhiSet.emplace_back(PN, i, ExitValue, Inst, HighCost);
14001395
}
14011396
}
14021397
}
14031398

1399+
// Now that we've done preliminary filtering and billed all the SCEV's,
1400+
// we can perform the last sanity check - the expansion must be valid.
1401+
for (RewritePhi &Phi : RewritePhiSet) {
1402+
Phi.Expansion = Rewriter.expandCodeFor(Phi.ExpansionSCEV, Phi.PN->getType(),
1403+
Phi.ExpansionPoint);
1404+
1405+
LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = "
1406+
<< *(Phi.Expansion) << '\n'
1407+
<< " LoopVal = " << *(Phi.ExpansionPoint) << "\n");
1408+
1409+
// FIXME: isValidRewrite() is a hack. it should be an assert, eventually.
1410+
Phi.ValidRewrite = isValidRewrite(SE, Phi.ExpansionPoint, Phi.Expansion);
1411+
if (!Phi.ValidRewrite) {
1412+
DeadInsts.push_back(Phi.Expansion);
1413+
continue;
1414+
}
1415+
1416+
#ifndef NDEBUG
1417+
// If we reuse an instruction from a loop which is neither L nor one of
1418+
// its containing loops, we end up breaking LCSSA form for this loop by
1419+
// creating a new use of its instruction.
1420+
if (auto *ExitInsn = dyn_cast<Instruction>(Phi.Expansion))
1421+
if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
1422+
if (EVL != L)
1423+
assert(EVL->contains(L) && "LCSSA breach detected!");
1424+
#endif
1425+
}
1426+
1427+
// TODO: after isValidRewrite() is an assertion, evaluate whether
1428+
// it is beneficial to change how we calculate high-cost:
1429+
// if we have SCEV 'A' which we know we will expand, should we calculate
1430+
// the cost of other SCEV's after expanding SCEV 'A',
1431+
// thus potentially giving cost bonus to those other SCEV's?
1432+
14041433
bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
14051434
int NumReplaced = 0;
14061435

14071436
// Transformation.
14081437
for (const RewritePhi &Phi : RewritePhiSet) {
1438+
if (!Phi.ValidRewrite)
1439+
continue;
1440+
14091441
PHINode *PN = Phi.PN;
1410-
Value *ExitVal = Phi.Val;
1442+
Value *ExitVal = Phi.Expansion;
14111443

14121444
// Only do the rewrite when the ExitValue can be expanded cheaply.
14131445
// If LoopCanBeDel is true, rewrite exit value aggressively.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
; RUN: opt < %s -indvars -replexitval=always -S | FileCheck %s --check-prefix=ALWAYS
2+
; RUN: opt < %s -indvars -replexitval=never -S | FileCheck %s --check-prefix=NEVER
3+
; RUN: opt < %s -indvars -replexitval=cheap -scev-cheap-expansion-budget=1 -S | FileCheck %s --check-prefix=CHEAP
4+
5+
; rewriteLoopExitValues() must rewrite all or none of a PHI's values from a given block.
6+
7+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
8+
9+
@a = common global i8 0, align 1
10+
11+
define internal fastcc void @d(i8* %c) unnamed_addr #0 {
12+
entry:
13+
%cmp = icmp ule i8* %c, getelementptr inbounds (i8, i8* @a, i64 65535)
14+
%add.ptr = getelementptr inbounds i8, i8* %c, i64 -65535
15+
br label %while.cond
16+
17+
while.cond:
18+
br i1 icmp ne (i8 0, i8 0), label %cont, label %while.end
19+
20+
cont:
21+
%a.mux = select i1 %cmp, i8* @a, i8* %add.ptr
22+
switch i64 0, label %while.cond [
23+
i64 -1, label %handler.pointer_overflow.i
24+
i64 0, label %handler.pointer_overflow.i
25+
]
26+
27+
handler.pointer_overflow.i:
28+
%a.mux.lcssa4 = phi i8* [ %a.mux, %cont ], [ %a.mux, %cont ]
29+
; ALWAYS: [ %scevgep, %cont ], [ %scevgep, %cont ]
30+
; NEVER: [ %a.mux, %cont ], [ %a.mux, %cont ]
31+
; In cheap mode, use either one as long as it's consistent.
32+
; CHEAP: [ %[[VAL:.*]], %cont ], [ %[[VAL]], %cont ]
33+
%x5 = ptrtoint i8* %a.mux.lcssa4 to i64
34+
br label %while.end
35+
36+
while.end:
37+
ret void
38+
}

0 commit comments

Comments
 (0)