|
16 | 16 | //===----------------------------------------------------------------------===// |
17 | 17 |
|
18 | 18 | #include "CoroInternal.h" |
| 19 | +#include "SuspendCrossingInfo.h" |
19 | 20 | #include "llvm/ADT/BitVector.h" |
20 | 21 | #include "llvm/ADT/PostOrderIterator.h" |
21 | 22 | #include "llvm/ADT/ScopeExit.h" |
@@ -51,315 +52,6 @@ extern cl::opt<bool> UseNewDbgInfoFormat; |
51 | 52 | // "coro-frame", which results in leaner debug spew. |
52 | 53 | #define DEBUG_TYPE "coro-suspend-crossing" |
53 | 54 |
|
54 | | -enum { SmallVectorThreshold = 32 }; |
55 | | - |
56 | | -// Provides two way mapping between the blocks and numbers. |
57 | | -namespace { |
58 | | -class BlockToIndexMapping { |
59 | | - SmallVector<BasicBlock *, SmallVectorThreshold> V; |
60 | | - |
61 | | -public: |
62 | | - size_t size() const { return V.size(); } |
63 | | - |
64 | | - BlockToIndexMapping(Function &F) { |
65 | | - for (BasicBlock &BB : F) |
66 | | - V.push_back(&BB); |
67 | | - llvm::sort(V); |
68 | | - } |
69 | | - |
70 | | - size_t blockToIndex(BasicBlock const *BB) const { |
71 | | - auto *I = llvm::lower_bound(V, BB); |
72 | | - assert(I != V.end() && *I == BB && "BasicBlockNumberng: Unknown block"); |
73 | | - return I - V.begin(); |
74 | | - } |
75 | | - |
76 | | - BasicBlock *indexToBlock(unsigned Index) const { return V[Index]; } |
77 | | -}; |
78 | | -} // end anonymous namespace |
79 | | - |
80 | | -// The SuspendCrossingInfo maintains data that allows to answer a question |
81 | | -// whether given two BasicBlocks A and B there is a path from A to B that |
82 | | -// passes through a suspend point. |
83 | | -// |
84 | | -// For every basic block 'i' it maintains a BlockData that consists of: |
85 | | -// Consumes: a bit vector which contains a set of indices of blocks that can |
86 | | -// reach block 'i'. A block can trivially reach itself. |
87 | | -// Kills: a bit vector which contains a set of indices of blocks that can |
88 | | -// reach block 'i' but there is a path crossing a suspend point |
89 | | -// not repeating 'i' (path to 'i' without cycles containing 'i'). |
90 | | -// Suspend: a boolean indicating whether block 'i' contains a suspend point. |
91 | | -// End: a boolean indicating whether block 'i' contains a coro.end intrinsic. |
92 | | -// KillLoop: There is a path from 'i' to 'i' not otherwise repeating 'i' that |
93 | | -// crosses a suspend point. |
94 | | -// |
95 | | -namespace { |
96 | | -class SuspendCrossingInfo { |
97 | | - BlockToIndexMapping Mapping; |
98 | | - |
99 | | - struct BlockData { |
100 | | - BitVector Consumes; |
101 | | - BitVector Kills; |
102 | | - bool Suspend = false; |
103 | | - bool End = false; |
104 | | - bool KillLoop = false; |
105 | | - bool Changed = false; |
106 | | - }; |
107 | | - SmallVector<BlockData, SmallVectorThreshold> Block; |
108 | | - |
109 | | - iterator_range<pred_iterator> predecessors(BlockData const &BD) const { |
110 | | - BasicBlock *BB = Mapping.indexToBlock(&BD - &Block[0]); |
111 | | - return llvm::predecessors(BB); |
112 | | - } |
113 | | - |
114 | | - BlockData &getBlockData(BasicBlock *BB) { |
115 | | - return Block[Mapping.blockToIndex(BB)]; |
116 | | - } |
117 | | - |
118 | | - /// Compute the BlockData for the current function in one iteration. |
119 | | - /// Initialize - Whether this is the first iteration, we can optimize |
120 | | - /// the initial case a little bit by manual loop switch. |
121 | | - /// Returns whether the BlockData changes in this iteration. |
122 | | - template <bool Initialize = false> |
123 | | - bool computeBlockData(const ReversePostOrderTraversal<Function *> &RPOT); |
124 | | - |
125 | | -public: |
126 | | -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
127 | | - void dump() const; |
128 | | - void dump(StringRef Label, BitVector const &BV, |
129 | | - const ReversePostOrderTraversal<Function *> &RPOT) const; |
130 | | -#endif |
131 | | - |
132 | | - SuspendCrossingInfo(Function &F, coro::Shape &Shape); |
133 | | - |
134 | | - /// Returns true if there is a path from \p From to \p To crossing a suspend |
135 | | - /// point without crossing \p From a 2nd time. |
136 | | - bool hasPathCrossingSuspendPoint(BasicBlock *From, BasicBlock *To) const { |
137 | | - size_t const FromIndex = Mapping.blockToIndex(From); |
138 | | - size_t const ToIndex = Mapping.blockToIndex(To); |
139 | | - bool const Result = Block[ToIndex].Kills[FromIndex]; |
140 | | - LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName() |
141 | | - << " answer is " << Result << "\n"); |
142 | | - return Result; |
143 | | - } |
144 | | - |
145 | | - /// Returns true if there is a path from \p From to \p To crossing a suspend |
146 | | - /// point without crossing \p From a 2nd time. If \p From is the same as \p To |
147 | | - /// this will also check if there is a looping path crossing a suspend point. |
148 | | - bool hasPathOrLoopCrossingSuspendPoint(BasicBlock *From, |
149 | | - BasicBlock *To) const { |
150 | | - size_t const FromIndex = Mapping.blockToIndex(From); |
151 | | - size_t const ToIndex = Mapping.blockToIndex(To); |
152 | | - bool Result = Block[ToIndex].Kills[FromIndex] || |
153 | | - (From == To && Block[ToIndex].KillLoop); |
154 | | - LLVM_DEBUG(dbgs() << From->getName() << " => " << To->getName() |
155 | | - << " answer is " << Result << " (path or loop)\n"); |
156 | | - return Result; |
157 | | - } |
158 | | - |
159 | | - bool isDefinitionAcrossSuspend(BasicBlock *DefBB, User *U) const { |
160 | | - auto *I = cast<Instruction>(U); |
161 | | - |
162 | | - // We rewrote PHINodes, so that only the ones with exactly one incoming |
163 | | - // value need to be analyzed. |
164 | | - if (auto *PN = dyn_cast<PHINode>(I)) |
165 | | - if (PN->getNumIncomingValues() > 1) |
166 | | - return false; |
167 | | - |
168 | | - BasicBlock *UseBB = I->getParent(); |
169 | | - |
170 | | - // As a special case, treat uses by an llvm.coro.suspend.retcon or an |
171 | | - // llvm.coro.suspend.async as if they were uses in the suspend's single |
172 | | - // predecessor: the uses conceptually occur before the suspend. |
173 | | - if (isa<CoroSuspendRetconInst>(I) || isa<CoroSuspendAsyncInst>(I)) { |
174 | | - UseBB = UseBB->getSinglePredecessor(); |
175 | | - assert(UseBB && "should have split coro.suspend into its own block"); |
176 | | - } |
177 | | - |
178 | | - return hasPathCrossingSuspendPoint(DefBB, UseBB); |
179 | | - } |
180 | | - |
181 | | - bool isDefinitionAcrossSuspend(Argument &A, User *U) const { |
182 | | - return isDefinitionAcrossSuspend(&A.getParent()->getEntryBlock(), U); |
183 | | - } |
184 | | - |
185 | | - bool isDefinitionAcrossSuspend(Instruction &I, User *U) const { |
186 | | - auto *DefBB = I.getParent(); |
187 | | - |
188 | | - // As a special case, treat values produced by an llvm.coro.suspend.* |
189 | | - // as if they were defined in the single successor: the uses |
190 | | - // conceptually occur after the suspend. |
191 | | - if (isa<AnyCoroSuspendInst>(I)) { |
192 | | - DefBB = DefBB->getSingleSuccessor(); |
193 | | - assert(DefBB && "should have split coro.suspend into its own block"); |
194 | | - } |
195 | | - |
196 | | - return isDefinitionAcrossSuspend(DefBB, U); |
197 | | - } |
198 | | - |
199 | | - bool isDefinitionAcrossSuspend(Value &V, User *U) const { |
200 | | - if (auto *Arg = dyn_cast<Argument>(&V)) |
201 | | - return isDefinitionAcrossSuspend(*Arg, U); |
202 | | - if (auto *Inst = dyn_cast<Instruction>(&V)) |
203 | | - return isDefinitionAcrossSuspend(*Inst, U); |
204 | | - |
205 | | - llvm_unreachable( |
206 | | - "Coroutine could only collect Argument and Instruction now."); |
207 | | - } |
208 | | -}; |
209 | | -} // end anonymous namespace |
210 | | - |
211 | | -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
212 | | -static std::string getBasicBlockLabel(const BasicBlock *BB) { |
213 | | - if (BB->hasName()) |
214 | | - return BB->getName().str(); |
215 | | - |
216 | | - std::string S; |
217 | | - raw_string_ostream OS(S); |
218 | | - BB->printAsOperand(OS, false); |
219 | | - return OS.str().substr(1); |
220 | | -} |
221 | | - |
222 | | -LLVM_DUMP_METHOD void SuspendCrossingInfo::dump( |
223 | | - StringRef Label, BitVector const &BV, |
224 | | - const ReversePostOrderTraversal<Function *> &RPOT) const { |
225 | | - dbgs() << Label << ":"; |
226 | | - for (const BasicBlock *BB : RPOT) { |
227 | | - auto BBNo = Mapping.blockToIndex(BB); |
228 | | - if (BV[BBNo]) |
229 | | - dbgs() << " " << getBasicBlockLabel(BB); |
230 | | - } |
231 | | - dbgs() << "\n"; |
232 | | -} |
233 | | - |
234 | | -LLVM_DUMP_METHOD void SuspendCrossingInfo::dump() const { |
235 | | - if (Block.empty()) |
236 | | - return; |
237 | | - |
238 | | - BasicBlock *const B = Mapping.indexToBlock(0); |
239 | | - Function *F = B->getParent(); |
240 | | - |
241 | | - ReversePostOrderTraversal<Function *> RPOT(F); |
242 | | - for (const BasicBlock *BB : RPOT) { |
243 | | - auto BBNo = Mapping.blockToIndex(BB); |
244 | | - dbgs() << getBasicBlockLabel(BB) << ":\n"; |
245 | | - dump(" Consumes", Block[BBNo].Consumes, RPOT); |
246 | | - dump(" Kills", Block[BBNo].Kills, RPOT); |
247 | | - } |
248 | | - dbgs() << "\n"; |
249 | | -} |
250 | | -#endif |
251 | | - |
252 | | -template <bool Initialize> |
253 | | -bool SuspendCrossingInfo::computeBlockData( |
254 | | - const ReversePostOrderTraversal<Function *> &RPOT) { |
255 | | - bool Changed = false; |
256 | | - |
257 | | - for (const BasicBlock *BB : RPOT) { |
258 | | - auto BBNo = Mapping.blockToIndex(BB); |
259 | | - auto &B = Block[BBNo]; |
260 | | - |
261 | | - // We don't need to count the predecessors when initialization. |
262 | | - if constexpr (!Initialize) |
263 | | - // If all the predecessors of the current Block don't change, |
264 | | - // the BlockData for the current block must not change too. |
265 | | - if (all_of(predecessors(B), [this](BasicBlock *BB) { |
266 | | - return !Block[Mapping.blockToIndex(BB)].Changed; |
267 | | - })) { |
268 | | - B.Changed = false; |
269 | | - continue; |
270 | | - } |
271 | | - |
272 | | - // Saved Consumes and Kills bitsets so that it is easy to see |
273 | | - // if anything changed after propagation. |
274 | | - auto SavedConsumes = B.Consumes; |
275 | | - auto SavedKills = B.Kills; |
276 | | - |
277 | | - for (BasicBlock *PI : predecessors(B)) { |
278 | | - auto PrevNo = Mapping.blockToIndex(PI); |
279 | | - auto &P = Block[PrevNo]; |
280 | | - |
281 | | - // Propagate Kills and Consumes from predecessors into B. |
282 | | - B.Consumes |= P.Consumes; |
283 | | - B.Kills |= P.Kills; |
284 | | - |
285 | | - // If block P is a suspend block, it should propagate kills into block |
286 | | - // B for every block P consumes. |
287 | | - if (P.Suspend) |
288 | | - B.Kills |= P.Consumes; |
289 | | - } |
290 | | - |
291 | | - if (B.Suspend) { |
292 | | - // If block B is a suspend block, it should kill all of the blocks it |
293 | | - // consumes. |
294 | | - B.Kills |= B.Consumes; |
295 | | - } else if (B.End) { |
296 | | - // If block B is an end block, it should not propagate kills as the |
297 | | - // blocks following coro.end() are reached during initial invocation |
298 | | - // of the coroutine while all the data are still available on the |
299 | | - // stack or in the registers. |
300 | | - B.Kills.reset(); |
301 | | - } else { |
302 | | - // This is reached when B block it not Suspend nor coro.end and it |
303 | | - // need to make sure that it is not in the kill set. |
304 | | - B.KillLoop |= B.Kills[BBNo]; |
305 | | - B.Kills.reset(BBNo); |
306 | | - } |
307 | | - |
308 | | - if constexpr (!Initialize) { |
309 | | - B.Changed = (B.Kills != SavedKills) || (B.Consumes != SavedConsumes); |
310 | | - Changed |= B.Changed; |
311 | | - } |
312 | | - } |
313 | | - |
314 | | - return Changed; |
315 | | -} |
316 | | - |
317 | | -SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) |
318 | | - : Mapping(F) { |
319 | | - const size_t N = Mapping.size(); |
320 | | - Block.resize(N); |
321 | | - |
322 | | - // Initialize every block so that it consumes itself |
323 | | - for (size_t I = 0; I < N; ++I) { |
324 | | - auto &B = Block[I]; |
325 | | - B.Consumes.resize(N); |
326 | | - B.Kills.resize(N); |
327 | | - B.Consumes.set(I); |
328 | | - B.Changed = true; |
329 | | - } |
330 | | - |
331 | | - // Mark all CoroEnd Blocks. We do not propagate Kills beyond coro.ends as |
332 | | - // the code beyond coro.end is reachable during initial invocation of the |
333 | | - // coroutine. |
334 | | - for (auto *CE : Shape.CoroEnds) |
335 | | - getBlockData(CE->getParent()).End = true; |
336 | | - |
337 | | - // Mark all suspend blocks and indicate that they kill everything they |
338 | | - // consume. Note, that crossing coro.save also requires a spill, as any code |
339 | | - // between coro.save and coro.suspend may resume the coroutine and all of the |
340 | | - // state needs to be saved by that time. |
341 | | - auto markSuspendBlock = [&](IntrinsicInst *BarrierInst) { |
342 | | - BasicBlock *SuspendBlock = BarrierInst->getParent(); |
343 | | - auto &B = getBlockData(SuspendBlock); |
344 | | - B.Suspend = true; |
345 | | - B.Kills |= B.Consumes; |
346 | | - }; |
347 | | - for (auto *CSI : Shape.CoroSuspends) { |
348 | | - markSuspendBlock(CSI); |
349 | | - if (auto *Save = CSI->getCoroSave()) |
350 | | - markSuspendBlock(Save); |
351 | | - } |
352 | | - |
353 | | - // It is considered to be faster to use RPO traversal for forward-edges |
354 | | - // dataflow analysis. |
355 | | - ReversePostOrderTraversal<Function *> RPOT(&F); |
356 | | - computeBlockData</*Initialize=*/true>(RPOT); |
357 | | - while (computeBlockData</*Initialize*/ false>(RPOT)) |
358 | | - ; |
359 | | - |
360 | | - LLVM_DEBUG(dump()); |
361 | | -} |
362 | | - |
363 | 55 | namespace { |
364 | 56 |
|
365 | 57 | // RematGraph is used to construct a DAG for rematerializable instructions |
@@ -438,6 +130,16 @@ struct RematGraph { |
438 | 130 | } |
439 | 131 |
|
440 | 132 | #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) |
| 133 | + static std::string getBasicBlockLabel(const BasicBlock *BB) { |
| 134 | + if (BB->hasName()) |
| 135 | + return BB->getName().str(); |
| 136 | + |
| 137 | + std::string S; |
| 138 | + raw_string_ostream OS(S); |
| 139 | + BB->printAsOperand(OS, false); |
| 140 | + return OS.str().substr(1); |
| 141 | + } |
| 142 | + |
441 | 143 | void dump() const { |
442 | 144 | dbgs() << "Entry ("; |
443 | 145 | dbgs() << getBasicBlockLabel(EntryNode->Node->getParent()); |
@@ -3159,7 +2861,7 @@ void coro::buildCoroutineFrame( |
3159 | 2861 | rewritePHIs(F); |
3160 | 2862 |
|
3161 | 2863 | // Build suspend crossing info. |
3162 | | - SuspendCrossingInfo Checker(F, Shape); |
| 2864 | + SuspendCrossingInfo Checker(F, Shape.CoroSuspends, Shape.CoroEnds); |
3163 | 2865 |
|
3164 | 2866 | doRematerializations(F, Checker, MaterializableCallback); |
3165 | 2867 |
|
|
0 commit comments