Skip to content

Commit 98f54fe

Browse files
authored
cpu memory optimization rebased to main (#3868)
1 parent a834c02 commit 98f54fe

File tree

11 files changed

+302
-82
lines changed

11 files changed

+302
-82
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
.. _resource_management:
2+
3+
Resource Management
4+
===================
5+
6+
Overview
7+
--------
8+
9+
Efficient control of CPU and GPU memory is essential for successful model compilation,
10+
especially when working with large models such as LLMs or diffusion models.
11+
Uncontrolled memory growth can cause compilation failures or process termination.
12+
This guide describes the symptoms of excessive memory usage and provides methods
13+
to reduce both CPU and GPU memory consumption.
14+
15+
Memory Usage Control
16+
--------------------
17+
18+
CPU Memory
19+
^^^^^^^^^^
20+
21+
By default, Torch-TensorRT may consume up to **5x** the model size in CPU memory.
22+
This can exceed system limits when compiling large models.
23+
24+
**Common symptoms of high CPU memory usage:**
25+
26+
- Program freeze
27+
- Process terminated by the operating system
28+
29+
**Ways to lower CPU memory usage:**
30+
31+
1. **Enable memory trimming**
32+
33+
Set the following environment variable:
34+
35+
.. code-block:: bash
36+
37+
export TORCHTRT_ENABLE_BUILDER_MALLOC_TRIM=1
38+
39+
This reduces approximately **2x** of redundant model copies, limiting
40+
total CPU memory usage to up to **3x** the model size.
41+
42+
2. **Disable CPU offloading**
43+
44+
In compilation settings, set:
45+
46+
.. code-block:: python
47+
48+
offload_module_to_cpu = False
49+
50+
This removes another **1x** model copy, reducing peak CPU memory
51+
usage to about **2x** the model size.
52+
53+
GPU Memory
54+
^^^^^^^^^^
55+
56+
By default, Torch-TensorRT may consume up to **2x** the model size in GPU memory.
57+
58+
**Common symptoms of high GPU memory usage:**
59+
60+
- CUDA out-of-memory errors
61+
- TensorRT compilation errors
62+
63+
**Ways to lower GPU memory usage:**
64+
65+
1. **Enable offloading to CPU**
66+
67+
In compilation settings, set:
68+
69+
.. code-block:: python
70+
71+
offload_module_to_cpu = True
72+
73+
This shifts one model copy from GPU to CPU memory.
74+
As a result, peak GPU memory usage decreases to about **1x**
75+
the model size, while one more copy of the model will occupy the CPU memory so CPU memory usage increases by roughly **1x**.
76+
77+

docsrc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ Contributor Documentation
233233
contributors/writing_dynamo_aten_lowering_passes
234234
contributors/ts_converters
235235
contributors/useful_links
236+
contributors/resource_management
236237

237238
Indices
238239
----------------

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
)
4343
from torch_tensorrt.dynamo.utils import (
4444
deallocate_module,
45+
get_cpu_memory_usage,
4546
get_flat_args_with_check,
4647
get_output_metadata,
4748
parse_graph_io,
@@ -681,7 +682,7 @@ def compile(
681682
"offload_module_to_cpu": offload_module_to_cpu,
682683
"use_distributed_mode_trace": use_distributed_mode_trace,
683684
}
684-
685+
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
685686
settings = CompilationSettings(**compilation_options)
686687
logger.info("Compilation Settings: %s\n", settings)
687688
exported_program = pre_export_lowering(exported_program, settings)
@@ -695,14 +696,17 @@ def compile(
695696

696697
# Apply lowering on the graph module
697698
gm = post_lowering(gm, settings)
699+
logger.debug(f"CPU memory usage after post_lowering: {get_cpu_memory_usage()} MB")
698700
logger.debug("Lowered Input graph: " + str(gm.graph))
699701

700702
# Move the weights in the state_dict to CPU
701703
if offload_module_to_cpu:
704+
deallocate_module(gm, delete_module=False)
702705
deallocate_module(exported_program.module(), delete_module=False)
703706
logger.info(
704707
"The PyTorch model was moved to the CPU to allocate all GPU memory to TensorRT. To retain the model on the GPU, set offload_module_to_cpu=False"
705708
)
709+
logger.debug(f"CPU memory usage after CPU offload: {get_cpu_memory_usage()} MB")
706710
else:
707711
remaining_memory, total_memory = torch.cuda.mem_get_info()
708712
if remaining_memory < total_memory // 2:
@@ -868,6 +872,11 @@ def preserve_module_specs(
868872
# Iterate over all components that can be accelerated
869873
# Generate the corresponding TRT Module for those
870874

875+
# Here we delete the frozen parameters from the graph module. Note this does not affect the submodules. We are going to delete the frozen parameters from the submodules in the convert_module function.
876+
# This is done to release CPU memory.
877+
for attr in dir(gm):
878+
if attr.startswith("_frozen_param"):
879+
delattr(gm, attr)
871880
for name, _ in partitioned_module.named_children():
872881
submodule = getattr(partitioned_module, name)
873882
# filter on the GraphModule
@@ -1243,7 +1252,7 @@ def convert_exported_program_to_serialized_trt_engine(
12431252

12441253
# Prepare torch_trt inputs
12451254
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
1246-
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
1255+
trt_kwarg_inputs: Optional[dict[str, Any]] = prepare_inputs(kwarg_inputs)
12471256
device = to_torch_tensorrt_device(device)
12481257
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
12491258

@@ -1330,7 +1339,7 @@ def convert_exported_program_to_serialized_trt_engine(
13301339
)
13311340

13321341
flattened_input_list = get_flat_args_with_check(
1333-
exported_program, list(trt_arg_inputs), trt_kwarg_inputs
1342+
exported_program, list(trt_arg_inputs), trt_kwarg_inputs # type: ignore
13341343
)[0]
13351344

13361345
try:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 33 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import gc
2-
import io
32
import logging
43
import os
54
import warnings
@@ -50,7 +49,12 @@
5049
from torch_tensorrt.dynamo.debug._DebuggerConfig import DebuggerConfig
5150
from torch_tensorrt.dynamo.debug._supports_debugger import cls_supports_debugger
5251
from torch_tensorrt.dynamo.observer import Observer
53-
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM, deallocate_module, to_torch_device
52+
from torch_tensorrt.dynamo.utils import (
53+
DYNAMIC_DIM,
54+
deallocate_module,
55+
get_cpu_memory_usage,
56+
to_torch_device,
57+
)
5458
from torch_tensorrt.logging import TRT_LOGGER
5559

5660
_LOGGER: logging.Logger = logging.getLogger(__name__)
@@ -65,7 +69,7 @@ class UnsupportedOperatorException(RuntimeError):
6569

6670

6771
class TRTInterpreterResult(NamedTuple):
68-
serialized_engine: bytes
72+
engine: trt.ICudaEngine
6973
input_names: Sequence[str]
7074
output_names: Sequence[str]
7175
weight_name_map: Optional[dict[Any, Any]]
@@ -512,8 +516,7 @@ def _save_weight_mapping(self) -> None:
512516
_LOGGER.info("Building weight name mapping...")
513517
# Stage 1: Name mapping
514518
torch_device = to_torch_device(self.compilation_settings.device)
515-
self.module.to(torch_device)
516-
sd = self.module.state_dict()
519+
sd = {k: v.to(torch_device) for k, v in self.module.state_dict().items()}
517520
weight_name_map: dict[str, Any] = {}
518521
weight_refit_map = self.ctx.weight_refit_map
519522
constant_mapping = {k: v for k, v in weight_refit_map.items() if v.size == 1}
@@ -591,34 +594,6 @@ def _save_weight_mapping(self) -> None:
591594
gc.collect()
592595
torch.cuda.empty_cache()
593596

594-
@needs_refit # type: ignore[misc]
595-
def _insert_engine_to_cache(self, hash_val: str, serialized_engine: bytes) -> None:
596-
# TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine
597-
# if not self.compilation_settings.strip_engine_weights:
598-
# # set EXCLUDE_WEIGHTS flag to strip weights
599-
# runtime = trt.Runtime(TRT_LOGGER)
600-
# engine = runtime.deserialize_cuda_engine(serialized_engine)
601-
602-
# serialization_config = engine.create_serialization_config()
603-
# serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
604-
# serialized_engine = engine.serialize_with_config(
605-
# serialization_config
606-
# )
607-
608-
# Cache weighted engine for now
609-
self.engine_cache.insert( # type: ignore[union-attr]
610-
hash_val,
611-
(
612-
serialized_engine,
613-
self._input_names,
614-
self._output_names,
615-
self.input_specs,
616-
self.compilation_settings,
617-
self.weight_name_map,
618-
self.ctx.requires_output_allocator,
619-
),
620-
)
621-
622597
@needs_refit # type: ignore[misc]
623598
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
624599
# query the cached TRT engine
@@ -671,7 +646,6 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
671646
settings=self.compilation_settings,
672647
weight_name_map=self.weight_name_map,
673648
)
674-
serialized_engine = engine.serialize()
675649

676650
# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
677651
# # EXCLUDE_WEIGHTS flag must be cleared
@@ -684,12 +658,8 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
684658
# )
685659
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller
686660

687-
with io.BytesIO() as engine_bytes:
688-
engine_bytes.write(serialized_engine)
689-
engine_str = engine_bytes.getvalue()
690-
691661
return TRTInterpreterResult(
692-
engine_str,
662+
engine,
693663
self._input_names,
694664
self._output_names,
695665
self.weight_name_map,
@@ -733,6 +703,9 @@ def run(
733703
return interpreter_result # type: ignore[no-any-return]
734704

735705
self._construct_trt_network_def()
706+
_LOGGER.debug(
707+
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
708+
)
736709

737710
if not self.compilation_settings.immutable_weights:
738711
self._save_weight_mapping()
@@ -750,36 +723,39 @@ def run(
750723
self._create_timing_cache(
751724
builder_config, self.compilation_settings.timing_cache_path
752725
)
753-
serialized_engine = self.builder.build_serialized_network(
754-
self.ctx.net, builder_config
726+
727+
if (
728+
ENABLED_FEATURES.tensorrt_rtx
729+
or self.compilation_settings.version_compatible
730+
):
731+
# TODO: When TRT-RTX matures, change it to build_engine_with_config
732+
serialized_engine = self.builder.build_serialized_network(
733+
self.ctx.net, builder_config
734+
)
735+
runtime = trt.Runtime(TRT_LOGGER)
736+
cuda_engine = runtime.deserialize_cuda_engine(serialized_engine)
737+
else:
738+
739+
cuda_engine = self.builder.build_engine_with_config(
740+
self.ctx.net, builder_config
741+
)
742+
assert cuda_engine
743+
744+
_LOGGER.debug(
745+
f"CPU memory usage after engine building: {get_cpu_memory_usage()} MB"
755746
)
756-
assert serialized_engine
757747

758748
_LOGGER.info(
759749
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
760750
)
761-
_LOGGER.info(f"TRT Engine uses: {serialized_engine.nbytes} bytes of Memory")
762-
763751
self.ctx.clear_cpu_weights_reference_holder()
764752

765753
self._save_timing_cache(
766754
builder_config, self.compilation_settings.timing_cache_path
767755
)
768756

769-
# Engine caching only for refittable engines
770-
if (
771-
not self.compilation_settings.immutable_weights
772-
and self.compilation_settings.cache_built_engines
773-
and self.engine_cache is not None
774-
):
775-
self._insert_engine_to_cache(hash_val, serialized_engine)
776-
777-
with io.BytesIO() as engine_bytes:
778-
engine_bytes.write(serialized_engine)
779-
engine_str = engine_bytes.getvalue()
780-
781757
return TRTInterpreterResult(
782-
engine_str,
758+
cuda_engine,
783759
self._input_names,
784760
self._output_names,
785761
self.weight_name_map,

0 commit comments

Comments
 (0)