diff --git a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp index d705d8d4c7819..20be50c8e8a5b 100644 --- a/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/LivenessAnalysis.cpp @@ -137,7 +137,8 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { // Populating such blocks in `blocks`. bool mayLive = false; SmallVector blocks; - if (isa(op)) { + SmallVector argumentNotOperand; + if (auto regionBranchOp = dyn_cast(op)) { if (op->getNumResults() != 0) { // This mark value of type 1.c liveness as may live, because the region // branch operation has a return value, and the non-forwarded operand can @@ -165,6 +166,25 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { blocks.push_back(&block); } } + + // In the block of the successor block argument of RegionBranchOpInterface, + // there may be arguments of RegionBranchOpInterface, such as the IV of + // scf.forOp. Explicitly set this argument to live. + for (Region ®ion : op->getRegions()) { + SmallVector successors; + regionBranchOp.getSuccessorRegions(region, successors); + for (RegionSuccessor successor : successors) { + if (successor.isParent()) + continue; + auto arguments = successor.getSuccessor()->getArguments(); + ValueRange regionInputs = successor.getSuccessorInputs(); + for (auto argument : arguments) { + if (llvm::find(regionInputs, argument) == regionInputs.end()) { + argumentNotOperand.push_back(argument); + } + } + } + } } else if (isa(op)) { // We cannot track all successor blocks of the branch operation(More // specifically, it's the successor's successor). Additionally, different @@ -224,6 +244,15 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { Liveness *operandLiveness = getLatticeElement(operand.get()); LDBG() << "Marking branch operand live: " << operand.get(); propagateIfChanged(operandLiveness, operandLiveness->markLive()); + for (BlockArgument argument : argumentNotOperand) { + Liveness *argumentLiveness = getLatticeElement(argument); + LDBG() << "Marking RegionBranchOp's argument live: " << argument; + // TODO: this is overly conservative: we should be able to eliminate + // unused values in a RegionBranchOpInterface operation but that may + // requires removing operation results which is beyond current + // capabilities of this pass right now. + propagateIfChanged(argumentLiveness, argumentLiveness->markLive()); + } } // Now that we have checked for memory-effecting ops in the blocks of concern, @@ -231,6 +260,8 @@ void LivenessAnalysis::visitBranchOperand(OpOperand &operand) { // mark it "live" due to type (1.a/3) liveness. SmallVector operandLiveness; operandLiveness.push_back(getLatticeElement(operand.get())); + for (BlockArgument argument : argumentNotOperand) + operandLiveness.push_back(getLatticeElement(argument)); SmallVector resultsLiveness; for (const Value result : op->getResults()) resultsLiveness.push_back(getLatticeElement(result)); diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 56449469dc29f..e7304505c809e 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -649,3 +649,28 @@ func.func @callee(%arg0: index, %arg1: index, %arg2: index) -> index { %res = call @mutl_parameter(%arg0, %arg1, %arg2) : (index, index, index) -> (index) return %res : index } + +// ----- + +// This test verifies that the induction variables in loops are not deleted, the loop has results. + +// CHECK-LABEL: func @dead_value_loop_ivs +func.func @dead_value_loop_ivs_has_result(%lb: index, %ub: index, %step: index, %b: i1) -> i1 { + %loop_ret = scf.for %iv = %lb to %ub step %step iter_args(%iter = %b) -> (i1) { + cf.assert %b, "loop not dead" + scf.yield %b : i1 + } + return %loop_ret : i1 +} + +// ----- + +// This test verifies that the induction variables in loops are not deleted, the loop has no results. + +// CHECK-LABEL: func @dead_value_loop_ivs_no_result +func.func @dead_value_loop_ivs_no_result(%lb: index, %ub: index, %step: index, %input: memref, %value: f32, %pos: index) { + scf.for %iv = %lb to %ub step %step { + memref.store %value, %input[%pos] : memref + } + return +}