@@ -70,7 +70,7 @@ void setExplicitFallbackNodes(PartitioningCtx* ctx, torch::jit::Block* block) {
7070 const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
7171
7272 for (const auto n : nodes) {
73- if (n-> kind () == torch::jit::prim::Constant ) {
73+ if (isConstantOrUninitialized (n) ) {
7474 continue ;
7575 }
7676
@@ -107,7 +107,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
107107 q.pop ();
108108 // for every node that produces this fallback node's NonTensor input, they should fallback too
109109 for (auto input : cur_node->inputs ()) {
110- if (!isTensor (input) && input->node ()-> kind () != torch::jit::prim::Constant &&
110+ if (!isTensor (input) && ! isConstantOrUninitialized ( input->node ()) &&
111111 ctx->shouldNodeRunInTensorRT (input->node ())) {
112112 ctx->setNodeExecutorDecision (input->node (), NodeExecutorDecision::kNON_TENSOR );
113113 q.push (input->node ());
@@ -118,7 +118,7 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
118118 if (!isTensor (output)) {
119119 for (auto use : output->uses ()) {
120120 auto node = use.user ;
121- if (node-> kind () != torch::jit::prim::Constant && ctx->shouldNodeRunInTensorRT (node)) {
121+ if (isConstantOrUninitialized (node) && ctx->shouldNodeRunInTensorRT (node)) {
122122 ctx->setNodeExecutorDecision (node, NodeExecutorDecision::kNON_TENSOR );
123123 q.push (node);
124124 }
@@ -128,11 +128,13 @@ void setNonTensorConnectedNodes(PartitioningCtx* ctx, std::vector<torch::jit::No
128128 }
129129}
130130
131- std::set<torch::jit::Node*> getDependentNodes (torch::jit::Node* n) {
132- std::set<torch::jit::Node*> dependent_nodes ;
131+ std::set<torch::jit::Node*> getUserNodes (torch::jit::Node* n) {
132+ std::set<torch::jit::Node*> user_nodes ;
133133 for (auto val : n->outputs ()) {
134134 for (auto use : val->uses ()) {
135- dependent_nodes.insert (use.user );
135+ if (use.user ->owningBlock ()->owningNode ())
136+ user_nodes.insert (use.user ->owningBlock ()->owningNode ());
137+ user_nodes.insert (use.user );
136138 }
137139 }
138140 if (const auto * schema = n->maybeSchema ()) {
@@ -142,13 +144,13 @@ std::set<torch::jit::Node*> getDependentNodes(torch::jit::Node* n) {
142144 for (auto use : n->inputs ()[i]->uses ()) {
143145 torch::jit::Node* use_node = use.user ;
144146 if (use_node->isAfter (n)) {
145- dependent_nodes .insert (use_node);
147+ user_nodes .insert (use_node);
146148 }
147149 }
148150 }
149151 }
150152 }
151- return dependent_nodes ;
153+ return user_nodes ;
152154}
153155
154156// Sub-function that traverses the entire block and check if TensorRT node sequence satisfy min_block_size
@@ -158,14 +160,14 @@ std::vector<torch::jit::Node*> traverseNodesForMinBlockSize(PartitioningCtx* ctx
158160 std::unordered_set<torch::jit::Node*> cur_trt_nodes_uses;
159161 std::vector<torch::jit::Node*> min_block_fallback_nodes;
160162 for (const auto n : nodes) {
161- if (n-> kind () == torch::jit::prim::Constant ) {
163+ if (isConstantOrUninitialized (n) ) {
162164 continue ;
163165 }
164166
165167 // check if current node fallback or not
166168 if (!ctx->shouldNodeRunInTorch (n)) {
167169 cur_trt_nodes.push_back (n);
168- auto dependent_nodes = getDependentNodes (n);
170+ auto dependent_nodes = getUserNodes (n);
169171 cur_trt_nodes_uses.insert (dependent_nodes.begin (), dependent_nodes.end ());
170172 } else {
171173 if (cur_trt_nodes_uses.count (n)) {
@@ -250,7 +252,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
250252 auto cur_val = q.front ();
251253 q.pop ();
252254 auto node = cur_val->node ();
253- if (node-> kind () != torch::jit::prim::Constant && !visited.count (node)) {
255+ if (! isConstantOrUninitialized (node) && !visited.count (node)) {
254256 visited.insert (node);
255257 auto modifying_nodes = findModifyingNodes (cur_val, seg_block_nodes);
256258 stk.insert (stk.end (), modifying_nodes.rbegin (), modifying_nodes.rend ());
@@ -454,10 +456,10 @@ void segmentGraph(PartitioningCtx* ctx, torch::jit::Block* block) {
454456 std::unordered_set<torch::jit::Node*> cur_pyt_nodes_uses;
455457 for (const auto n : nodes) {
456458 // Skip constant nodes as they are resources for both kinds of modules
457- if (n-> kind () == torch::jit::prim::Constant ) {
459+ if (isConstantOrUninitialized (n) ) {
458460 continue ;
459461 }
460- auto dependent_nodes = getDependentNodes (n);
462+ auto dependent_nodes = getUserNodes (n);
461463 // the outputs of trt subgraph shouldn't be collections
462464 if (ctx->shouldNodeRunInTensorRT (n)) {
463465 in_prog_trt_blk_nodes.push_back (n);
0 commit comments