@@ -256,6 +256,7 @@ GraphAndMapping ConstructFallbackGraph(
256256 // update the input ranges for each segments
257257 convert_cfg.inputs = ir::associate_specs_with_inputs (seg_block.g (), inputs, static_params);
258258
259+ // TODO mapping Inputs Ivalue to flatten one here
259260 auto engine = conversion::ConvertBlockToEngine (seg_block.block (), convert_cfg, static_params);
260261 auto temp_g = std::make_shared<torch::jit::Graph>();
261262 auto device_spec = convert_cfg.engine_settings .device ;
@@ -306,57 +307,72 @@ void MapInputsAndDetermineDTypes(
306307 CompileSpec& cfg,
307308 std::shared_ptr<torch::jit::Graph>& g,
308309 ir::StaticParams& static_params,
309- ir::TypeMap& first_use_type_map) {
310- // Associate input specs with inputs
311- cfg.convert_info .inputs = std::move (ir::associate_specs_with_inputs (g, cfg.inputs , static_params));
312-
313- for (auto & in : g->inputs ()) {
314- if (static_params.find (in) == static_params.end ()) {
315- ir::Input& spec = cfg.convert_info .inputs .find (in)->second ;
316- auto est_type_opt = first_use_type_map.find (in)->second ;
317- if (est_type_opt && !spec.dtype_is_user_defined ) {
318- // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
319- // type
320- LOG_INFO (
321- " Since input type is not explicitly defined, infering using first tensor calculation\n Found input "
322- << in->debugName () << " has type " << est_type_opt.value ()
323- << " . If this is incorrect explicitly set dtype for input and file a bug" );
324- spec.dtype = util::ScalarTypeToTRTDataType (est_type_opt.value ());
325- } else if (!est_type_opt && !spec.dtype_is_user_defined ) {
326- // If we cannot calculate the type and the user did not define the type, then default to FP32
327- LOG_WARNING (
328- " Cannot infer input type from calcuations in graph for input "
329- << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
330- spec.dtype = nvinfer1::DataType::kFLOAT ;
331- } else if (spec.dtype_is_user_defined && cfg.partition_info .enabled ) {
332- if (!est_type_opt) {
333- LOG_INFO (" Cannot infer input tensor dtype in graph. Using user provided input dtype settings" );
334- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
335- } else {
336- if (util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype ) != est_type_opt.value ()) {
310+ ir::CollectionTypeMap& first_use_type_map) {
311+ cfg.convert_info .collection_input_spec_map = std::move (ir::associate_specs_with_collection_inputs (g, cfg.graph_inputs , static_params));
312+
313+ auto collection_inputs = ir::get_collection_inputs (g, static_params);
314+ LOG_DEBUG (" In MapInputsAndDetermineDTypes, the g->inputs() size is " << g->inputs ().size () << " , CollectionInputSpecMap size is" << collection_inputs.size ());
315+
316+ for (auto in : collection_inputs) {
317+ std::vector<ir::Input>& spec = cfg.convert_info .collection_input_spec_map .find (in)->second ;
318+ std::vector<c10::optional<at::ScalarType>> est_type_opt;
319+
320+ auto est_it = first_use_type_map.find (in);
321+ if (est_it != first_use_type_map.end ()) {
322+ est_type_opt = first_use_type_map.find (in)->second ;
323+ }
324+ // traverse elements in est_type_out and spec
325+ for (int i = 0 ; i < est_type_opt.size (); i++) {
326+ if (est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
327+ // If we can calculate the type from the graph and the type was not defined by the user then use the calculated
328+ // type
329+ LOG_INFO (
330+ " Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
331+ << in->debugName () << " has type " << est_type_opt[i].value ());
332+ spec[i].dtype = util::ScalarTypeToTRTDataType (est_type_opt[i].value ());
333+ } else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined ) {
334+ // If we cannot calculate the type and the user did not define the type, then default to FP32
335+ LOG_WARNING (
336+ " Cannot infer input type from calcuations in graph for input "
337+ << in->debugName () << " . Assuming it is Float32. If not, specify input type explicity" );
338+ spec[i].dtype = nvinfer1::DataType::kFLOAT ;
339+ } else if (spec[i].dtype_is_user_defined && cfg.partition_info .enabled ) {
340+ if (!est_type_opt[i]) {
341+ LOG_INFO (" Cannot infer input tensor dtype in graph, compiler is going to use the user setting" );
337342 std::stringstream ss;
338343 ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
339- ss << cfg.convert_info .inputs .find (in)->second .dtype ;
340- ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
341- ss << est_type_opt.value () << std::endl;
342- ss << " The compiler is going to use the user setting " << cfg.convert_info .inputs .find (in)->second .dtype ;
343- ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
344- ss << " compatibility with PyTorch's data type convention is required.\n " ;
345- ss << " If you do indeed see errors at runtime either:\n " ;
346- ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
347- ss << " - Disable partial compilation by setting require_full_compilation to True" ;
344+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
345+ ss << " . The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
348346 auto warn_str = ss.str ();
349347 LOG_WARNING (warn_str);
348+ // Overwrite type map with user settings
349+ first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
350+
351+ } else {
352+ if (util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ) != est_type_opt[i].value ()) {
353+ std::stringstream ss;
354+ ss << " For input " << in->debugName () << " , found user specified input dtype as " ;
355+ ss << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
356+ ss << " , however when inspecting the graph, the input type expected was inferred to be " ;
357+ ss << est_type_opt[i].value () << std::endl;
358+ ss << " The compiler is going to use the user setting " << cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype ;
359+ ss << " \n This conflict may cause an error at runtime due to partial compilation being enabled and therefore\n " ;
360+ ss << " compatibility with PyTorch's data type convention is required.\n " ;
361+ ss << " If you do indeed see errors at runtime either:\n " ;
362+ ss << " - Remove the dtype spec for " << in->debugName () << std::endl;
363+ ss << " - Disable partial compilation by setting require_full_compilation to True" ;
364+ auto warn_str = ss.str ();
365+ LOG_WARNING (warn_str);
366+ // Overwrite type map with user settings
367+ first_use_type_map[in][i] = {util::TRTDataTypeToScalarType (cfg.convert_info .collection_input_spec_map .find (in)->second [i].dtype )};
368+ }
350369 }
351- // Overwrite type map with user settings
352- // We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
353- first_use_type_map[in] = {util::TRTDataTypeToScalarType (cfg.convert_info .inputs .find (in)->second .dtype )};
370+ } else {
371+ // The user defined the type so no changes are necessary
354372 }
355- } else {
356- // The user defined the type so no changes are necessary
357373 }
358374 }
359- }
375+ // }
360376}
361377
362378std::string ConvertGraphToTRTEngine (const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -370,7 +386,8 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
370386 auto params = graph_and_parameters.second ;
371387 auto static_params = ir::get_static_params (g->inputs (), params);
372388 // Infer the type of an input from the weights of the calculation
373- auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
389+ // auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
390+ auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
374391
375392 // // GPU default WS size : 1 GB
376393 // // Set WS = 256 Mb for Jetson nano/TX1 like platforms whose compute capability is 5.X.
@@ -410,23 +427,25 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
410427 auto params = graph_and_parameters.second ;
411428 auto static_params = ir::get_static_params (g->inputs (), params);
412429 // Infer the type of an input from the weights of the calculation
413- auto first_use_types = ir::get_block_first_calc_dtypes_opt (g->block ());
430+ auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection (g->block ());
414431
415432 MapInputsAndDetermineDTypes (cfg, g, static_params, first_use_types);
416433 auto isBlockConvertible = conversion::VerifyConverterSupportForBlock (g->block (), true );
434+ auto outputIsCollection = conversion::OutputIsCollection (g->block ());
417435 if (cfg.partition_info .enabled &&
418436 (cfg.lower_info .forced_fallback_modules .size () == 0 &&
419437 cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)) {
420438 LOG_INFO (" Skipping partitioning since model is fully supported" );
421439 }
422440
423441 if (cfg.partition_info .enabled &&
424- !(cfg.lower_info .forced_fallback_modules .size () == 0 &&
425- cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)) {
426- auto input_ivalues_map = partitioning::generateRandomInputs (cfg.convert_info .inputs , first_use_types);
442+ (!(cfg.lower_info .forced_fallback_modules .size () == 0 &&
443+ cfg.partition_info .forced_fallback_operators .size () == 0 && isBlockConvertible)
444+ || outputIsCollection)) {
445+
427446 std::unordered_map<torch::jit::Node*, int > fallback_nodes;
428- auto graph_and_mapping =
429- ConstructFallbackGraph (new_mod, g->block (), input_ivalues_map , cfg, static_params, fallback_nodes);
447+ auto collection_input_ivalues_map = partitioning::generateRandomInputs (cfg. convert_info . collection_input_spec_map , first_use_types);
448+ auto graph_and_mapping = ConstructFallbackGraph (new_mod, g->block (), collection_input_ivalues_map , cfg, static_params, fallback_nodes);
430449 new_g = graph_and_mapping.first ;
431450 LOG_INFO (" Segmented Graph: " << *new_g);
432451
@@ -440,6 +459,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
440459 TORCHTRT_CHECK (
441460 conversion::VerifyConverterSupportForBlock (g->block ()),
442461 " Not all operations in graph are supported by the compiler" );
462+ // TODO find the right
443463 auto engine = conversion::ConvertBlockToEngine (g->block (), cfg.convert_info , static_params);
444464 AddEngineToGraph (new_mod, new_g, engine, cuda_device);
445465 }
0 commit comments