-
Notifications
You must be signed in to change notification settings - Fork 371
Cpu memory optimization #3602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: moe-support
Are you sure you want to change the base?
Cpu memory optimization #3602
Changes from all commits
2540824
c7f8b12
711446c
35d5861
503f320
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError): | |
|
||
|
||
class TRTInterpreterResult(NamedTuple): | ||
serialized_engine: bytes | ||
engine: trt.ICudaEngine | bytes | ||
input_names: Sequence[str] | ||
output_names: Sequence[str] | ||
weight_name_map: Optional[dict[Any, Any]] | ||
|
@@ -512,8 +512,7 @@ def _save_weight_mapping(self) -> None: | |
_LOGGER.info("Building weight name mapping...") | ||
# Stage 1: Name mapping | ||
torch_device = to_torch_device(self.compilation_settings.device) | ||
self.module.to(torch_device) | ||
sd = self.module.state_dict() | ||
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()} | ||
weight_name_map: dict[str, Any] = {} | ||
weight_refit_map = self.ctx.weight_refit_map | ||
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1} | ||
|
@@ -592,13 +591,11 @@ def _save_weight_mapping(self) -> None: | |
torch.cuda.empty_cache() | ||
|
||
@needs_refit # type: ignore[misc] | ||
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None: | ||
def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None: | ||
serialized_engine = engine.serialize() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we are doing this don't we end up paying the serialization cost again? |
||
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine | ||
# if not self.compilation_settings.strip_engine_weights: | ||
# # set EXCLUDE_WEIGHTS flag to strip weights | ||
# runtime = trt.Runtime(TRT_LOGGER) | ||
# engine = runtime.deserialize_cuda_engine(serialized_engine) | ||
|
||
# serialization_config = engine.create_serialization_config() | ||
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) | ||
# serialized_engine = engine.serialize_with_config( | ||
|
@@ -750,16 +747,15 @@ def run( | |
self._create_timing_cache( | ||
builder_config, self.compilation_settings.timing_cache_path | ||
) | ||
serialized_engine = self.builder.build_serialized_network( | ||
|
||
cuda_engine = self.builder.build_engine_with_config( | ||
self.ctx.net, builder_config | ||
) | ||
assert serialized_engine | ||
assert cuda_engine | ||
|
||
_LOGGER.info( | ||
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}" | ||
) | ||
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") | ||
|
||
self.ctx.clear_cpu_weights_reference_holder() | ||
|
||
self._save_timing_cache( | ||
|
@@ -772,19 +768,31 @@ def run( | |
and self.compilation_settings.cache_built_engines | ||
and self.engine_cache is not None | ||
): | ||
self._insert_engine_to_cache(hash_val, serialized_engine) | ||
|
||
with io.BytesIO() as engine_bytes: | ||
engine_bytes.write(serialized_engine) | ||
engine_str = engine_bytes.getvalue() | ||
|
||
return TRTInterpreterResult( | ||
engine_str, | ||
self._input_names, | ||
self._output_names, | ||
self.weight_name_map, | ||
self.ctx.requires_output_allocator, | ||
) | ||
self._insert_engine_to_cache(hash_val, cuda_engine) | ||
|
||
if self.compilation_settings.use_python_runtime: | ||
return TRTInterpreterResult( | ||
cuda_engine, | ||
self._input_names, | ||
self._output_names, | ||
self.weight_name_map, | ||
self.ctx.requires_output_allocator, | ||
) | ||
else: | ||
serialized_engine = cuda_engine.serialize() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just have the TRTInterpreter return a live engine in all cases |
||
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory") | ||
|
||
with io.BytesIO() as engine_bytes: | ||
engine_bytes.write(serialized_engine) | ||
engine_str = engine_bytes.getvalue() | ||
|
||
return TRTInterpreterResult( | ||
engine_str, | ||
self._input_names, | ||
self._output_names, | ||
self.weight_name_map, | ||
self.ctx.requires_output_allocator, | ||
) | ||
|
||
def run_node(self, n: torch.fx.Node) -> torch.fx.Node: | ||
self._cur_node_name = get_node_name(n) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -89,12 +89,18 @@ def convert_module( | |
module, inputs, settings, engine_cache=engine_cache | ||
) | ||
|
||
rt_cls = PythonTorchTensorRTModule | ||
|
||
if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime: | ||
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule | ||
|
||
rt_cls = TorchTensorRTModule | ||
return TorchTensorRTModule( | ||
serialized_engine=interpreter_result.engine, | ||
input_binding_names=list(interpreter_result.input_names), | ||
output_binding_names=list(interpreter_result.output_names), | ||
name=name, | ||
settings=settings, | ||
weight_name_map=interpreter_result.weight_name_map, | ||
requires_output_allocator=interpreter_result.requires_output_allocator, | ||
) | ||
|
||
elif ( | ||
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime | ||
|
@@ -103,8 +109,8 @@ def convert_module( | |
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available" | ||
) | ||
|
||
return rt_cls( | ||
serialized_engine=interpreter_result.serialized_engine, | ||
return PythonTorchTensorRTModule( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be the rt_class still but there should be some preprocessing modifying the arguments |
||
cuda_engine=interpreter_result.engine, | ||
input_binding_names=list(interpreter_result.input_names), | ||
output_binding_names=list(interpreter_result.output_names), | ||
name=name, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig | ||
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger | ||
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM | ||
from torch_tensorrt.logging import TRT_LOGGER | ||
from torch_tensorrt.runtime._utils import ( | ||
_is_switch_required, | ||
_select_rt_device, | ||
|
@@ -123,7 +122,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc] | |
|
||
def __init__( | ||
self, | ||
serialized_engine: Optional[bytes] = None, | ||
cuda_engine: trt.ICudaEngine = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The API should support both since uses may use this runtime independently, its just now we either support There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same thing for the CPP runtime |
||
input_binding_names: Optional[List[str]] = None, | ||
output_binding_names: Optional[List[str]] = None, | ||
*, | ||
|
@@ -182,7 +181,7 @@ def __init__( | |
# Unused currently - to be used by Dynamic Shape support implementation | ||
self.memory_pool = None | ||
|
||
self.serialized_engine = serialized_engine | ||
self.engine = cuda_engine | ||
self.input_names = ( | ||
input_binding_names if input_binding_names is not None else [] | ||
) | ||
|
@@ -204,7 +203,6 @@ def __init__( | |
else False | ||
) | ||
self.settings = settings | ||
self.engine = None | ||
self.weight_name_map = weight_name_map | ||
self.target_platform = Platform.current_platform() | ||
self.runtime_states = TorchTRTRuntimeStates( | ||
|
@@ -219,7 +217,7 @@ def __init__( | |
self.output_allocator: Optional[DynamicOutputAllocator] = None | ||
self.use_output_allocator_outputs = False | ||
|
||
if self.serialized_engine is not None and not self.settings.lazy_engine_init: | ||
if self.engine is not None and not self.settings.lazy_engine_init: | ||
self.setup_engine() | ||
|
||
def get_streamable_device_memory_budget(self) -> Any: | ||
|
@@ -265,8 +263,6 @@ def setup_engine(self) -> None: | |
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})" | ||
|
||
self.initialized = True | ||
runtime = trt.Runtime(TRT_LOGGER) | ||
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine) | ||
if self.settings.enable_weight_streaming: | ||
self.set_default_device_memory_budget() | ||
self.context = self.engine.create_execution_context() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@narendasan Should we set it to be engine when it is python runtime and serialized engine if it is cpp runtime? In this way we can do serialization in Interpreter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can just be
trt.ICudaEngine
and post processing for the cpp runtimeThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Like I would have the CPP runtime determine if the engine is live or serialized, and if its live serialize the engine.