Skip to content

Commit 3e22f22

Browse files
joker-ephtru
authored andcommitted
[MLIR] Fix Liveness analysis handling of unreachable code (#153973)
This patch is forcing all values to be initialized by the LivenessAnalysis, even in dead blocks. The dataflow framework will skip visiting values when its already knows that a block is dynamically unreachable, so this requires specific handling. Downstream code could consider that the absence of liveness is the same a "dead". However as the code is mutated, new value can be introduced, and a transformation like "RemoveDeadValue" must conservatively consider that the absence of liveness information meant that we weren't sure if a value was dead (it could be a newly introduced value. Fixes #153906
1 parent 1339866 commit 3e22f22

File tree

6 files changed

+272
-15
lines changed

6 files changed

+272
-15
lines changed

mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,45 @@ RunLivenessAnalysis::RunLivenessAnalysis(Operation *op) {
295295

296296
loadBaselineAnalyses(solver);
297297
solver.load<LivenessAnalysis>(symbolTable);
298-
LDBG("Initializing and running solver");
298+
LLVM_DEBUG({ llvm::dbgs() << "Initializing and running solver\n"; });
299299
(void)solver.initializeAndRun(op);
300-
LDBG("Dumping liveness state for op");
300+
LLVM_DEBUG({
301+
llvm::dbgs() << "RunLivenessAnalysis initialized for op: " << op->getName()
302+
<< " check on unreachable code now:"
303+
<< "\n";
304+
});
305+
// The framework doesn't visit operations in dead blocks, so we need to
306+
// explicitly mark them as dead.
307+
op->walk([&](Operation *op) {
308+
if (op->getNumResults() == 0)
309+
return;
310+
for (auto result : llvm::enumerate(op->getResults())) {
311+
if (getLiveness(result.value()))
312+
continue;
313+
LLVM_DEBUG({
314+
llvm::dbgs() << "Result: " << result.index() << " of "
315+
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
316+
<< " has no liveness info (unreachable), mark dead"
317+
<< "\n";
318+
});
319+
solver.getOrCreateState<Liveness>(result.value());
320+
}
321+
for (auto &region : op->getRegions()) {
322+
for (auto &block : region) {
323+
for (auto blockArg : llvm::enumerate(block.getArguments())) {
324+
if (getLiveness(blockArg.value()))
325+
continue;
326+
LLVM_DEBUG({
327+
llvm::dbgs() << "Block argument: " << blockArg.index() << " of "
328+
<< OpWithFlags(op, OpPrintingFlags().skipRegions())
329+
<< " has no liveness info, mark dead"
330+
<< "\n";
331+
});
332+
solver.getOrCreateState<Liveness>(blockArg.value());
333+
}
334+
}
335+
}
336+
});
301337
}
302338

303339
const Liveness *RunLivenessAnalysis::getLiveness(Value val) {

mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp

Lines changed: 91 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2020
#include "mlir/Support/LLVM.h"
2121
#include "llvm/ADT/STLExtras.h"
22+
#include "llvm/Support/Debug.h"
2223
#include <cassert>
2324
#include <optional>
2425

2526
using namespace mlir;
2627
using namespace mlir::dataflow;
2728

29+
#define DEBUG_TYPE "dataflow"
30+
2831
//===----------------------------------------------------------------------===//
2932
// AbstractSparseLattice
3033
//===----------------------------------------------------------------------===//
@@ -64,22 +67,56 @@ AbstractSparseForwardDataFlowAnalysis::initialize(Operation *top) {
6467

6568
LogicalResult
6669
AbstractSparseForwardDataFlowAnalysis::initializeRecursively(Operation *op) {
70+
LLVM_DEBUG({
71+
llvm::dbgs() << "Initializing recursively for operation: " << op->getName()
72+
<< "\n";
73+
});
74+
6775
// Initialize the analysis by visiting every owner of an SSA value (all
6876
// operations and blocks).
69-
if (failed(visitOperation(op)))
77+
if (failed(visitOperation(op))) {
78+
LLVM_DEBUG({
79+
llvm::dbgs() << "Failed to visit operation: " << op->getName() << "\n";
80+
});
7081
return failure();
82+
}
7183

7284
for (Region &region : op->getRegions()) {
85+
LLVM_DEBUG({
86+
llvm::dbgs() << "Processing region with " << region.getBlocks().size()
87+
<< " blocks"
88+
<< "\n";
89+
});
7390
for (Block &block : region) {
91+
LLVM_DEBUG({
92+
llvm::dbgs() << "Processing block with " << block.getNumArguments()
93+
<< " arguments"
94+
<< "\n";
95+
});
7496
getOrCreate<Executable>(getProgramPointBefore(&block))
7597
->blockContentSubscribe(this);
7698
visitBlock(&block);
77-
for (Operation &op : block)
78-
if (failed(initializeRecursively(&op)))
99+
for (Operation &op : block) {
100+
LLVM_DEBUG({
101+
llvm::dbgs() << "Recursively initializing nested operation: "
102+
<< op.getName() << "\n";
103+
});
104+
if (failed(initializeRecursively(&op))) {
105+
LLVM_DEBUG({
106+
llvm::dbgs() << "Failed to initialize nested operation: "
107+
<< op.getName() << "\n";
108+
});
79109
return failure();
110+
}
111+
}
80112
}
81113
}
82114

115+
LLVM_DEBUG({
116+
llvm::dbgs()
117+
<< "Successfully completed recursive initialization for operation: "
118+
<< op->getName() << "\n";
119+
});
83120
return success();
84121
}
85122

@@ -409,11 +446,29 @@ static MutableArrayRef<OpOperand> operandsToOpOperands(OperandRange &operands) {
409446

410447
LogicalResult
411448
AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
449+
LLVM_DEBUG({
450+
llvm::dbgs() << "Visiting operation: " << op->getName() << " with "
451+
<< op->getNumOperands() << " operands and "
452+
<< op->getNumResults() << " results"
453+
<< "\n";
454+
});
455+
412456
// If we're in a dead block, bail out.
413457
if (op->getBlock() != nullptr &&
414-
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))->isLive())
458+
!getOrCreate<Executable>(getProgramPointBefore(op->getBlock()))
459+
->isLive()) {
460+
LLVM_DEBUG({
461+
llvm::dbgs() << "Operation is in dead block, bailing out"
462+
<< "\n";
463+
});
415464
return success();
465+
}
416466

467+
LLVM_DEBUG({
468+
llvm::dbgs() << "Creating lattice elements for " << op->getNumOperands()
469+
<< " operands and " << op->getNumResults() << " results"
470+
<< "\n";
471+
});
417472
SmallVector<AbstractSparseLattice *> operandLattices =
418473
getLatticeElements(op->getOperands());
419474
SmallVector<const AbstractSparseLattice *> resultLattices =
@@ -422,11 +477,21 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
422477
// Block arguments of region branch operations flow back into the operands
423478
// of the parent op
424479
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
480+
LLVM_DEBUG({
481+
llvm::dbgs() << "Processing RegionBranchOpInterface operation"
482+
<< "\n";
483+
});
425484
visitRegionSuccessors(branch, operandLattices);
426485
return success();
427486
}
428487

