11import gc
2- import io
32import logging
43import os
54import warnings
5049from torch_tensorrt .dynamo .debug ._DebuggerConfig import DebuggerConfig
5150from torch_tensorrt .dynamo .debug ._supports_debugger import cls_supports_debugger
5251from torch_tensorrt .dynamo .observer import Observer
53- from torch_tensorrt .dynamo .utils import DYNAMIC_DIM , deallocate_module , to_torch_device
52+ from torch_tensorrt .dynamo .utils import (
53+ DYNAMIC_DIM ,
54+ deallocate_module ,
55+ get_cpu_memory_usage ,
56+ to_torch_device ,
57+ )
5458from torch_tensorrt .logging import TRT_LOGGER
5559
5660_LOGGER : logging .Logger = logging .getLogger (__name__ )
@@ -65,7 +69,7 @@ class UnsupportedOperatorException(RuntimeError):
6569
6670
6771class TRTInterpreterResult (NamedTuple ):
68- serialized_engine : bytes
72+ engine : trt . ICudaEngine
6973 input_names : Sequence [str ]
7074 output_names : Sequence [str ]
7175 weight_name_map : Optional [dict [Any , Any ]]
@@ -512,8 +516,7 @@ def _save_weight_mapping(self) -> None:
512516 _LOGGER .info ("Building weight name mapping..." )
513517 # Stage 1: Name mapping
514518 torch_device = to_torch_device (self .compilation_settings .device )
515- self .module .to (torch_device )
516- sd = self .module .state_dict ()
519+ sd = {k : v .to (torch_device ) for k , v in self .module .state_dict ().items ()}
517520 weight_name_map : dict [str , Any ] = {}
518521 weight_refit_map = self .ctx .weight_refit_map
519522 constant_mapping = {k : v for k , v in weight_refit_map .items () if v .size == 1 }
@@ -591,34 +594,6 @@ def _save_weight_mapping(self) -> None:
591594 gc .collect ()
592595 torch .cuda .empty_cache ()
593596
594- @needs_refit # type: ignore[misc]
595- def _insert_engine_to_cache (self , hash_val : str , serialized_engine : bytes ) -> None :
596- # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
597- # if not self.compilation_settings.strip_engine_weights:
598- # # set EXCLUDE_WEIGHTS flag to strip weights
599- # runtime = trt.Runtime(TRT_LOGGER)
600- # engine = runtime.deserialize_cuda_engine(serialized_engine)
601-
602- # serialization_config = engine.create_serialization_config()
603- # serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
604- # serialized_engine = engine.serialize_with_config(
605- # serialization_config
606- # )
607-
608- # Cache weighted engine for now
609- self .engine_cache .insert ( # type: ignore[union-attr]
610- hash_val ,
611- (
612- serialized_engine ,
613- self ._input_names ,
614- self ._output_names ,
615- self .input_specs ,
616- self .compilation_settings ,
617- self .weight_name_map ,
618- self .ctx .requires_output_allocator ,
619- ),
620- )
621-
622597 @needs_refit # type: ignore[misc]
623598 def _pull_cached_engine (self , hash_val : str ) -> Optional [TRTInterpreterResult ]:
624599 # query the cached TRT engine
@@ -671,7 +646,6 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
671646 settings = self .compilation_settings ,
672647 weight_name_map = self .weight_name_map ,
673648 )
674- serialized_engine = engine .serialize ()
675649
676650 # TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
677651 # # EXCLUDE_WEIGHTS flag must be cleared
@@ -684,12 +658,8 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
684658 # )
685659 # # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller
686660
687- with io .BytesIO () as engine_bytes :
688- engine_bytes .write (serialized_engine )
689- engine_str = engine_bytes .getvalue ()
690-
691661 return TRTInterpreterResult (
692- engine_str ,
662+ engine ,
693663 self ._input_names ,
694664 self ._output_names ,
695665 self .weight_name_map ,
@@ -733,6 +703,9 @@ def run(
733703 return interpreter_result # type: ignore[no-any-return]
734704
735705 self ._construct_trt_network_def ()
706+ _LOGGER .debug (
707+ f"CPU memory usage after network construction: { get_cpu_memory_usage ()} MB"
708+ )
736709
737710 if not self .compilation_settings .immutable_weights :
738711 self ._save_weight_mapping ()
@@ -750,36 +723,39 @@ def run(
750723 self ._create_timing_cache (
751724 builder_config , self .compilation_settings .timing_cache_path
752725 )
753- serialized_engine = self .builder .build_serialized_network (
754- self .ctx .net , builder_config
726+
727+ if (
728+ ENABLED_FEATURES .tensorrt_rtx
729+ or self .compilation_settings .version_compatible
730+ ):
731+ # TODO: When TRT-RTX matures, change it to build_engine_with_config
732+ serialized_engine = self .builder .build_serialized_network (
733+ self .ctx .net , builder_config
734+ )
735+ runtime = trt .Runtime (TRT_LOGGER )
736+ cuda_engine = runtime .deserialize_cuda_engine (serialized_engine )
737+ else :
738+
739+ cuda_engine = self .builder .build_engine_with_config (
740+ self .ctx .net , builder_config
741+ )
742+ assert cuda_engine
743+
744+ _LOGGER .debug (
745+ f"CPU memory usage after engine building: { get_cpu_memory_usage ()} MB"
755746 )
756- assert serialized_engine
757747
758748 _LOGGER .info (
759749 f"Build TRT engine elapsed time: { datetime .now () - build_engine_start_time } "
760750 )
761- _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes } bytes of Memory" )
762-
763751 self .ctx .clear_cpu_weights_reference_holder ()
764752
765753 self ._save_timing_cache (
766754 builder_config , self .compilation_settings .timing_cache_path
767755 )
768756
769- # Engine caching only for refittable engines
770- if (
771- not self .compilation_settings .immutable_weights
772- and self .compilation_settings .cache_built_engines
773- and self .engine_cache is not None
774- ):
775- self ._insert_engine_to_cache (hash_val , serialized_engine )
776-
777- with io .BytesIO () as engine_bytes :
778- engine_bytes .write (serialized_engine )
779- engine_str = engine_bytes .getvalue ()
780-
781757 return TRTInterpreterResult (
782- engine_str ,
758+ cuda_engine ,
783759 self ._input_names ,
784760 self ._output_names ,
785761 self .weight_name_map ,
0 commit comments