@@ -90,6 +90,16 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9090 return stk;
9191}
9292
93+ void find_nontensor_output_nodes (
94+ torch::jit::Block* block,
95+ std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
96+ for (auto i : block->outputs ()) {
97+ if (!isTensor (i)) {
98+ global_fallback_nodes.insert ({i->node (), FallbackNodeType::kNON_TENSOR });
99+ }
100+ }
101+ }
102+
93103void find_all_fallback_nodes (
94104 std::unordered_map<torch::jit::Node*, int >& initial_fallback_nodes,
95105 std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
@@ -430,6 +440,9 @@ PartitionedGraph Partition(
430440 const PartitionInfo& partition_info,
431441 std::unordered_map<torch::jit::Node*, int >& global_fallback_nodes) {
432442 LOG_DEBUG (partition_info);
443+ // if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output
444+ find_nontensor_output_nodes (block, global_fallback_nodes);
445+
433446 // segment lowering global graph into blocks
434447 LOG_DEBUG (" Parititioning source module into PyTorch and TensorRT sub blocks" );
435448 PartitionedGraph segmented_blocks = segment_graph (block, partition_info, global_fallback_nodes);
0 commit comments