From 0d212afc5fa0fe60254546ec925246cb2b3242c7 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Mon, 24 Feb 2025 13:51:43 +0000 Subject: [PATCH 1/3] fix: structured inputs for CudaGraphsTorchTensorRTModule --- .../runtime/_CudaGraphsTorchTensorRTModule.py | 69 +++++++++++++++---- 1 file changed, 55 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index b3ac25bc3a..5af9b11a4b 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -1,16 +1,48 @@ from __future__ import annotations import logging -from typing import List, Optional, Sequence, Tuple +from typing import Any, List, Optional, Sequence, Tuple import torch import torch_tensorrt from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from torch_tensorrt.dynamo import partitioning logger = logging.getLogger(__name__) +def _unflatten_inputs( + flattened_inputs: Sequence[torch_tensorrt.Input], + compiled_module: torch.fx.GraphModule, +) -> Tuple[Any, Any]: + """ + Process inputs using tree_unflatten and tree_map to reconstructe inputs + + Args: + flattened_inputs: Flattened input tensors to process + compiled_module: The compiled GraphModule containing input specifications + + Returns: + Tuple of (args, kwargs) containing reconstructed input tensors + """ + + def convert_input_to_cuda_tensor(input: Any) -> torch.Tensor: + if isinstance(input, torch_tensorrt.Input): + return input.torch_tensor.cuda() + else: + raise RuntimeError("Input is not a torch_tensorrt.Input") + + # Reconstruct the (args, kwargs) structure that was flattened during export + pytree_inputs = tree_unflatten(flattened_inputs, compiled_module._in_spec) + # Apply the tensor creation to the reconstructed structure + processed_inputs = tree_map(convert_input_to_cuda_tensor, pytree_inputs) + + # Since inputs were originally flattened from (args, kwargs), + # processed_inputs is now that same tuple structure + return processed_inputs[0], processed_inputs[1] + + class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc] """This Wrapper runtime module is to record/replay whole cuda graph in sub modules @@ -43,14 +75,15 @@ def warm_up(self) -> None: Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs """ + with torch_tensorrt.logging.errors(): with unset_fake_temporarily(): - inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs] + args, kwargs = _unflatten_inputs(self.inputs, self.compiled_module) s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): for _ in range(3): - self.compiled_module(*inputs_tensor) + self.compiled_module(*args, **kwargs) torch.cuda.current_stream().wait_stream(s) def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: @@ -77,7 +110,10 @@ def __del__(self) -> None: def set_use_output_allocator(self, enable: bool) -> None: self.use_output_allocator_outputs = enable - def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + def forward( + self, *args: Any, **kwargs: Any + ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + inputs, _ = tree_flatten((args, kwargs)) cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() if cudagraphs_enabled: shape_changed = self.validate_input_shapes(inputs) @@ -85,7 +121,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if need_cudagraphs_record: if self.cudagraph: self.cudagraph.reset() - self._input_buffers = [None] * len(self.inputs) + self._input_buffers = [None] * len(inputs) self.is_weight_streaming_set = False # Ensure inputs are available in all scopes and cast symbolic integers to Tensors @@ -98,10 +134,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . for i in inputs ] assert len(contiguous_inputs) == len( - self.inputs - ), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}." + inputs + ), f"Wrong number of inputs, expect {len(inputs)} get {len(contiguous_inputs)}." - for i, _ in enumerate(self.inputs): + for i, _ in enumerate(inputs): if not contiguous_inputs[i].is_cuda: logger.warning( f"Detected input[{i}] is not on a cuda device. " @@ -116,8 +152,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ) assert ( - contiguous_inputs[i].dtype == self.inputs[i].dtype - ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + contiguous_inputs[i].dtype == inputs[i].dtype + ), f"Dtype mismatch for {i}th input. Expect {inputs[i].dtype}, got {contiguous_inputs[i].dtype}." if need_cudagraphs_record: # If cudagraphs is enabled, this memory is reserved for future cudagraph runs @@ -126,6 +162,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . else: self._input_buffers[i].copy_(contiguous_inputs[i]) + if need_cudagraphs_record: + # Reconstruct the original args and kwargs structure from static input buffers + # using the input specification stored during module compilation + args, kwargs = tree_unflatten( + self._input_buffers, self.compiled_module._in_spec + ) + self._caller_stream = torch.cuda.current_stream() if ( self._engine_stream == torch.cuda.default_stream() @@ -139,9 +182,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if need_cudagraphs_record: self.cudagraph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): - self._output_buffers = self.compiled_module( - *self._input_buffers - ) + self._output_buffers = self.compiled_module(*args, **kwargs) self.cudagraph.replay() # type: ignore self._caller_stream.wait_stream(self._engine_stream) @@ -158,4 +199,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . if self.cudagraph: self.cudagraph.reset() self.cudagraph = None - return self.compiled_module(*inputs) + return self.compiled_module(*args, **kwargs) From abc915a96f7042876a896ec306aa232579dca7b7 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Tue, 25 Feb 2025 05:39:03 +0000 Subject: [PATCH 2/3] chore: Update test case --- .../runtime/test_004_weight_streaming.py | 82 +++++++++++-------- 1 file changed, 50 insertions(+), 32 deletions(-) diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index 78522388d1..67d69df381 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -6,6 +6,7 @@ import torch_tensorrt as torchtrt from parameterized import parameterized from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo.utils import prepare_inputs INPUT_SIZE = (64, 100) @@ -302,45 +303,62 @@ def __init__(self): self.layer2 = torch.nn.Linear(128, 64) self.relu = torch.nn.ReLU() - def forward(self, x): + def forward(self, x, b=None, c=None, d=None, e=[]): out = self.layer1(x) + out = out + b + if c is not None: + out = out * c out = self.relu((out + 2.0) * 0.05) + if d is not None: + out = out - d["value"] + d["value2"] out = self.layer2(out) + for n in e: + out += n return out - inputs = torchtrt.Input( - min_shape=(1, 100), - opt_shape=(64, 100), - max_shape=(128, 100), - dtype=torch.float, - name="x", - ) model = SampleModel().eval().cuda() input_list = [] - input_list.append(torch.randn((8, 100)).cuda()) - input_list.append(torch.randn((12, 100)).cuda()) - input_list.append(torch.randn((12, 100)).cuda()) - input_list.append(torch.randn((8, 100)).cuda()) - input_list.append(torch.randn((8, 100)).cuda()) - - dynamic_shapes = ( - { - 0: torch.export.Dim("batch_size", min=1, max=128), - }, - ) - exp_program = torch.export.export( - model, (input_list[0],), dynamic_shapes=dynamic_shapes - ) - + for batch_size in [8, 12, 12, 8, 8]: + args = [torch.rand((batch_size, 100)).to("cuda")] + kwargs = { + "b": torch.rand((1, 128)).to("cuda"), + "d": { + "value": torch.rand(1).to("cuda"), + "value2": torch.tensor(1.2).to("cuda"), + }, + "e": [torch.rand(1).to("cuda"), torch.rand(1).to("cuda")], + } + input_list.append((args, kwargs)) + + kwarg_torchtrt_input = prepare_inputs(input_list[0][1]) + + compile_spec = { + "inputs": [ + torchtrt.Input( + min_shape=(1, 100), + opt_shape=(64, 100), + max_shape=(128, 100), + dtype=torch.float32, + name="x", + ), + ], + "kwarg_inputs": kwarg_torchtrt_input, + "device": torchtrt.Device("cuda:0"), + "enabled_precisions": {torch.float}, + "pass_through_build_failures": True, + "min_block_size": 1, + "ir": "dynamo", + "cache_built_engines": False, + "reuse_cached_engines": False, + "use_explicit_typing": True, + "enable_weight_streaming": True, + "torch_executed_ops": {"torch.ops.aten.mul.Tensor"}, + "use_python_runtime": use_python_runtime, + } + exp_program = torchtrt.dynamo.trace(model, **compile_spec) optimized_model = torchtrt.dynamo.compile( exp_program, - inputs, - min_block_size=1, - pass_through_build_failures=True, - use_explicit_typing=True, - enable_weight_streaming=True, - torch_executed_ops={"torch.ops.aten.mul.Tensor"}, - use_python_runtime=use_python_runtime, + **compile_spec, ) # List of tuples representing different configurations for three features: @@ -361,12 +379,12 @@ def test_trt_model(enable_weight_streaming, optimized_model, input_list): for i in range(len(input_list)): if enable_weight_streaming and i == 4: weight_streaming_ctx.device_budget = int(streamable_budget * 0.6) - out_list.append(optimized_model(input_list[i])) + out_list.append(optimized_model(*input_list[i][0], **input_list[i][1])) return out_list ref_out_list = [] for i in range(len(input_list)): - ref_out_list.append(model(input_list[i])) + ref_out_list.append(model(*input_list[i][0], **input_list[i][1])) pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs( optimized_model From 77c42413d9ded80aa20d922ca455b504fb4ee611 Mon Sep 17 00:00:00 2001 From: kee hyun an Date: Sun, 2 Mar 2025 02:51:58 +0000 Subject: [PATCH 3/3] chore: inputs -> arg_inputs --- tests/py/dynamo/runtime/test_004_weight_streaming.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/py/dynamo/runtime/test_004_weight_streaming.py b/tests/py/dynamo/runtime/test_004_weight_streaming.py index 67d69df381..d453f91c3f 100644 --- a/tests/py/dynamo/runtime/test_004_weight_streaming.py +++ b/tests/py/dynamo/runtime/test_004_weight_streaming.py @@ -291,10 +291,6 @@ def test_weight_streaming_cudagraphs(self, _, use_python_runtime): ("cpp_runtime", False), ] ) - @unittest.skipIf( - os.environ.get("CI_BUILD") == "1", - "Skipping test due to CI resource constraints", - ) def test_runtime_state_change(self, _, use_python_runtime): class SampleModel(torch.nn.Module): def __init__(self): @@ -333,7 +329,7 @@ def forward(self, x, b=None, c=None, d=None, e=[]): kwarg_torchtrt_input = prepare_inputs(input_list[0][1]) compile_spec = { - "inputs": [ + "arg_inputs": [ torchtrt.Input( min_shape=(1, 100), opt_shape=(64, 100),