@@ -257,3 +257,147 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
257257  int  count = count_trt_engines (fallback_g);
258258  ASSERT_TRUE (count == 2 );
259259}
260+ 
261+ TEST (Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
262+   /*  parseIR does not support "= aten::_set_item" so we will build this graph manually
263+     const auto graph = R"IR( 
264+     graph(%x : Tensor, 
265+       %y : Tensor): 
266+     %2 : str = prim::Constant[value="INS"]() 
267+     %3 : str = prim::Constant[value="OUTS"]() 
268+     %4 : bool = prim::Constant[value=0]() 
269+     %5 : int = prim::Constant[value=-1]() 
270+     %6 : Dict(str, Tensor) = prim::DictConstruct() 
271+      = aten::_set_item(%6, %2, %x) 
272+     %7 : Tensor = aten::__getitem__(%6, %2) 
273+     %8 : Tensor = aten::lt(%7, %y) 
274+     %9 : Tensor?[] = prim::ListConstruct(%8) 
275+     %10 : int = prim::dtype(%7) 
276+     %11 : Device = prim::device(%7) 
277+     %12 : Tensor = aten::tensor(%5, %10, %11, %4) 
278+     %13 : Tensor = aten::index_put_(%7, %9, %12, %4) 
279+      = aten::_set_item(%6, %3, %7) 
280+     %14 : Tensor = aten::__getitem__(%6, %2) 
281+     %15 : Tensor = aten::__getitem__(%6, %3) 
282+     return (%14, %15))IR"; 
283+   */  
284+   auto  g = std::make_shared<torch::jit::Graph>();
285+   auto  x = g->insertInput (0 , " x"  );
286+   auto  y = g->insertInput (1 , " y"  );
287+   torch::jit::IValue ins_key (" INS"  );
288+   auto  ins_key_val = g->insertConstant (ins_key);
289+   torch::jit::IValue outs_key (" OUTS"  );
290+   auto  outs_key_val = g->insertConstant (outs_key);
291+   torch::jit::IValue zero (0 );
292+   auto  false_const_val = g->insertConstant (zero);
293+   false_const_val->setType (c10::BoolType::get ());
294+   torch::jit::IValue neg_one (-1 );
295+   auto  neg_one_const_val = g->insertConstant (neg_one);
296+   auto  dict_node = g->createDict (
297+       ins_key_val->type (),
298+       x->type (),
299+       torch::jit::ArrayRef<torch::jit::Value*>(),
300+       torch::jit::ArrayRef<torch::jit::Value*>());
301+   g->insertNode (dict_node);
302+   auto  set_node = g->create (
303+       torch::jit::Symbol::fromQualString (" aten::_set_item"  ),
304+       torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val, x},
305+       0 );
306+   g->insertNode (set_node);
307+   auto  get_node = g->create (
308+       torch::jit::Symbol::fromQualString (" aten::__getitem__"  ),
309+       torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
310+       1 );
311+   g->insertNode (get_node);
312+   auto  lt_node = g->create (
313+       torch::jit::Symbol::fromQualString (" aten::lt"  ),
314+       torch::jit::ArrayRef<torch::jit::Value*>{get_node->output (), y},
315+       1 );
316+   g->insertNode (lt_node);
317+   auto  list_node = g->createList (
318+       at::OptionalType::create (lt_node->output ()->type ()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output ()});
319+   g->insertNode (list_node);
320+   auto  dtype_node = g->create (
321+       torch::jit::Symbol::fromQualString (" prim::dtype"  ),
322+       torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
323+       1 );
324+   dtype_node->output ()->setType (neg_one_const_val->type ());
325+   g->insertNode (dtype_node);
326+   auto  device_node = g->create (
327+       torch::jit::Symbol::fromQualString (" prim::device"  ),
328+       torch::jit::ArrayRef<torch::jit::Value*>{get_node->output ()},
329+       1 );
330+   device_node->output ()->setType (c10::DeviceObjType::get ());
331+   g->insertNode (device_node);
332+   auto  tensor_node = g->create (
333+       torch::jit::Symbol::fromQualString (" aten::tensor"  ),
334+       torch::jit::ArrayRef<torch::jit::Value*>{
335+           neg_one_const_val, dtype_node->output (), device_node->output (), false_const_val},
336+       1 );
337+   g->insertNode (tensor_node);
338+   auto  index_put_node = g->create (
339+       torch::jit::Symbol::fromQualString (" aten::index_put_"  ),
340+       torch::jit::ArrayRef<torch::jit::Value*>{
341+           get_node->output (), list_node->output (), tensor_node->output (), false_const_val},
342+       1 );
343+   g->insertNode (index_put_node);
344+   auto  out_set_node = g->create (
345+       torch::jit::Symbol::fromQualString (" aten::_set_item"  ),
346+       torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val, get_node->output ()},
347+       0 );
348+   g->insertNode (out_set_node);
349+   auto  get_ins_node = g->create (
350+       torch::jit::Symbol::fromQualString (" aten::__getitem__"  ),
351+       torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), ins_key_val},
352+       1 );
353+   g->insertNode (get_ins_node);
354+   auto  get_outs_node = g->create (
355+       torch::jit::Symbol::fromQualString (" aten::__getitem__"  ),
356+       torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output (), outs_key_val},
357+       1 );
358+   g->insertNode (get_outs_node);
359+   g->registerOutput (get_ins_node->output ());
360+   g->registerOutput (get_outs_node->output ());
361+ 
362+   torch_tensorrt::core::partitioning::PartitionInfo partition_info;
363+   partition_info.enabled  = true ;
364+   std::vector<torch_tensorrt::core::ir::Input> inputs;
365+   inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
366+   inputs.push_back (torch_tensorrt::core::ir::Input ({4 , 4 }));
367+ 
368+   std::unordered_map<const  torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
369+   std::unordered_map<const  torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
370+   for  (size_t  i = 0 ; i < g->inputs ().size (); ++i) {
371+     inputs_map.insert ({g->inputs ()[i], inputs[i]});
372+     input_types.insert ({g->inputs ()[i], {at::kFloat }});
373+   }
374+   auto  input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs (inputs_map, input_types);
375+   auto  segmented_blocks = torch_tensorrt::core::partitioning::Partition (g->block (), input_ivalues_map, partition_info);
376+ 
377+   int  torch_block_cnt = 0 , trt_block_cnt = 0 ;
378+   for  (const  auto & segmented_block : segmented_blocks) {
379+     if  (segmented_block.target () == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT ) {
380+       ++trt_block_cnt;
381+       ASSERT_TRUE (checkSegmentedBlockInputType (segmented_block, [](torch::jit::TypePtr type_ptr) {
382+         return  type_ptr->isSubtypeOf (torch::jit::TensorType::get ());
383+       }));
384+     } else  {
385+       ++torch_block_cnt;
386+       bool  output_dict = false ;
387+       bool  input_dict = false ;
388+       auto  dict_type = dict_node->output ()->type ();
389+       for  (auto  in : segmented_block.raw_inputs ()) {
390+         if  (in->type ()->isSubtypeOf (dict_type)) {
391+           input_dict = true ;
392+         }
393+       }
394+       for  (auto  out : segmented_block.raw_outputs ()) {
395+         if  (out->type ()->isSubtypeOf (dict_type)) {
396+           output_dict = true ;
397+         }
398+       }
399+       EXPECT_TRUE (output_dict ^ input_dict);
400+     }
401+   }
402+   ASSERT_TRUE (trt_block_cnt == 1  && torch_block_cnt == 2 );
403+ }
0 commit comments