429488
if (auto branch = dyn_cast<BranchOpInterface>(op)) {
489+
LLVM_DEBUG({
490+
llvm::dbgs() << "Processing BranchOpInterface operation with "
491+
<< op->getNumSuccessors() << " successors"
492+
<< "\n";
493+
});
494+
430495
// Block arguments of successor blocks flow back into our operands.
431496

432497
// We remember all operands not forwarded to any block in a BitVector.
@@ -463,6 +528,10 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
463528
// For function calls, connect the arguments of the entry blocks to the
464529
// operands of the call op that are forwarded to these arguments.
465530
if (auto call = dyn_cast<CallOpInterface>(op)) {
531+
LLVM_DEBUG({
532+
llvm::dbgs() << "Processing CallOpInterface operation"
533+
<< "\n";
534+
});
466535
Operation *callableOp = call.resolveCallableInTable(&symbolTable);
467536
if (auto callable = dyn_cast_or_null<CallableOpInterface>(callableOp)) {
468537
// Not all operands of a call op forward to arguments. Such operands are
@@ -513,19 +582,36 @@ AbstractSparseBackwardDataFlowAnalysis::visitOperation(Operation *op) {
513582
// of this op itself and the operands of the terminators of the regions of
514583
// this op.
515584
if (auto terminator = dyn_cast<RegionBranchTerminatorOpInterface>(op)) {
585+
LLVM_DEBUG({
586+
llvm::dbgs() << "Processing RegionBranchTerminatorOpInterface operation"
587+
<< "\n";
588+
});
516589
if (auto branch = dyn_cast<RegionBranchOpInterface>(op->getParentOp())) {
517590
visitRegionSuccessorsFromTerminator(terminator, branch);
518591
return success();
519592
}
520593
}
521594

522595
if (op->hasTrait<OpTrait::ReturnLike>()) {
596+
LLVM_DEBUG({
597+
llvm::dbgs() << "Processing ReturnLike operation"
598+
<< "\n";
599+
});
523600
// Going backwards, the operands of the return are derived from the
524601
// results of all CallOps calling this CallableOp.
525-
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp()))
602+
if (auto callable = dyn_cast<CallableOpInterface>(op->getParentOp())) {
603+
LLVM_DEBUG({
604+
llvm::dbgs() << "Callable parent found, visiting callable operation"
605+
<< "\n";
606+
});
526607
return visitCallableOperation(op, callable, operandLattices);
608+
}
527609
}
528610

611+
LLVM_DEBUG({
612+
llvm::dbgs() << "Using default visitOperationImpl for operation: "
613+
<< op->getName() << "\n";
614+
});
529615
return visitOperationImpl(op, operandLattices, resultLattices);
530616
}
531617

0 commit comments

Comments
 (0)