Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 3 additions & 22 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,34 +181,15 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
seg_block.registerOutput(mini_graph_input);
}
}
// if no output, then register the last node's output as current graph's output
// if no output, then register this graph's input as its output
// We can ensure that TRT segmented block has Tensor inputs now
if (seg_block.raw_outputs().empty()) {
// for Torch segments, register input as output
if (seg_block.target() == SegmentedBlock::kTorch) {
seg_block.registerOutput(seg_block.raw_inputs()[0]);
} else {
// for TensorRT segments, register last nonInput Tensor outputs
for (int i = seg_block.raw_nodes().size() - 1; i >= 0; --i) {
for (auto node_output : seg_block.raw_nodes()[i]->outputs()) {
if (isTensor(node_output))
seg_block.registerOutput(node_output);
}
if (!seg_block.raw_outputs().empty())
break;
}
}
seg_block.registerOutput(seg_block.raw_inputs()[0]);
}
}
std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) {
torch::jit::EliminateDeadCode(seg_block.g());
});
// erase segments which still have no output
segmented_blocks.erase(
std::remove_if(
segmented_blocks.begin(),
segmented_blocks.end(),
[](SegmentedBlock& seg_block) { return seg_block.raw_outputs().empty(); }),
segmented_blocks.end());

return;
}
Expand Down