Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def compile(

# Move the weights in the state_dict to CPU
if offload_module_to_cpu:
deallocate_module(gm, delete_module=False)
deallocate_module(exported_program.module(), delete_module=False)
logger.info(
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
Expand Down
56 changes: 32 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class UnsupportedOperatorException(RuntimeError):


class TRTInterpreterResult(NamedTuple):
serialized_engine: bytes
engine: trt.ICudaEngine | bytes
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan Should we set it to be engine when it is python runtime and serialized engine if it is cpp runtime? In this way we can do serialization in Interpreter.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can just be trt.ICudaEngineand post processing for the cpp runtime

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like I would have the CPP runtime determine if the engine is live or serialized, and if its live serialize the engine.

input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
Expand Down Expand Up @@ -512,8 +512,7 @@ def _save_weight_mapping(self) -> None:
_LOGGER.info("Building weight name mapping...")
# Stage 1: Name mapping
torch_device = to_torch_device(self.compilation_settings.device)
self.module.to(torch_device)
sd = self.module.state_dict()
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
weight_name_map: dict[str, Any] = {}
weight_refit_map = self.ctx.weight_refit_map
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
Expand Down Expand Up @@ -592,13 +591,11 @@ def _save_weight_mapping(self) -> None:
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
def _insert_engine_to_cache(self, hash_val: str, engine: bytes) -> None:
serialized_engine = engine.serialize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we are doing this don't we end up paying the serialization cost again?

# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
# if not self.compilation_settings.strip_engine_weights:
# # set EXCLUDE_WEIGHTS flag to strip weights
# runtime = trt.Runtime(TRT_LOGGER)
# engine = runtime.deserialize_cuda_engine(serialized_engine)

# serialization_config = engine.create_serialization_config()
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
# serialized_engine = engine.serialize_with_config(
Expand Down Expand Up @@ -750,16 +747,15 @@ def run(
self._create_timing_cache(
builder_config, self.compilation_settings.timing_cache_path
)
serialized_engine = self.builder.build_serialized_network(

cuda_engine = self.builder.build_engine_with_config(
self.ctx.net, builder_config
)
assert serialized_engine
assert cuda_engine

_LOGGER.info(
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
)
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

self.ctx.clear_cpu_weights_reference_holder()

self._save_timing_cache(
Expand All @@ -772,19 +768,31 @@ def run(
and self.compilation_settings.cache_built_engines
and self.engine_cache is not None
):
self._insert_engine_to_cache(hash_val, serialized_engine)

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
self.ctx.requires_output_allocator,
)
self._insert_engine_to_cache(hash_val, cuda_engine)

if self.compilation_settings.use_python_runtime:
return TRTInterpreterResult(
cuda_engine,
self._input_names,
self._output_names,
self.weight_name_map,
self.ctx.requires_output_allocator,
)
else:
serialized_engine = cuda_engine.serialize()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just have the TRTInterpreter return a live engine in all cases

_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

return TRTInterpreterResult(
engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
self.ctx.requires_output_allocator,
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
self._cur_node_name = get_node_name(n)
Expand Down
16 changes: 11 additions & 5 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,18 @@ def convert_module(
module, inputs, settings, engine_cache=engine_cache
)

rt_cls = PythonTorchTensorRTModule

if ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime:
from torch_tensorrt.dynamo.runtime import TorchTensorRTModule

rt_cls = TorchTensorRTModule
return TorchTensorRTModule(
serialized_engine=interpreter_result.engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
name=name,
settings=settings,
weight_name_map=interpreter_result.weight_name_map,
requires_output_allocator=interpreter_result.requires_output_allocator,
)

elif (
not ENABLED_FEATURES.torch_tensorrt_runtime and not settings.use_python_runtime
Expand All @@ -103,8 +109,8 @@ def convert_module(
"Since Torch-TensorRT runtime is not available, using Python Runtime, some features may not be available"
)

return rt_cls(
serialized_engine=interpreter_result.serialized_engine,
return PythonTorchTensorRTModule(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be the rt_class still but there should be some preprocessing modifying the arguments

cuda_engine=interpreter_result.engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
name=name,
Expand Down
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def constant_fold(
# For TRT INetwork construction the constants are moved to CPU in get_attr call.
for node, constant in cf.node_replacements.items():
replace_node_with_constant(
gm, node, torch.nn.Parameter(constant, requires_grad=False)
gm,
node,
torch.nn.Parameter(constant.cpu().contiguous(), requires_grad=False),
)

erased_params = []
Expand Down
10 changes: 3 additions & 7 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER
from torch_tensorrt.runtime._utils import (
_is_switch_required,
_select_rt_device,
Expand Down Expand Up @@ -123,7 +122,7 @@ class PythonTorchTensorRTModule(Module): # type: ignore[misc]

def __init__(
self,
serialized_engine: Optional[bytes] = None,
cuda_engine: trt.ICudaEngine = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API should support both since uses may use this runtime independently, its just now we either support cuda_engine or serialized_engine

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same thing for the CPP runtime

input_binding_names: Optional[List[str]] = None,
output_binding_names: Optional[List[str]] = None,
*,
Expand Down Expand Up @@ -182,7 +181,7 @@ def __init__(
# Unused currently - to be used by Dynamic Shape support implementation
self.memory_pool = None

self.serialized_engine = serialized_engine
self.engine = cuda_engine
self.input_names = (
input_binding_names if input_binding_names is not None else []
)
Expand All @@ -204,7 +203,6 @@ def __init__(
else False
)
self.settings = settings
self.engine = None
self.weight_name_map = weight_name_map
self.target_platform = Platform.current_platform()
self.runtime_states = TorchTRTRuntimeStates(
Expand All @@ -219,7 +217,7 @@ def __init__(
self.output_allocator: Optional[DynamicOutputAllocator] = None
self.use_output_allocator_outputs = False

if self.serialized_engine is not None and not self.settings.lazy_engine_init:
if self.engine is not None and not self.settings.lazy_engine_init:
self.setup_engine()

def get_streamable_device_memory_budget(self) -> Any:
Expand Down Expand Up @@ -265,8 +263,6 @@ def setup_engine(self) -> None:
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"

self.initialized = True
runtime = trt.Runtime(TRT_LOGGER)
self.engine = runtime.deserialize_cuda_engine(self.serialized_engine)
if self.settings.enable_weight_streaming:
self.set_default_device_memory_budget()
self.context = self.engine.create_execution_context()
Expand Down
Loading