Skip to content

Commit 8f974c9

Browse files
add deleteDeadFunction function.
1 parent 41f3438 commit 8f974c9

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,25 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
235235
op->erase();
236236
}
237237

238+
// Remove the dead functions from moduleOp.
239+
static void deleteDeadFunction(Operation *module) {
240+
bool walkContinue = true;
241+
while (walkContinue) {
242+
walkContinue = false;
243+
module->walk([&](FunctionOpInterface funcOp) {
244+
if (funcOp.isPublic() || funcOp.isExternal())
245+
return;
246+
247+
SymbolTable::UseRange uses = *funcOp.getSymbolUses(module);
248+
auto callSites = funcOp.getFunctionBody().getOps<CallOpInterface>();
249+
if (uses.empty())
250+
funcOp.erase();
251+
if (uses.empty() && !callSites.empty())
252+
walkContinue = true;
253+
});
254+
}
255+
}
256+
238257
/// Convert a list of `Operand`s to a list of `OpOperand`s.
239258
static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
240259
OpOperand *values = operands.getBase();
@@ -881,6 +900,8 @@ void RemoveDeadValues::runOnOperation() {
881900
// end of this pass.
882901
RDVFinalCleanupList finalCleanupList;
883902

903+
// Remove the dead function in advance.
904+
deleteDeadFunction(module);
884905
module->walk([&](Operation *op) {
885906
if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
886907
processFuncOp(funcOp, module, la, deadVals, finalCleanupList);

mlir/test/Transforms/remove-dead-values.mlir

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ module @llvm_unreachable {
455455
func.func private @fn_with_llvm_unreachable(%arg0: tensor<4x4xf32>) -> tensor<4x4xi1> {
456456
llvm.unreachable
457457
}
458-
func.func private @main(%arg0: tensor<4x4xf32>) {
458+
func.func @main(%arg0: tensor<4x4xf32>) {
459459
%0 = call @fn_with_llvm_unreachable(%arg0) : (tensor<4x4xf32>) -> tensor<4x4xi1>
460460
llvm.return
461461
}
@@ -649,3 +649,49 @@ func.func @callee(%arg0: index, %arg1: index, %arg2: index) -> index {
649649
%res = call @mutl_parameter(%arg0, %arg1, %arg2) : (index, index, index) -> (index)
650650
return %res : index
651651
}
652+
653+
// -----
654+
655+
// Test the elimination of dead functions.
656+
657+
// CHECK-NOT: func private @single_private_func
658+
func.func private @single_private_func(%arg0: i64) -> (i64) {
659+
%c0_i64 = arith.constant 0 : i64
660+
%2 = arith.cmpi eq, %arg0, %c0_i64 : i64
661+
cf.cond_br %2, ^bb1, ^bb2
662+
^bb1: // pred: ^bb0
663+
%c1_i64 = arith.constant 1 : i64
664+
return %c1_i64 : i64
665+
^bb2: // pred: ^bb0
666+
%c3_i64 = arith.constant 3 : i64
667+
return %c3_i64 : i64
668+
}
669+
670+
// -----
671+
672+
// Test the elimination of dead functions.
673+
674+
// CHECK-NOT: @single_parameter
675+
func.func private @single_parameter(%arg0: index) {
676+
return
677+
}
678+
679+
// CHECK-NOT: @mutl_parameter
680+
func.func private @mutl_parameter(%arg0: index, %arg1: index, %arg2: index) -> index {
681+
return %arg1 : index
682+
}
683+
684+
// CHECK-NOT: @eliminate_parameter
685+
func.func private @eliminate_parameter(%arg0: index, %arg1: index) {
686+
call @single_parameter(%arg0) : (index) -> ()
687+
return
688+
}
689+
690+
// CHECK-NOT: @callee
691+
func.func private @callee(%arg0: index, %arg1: index, %arg2: index) -> index {
692+
// CHECK-NOT: call @eliminate_parameter
693+
call @eliminate_parameter(%arg0, %arg1) : (index, index) -> ()
694+
// CHECK-NOT: call @mutl_parameter
695+
%res = call @mutl_parameter(%arg0, %arg1, %arg2) : (index, index, index) -> (index)
696+
return %res : index
697+
}

0 commit comments

Comments
 (0)