11from __future__ import annotations
22
33import logging
4- from typing import List , Optional , Sequence , Tuple
4+ from typing import Any , List , Optional , Sequence , Tuple
55
66import torch
77import torch_tensorrt
88from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
9+ from torch .utils ._pytree import tree_flatten , tree_map , tree_unflatten
910from torch_tensorrt .dynamo import partitioning
1011
1112logger = logging .getLogger (__name__ )
1213
1314
15+ def _unflatten_inputs (
16+ flattened_inputs : Sequence [torch_tensorrt .Input ],
17+ compiled_module : torch .fx .GraphModule ,
18+ ) -> Tuple [Any , Any ]:
19+ """
20+ Process inputs using tree_unflatten and tree_map to reconstructe inputs
21+
22+ Args:
23+ flattened_inputs: Flattened input tensors to process
24+ compiled_module: The compiled GraphModule containing input specifications
25+
26+ Returns:
27+ Tuple of (args, kwargs) containing reconstructed input tensors
28+ """
29+
30+ def convert_input_to_cuda_tensor (input : Any ) -> torch .Tensor :
31+ if isinstance (input , torch_tensorrt .Input ):
32+ return input .torch_tensor .cuda ()
33+ else :
34+ raise RuntimeError ("Input is not a torch_tensorrt.Input" )
35+
36+ # Reconstruct the (args, kwargs) structure that was flattened during export
37+ pytree_inputs = tree_unflatten (flattened_inputs , compiled_module ._in_spec )
38+ # Apply the tensor creation to the reconstructed structure
39+ processed_inputs = tree_map (convert_input_to_cuda_tensor , pytree_inputs )
40+
41+ # Since inputs were originally flattened from (args, kwargs),
42+ # processed_inputs is now that same tuple structure
43+ return processed_inputs [0 ], processed_inputs [1 ]
44+
45+
1446class CudaGraphsTorchTensorRTModule (torch .nn .Module ): # type: ignore[misc]
1547 """This Wrapper runtime module is to record/replay whole cuda graph in sub modules
1648
@@ -43,14 +75,15 @@ def warm_up(self) -> None:
4375 Warm up is necessary to ensure that memory allocations and initializations
4476 are not recorded in cuda graphs
4577 """
78+
4679 with torch_tensorrt .logging .errors ():
4780 with unset_fake_temporarily ():
48- inputs_tensor = [ spec . torch_tensor . cuda () for spec in self .inputs ]
81+ args , kwargs = _unflatten_inputs ( self .inputs , self . compiled_module )
4982 s = torch .cuda .Stream ()
5083 s .wait_stream (torch .cuda .current_stream ())
5184 with torch .cuda .stream (s ):
5285 for _ in range (3 ):
53- self .compiled_module (* inputs_tensor )
86+ self .compiled_module (* args , ** kwargs )
5487 torch .cuda .current_stream ().wait_stream (s )
5588
5689 def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
@@ -77,15 +110,18 @@ def __del__(self) -> None:
77110 def set_use_output_allocator (self , enable : bool ) -> None :
78111 self .use_output_allocator_outputs = enable
79112
80- def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
113+ def forward (
114+ self , * args : Any , ** kwargs : Any
115+ ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
116+ inputs , _ = tree_flatten ((args , kwargs ))
81117 cudagraphs_enabled = torch_tensorrt .runtime .get_whole_cudagraphs_mode ()
82118 if cudagraphs_enabled :
83119 shape_changed = self .validate_input_shapes (inputs )
84120 need_cudagraphs_record = shape_changed or self .is_weight_streaming_set
85121 if need_cudagraphs_record :
86122 if self .cudagraph :
87123 self .cudagraph .reset ()
88- self ._input_buffers = [None ] * len (self . inputs )
124+ self ._input_buffers = [None ] * len (inputs )
89125
90126 self .is_weight_streaming_set = False
91127 # Ensure inputs are available in all scopes and cast symbolic integers to Tensors
@@ -98,10 +134,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
98134 for i in inputs
99135 ]
100136 assert len (contiguous_inputs ) == len (
101- self . inputs
102- ), f"Wrong number of inputs, expect { len (self . inputs )} get { len (contiguous_inputs )} ."
137+ inputs
138+ ), f"Wrong number of inputs, expect { len (inputs )} get { len (contiguous_inputs )} ."
103139
104- for i , _ in enumerate (self . inputs ):
140+ for i , _ in enumerate (inputs ):
105141 if not contiguous_inputs [i ].is_cuda :
106142 logger .warning (
107143 f"Detected input[{ i } ] is not on a cuda device. "
@@ -116,8 +152,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
116152 )
117153
118154 assert (
119- contiguous_inputs [i ].dtype == self . inputs [i ].dtype
120- ), f"Dtype mismatch for { i } th input. Expect { self . inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
155+ contiguous_inputs [i ].dtype == inputs [i ].dtype
156+ ), f"Dtype mismatch for { i } th input. Expect { inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
121157
122158 if need_cudagraphs_record :
123159 # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
@@ -126,6 +162,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
126162 else :
127163 self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
128164
165+ if need_cudagraphs_record :
166+ # Reconstruct the original args and kwargs structure from static input buffers
167+ # using the input specification stored during module compilation
168+ args , kwargs = tree_unflatten (
169+ self ._input_buffers , self .compiled_module ._in_spec
170+ )
171+
129172 self ._caller_stream = torch .cuda .current_stream ()
130173 if (
131174 self ._engine_stream == torch .cuda .default_stream ()
@@ -139,9 +182,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
139182 if need_cudagraphs_record :
140183 self .cudagraph = torch .cuda .CUDAGraph ()
141184 with torch .cuda .graph (self .cudagraph , stream = self ._engine_stream ):
142- self ._output_buffers = self .compiled_module (
143- * self ._input_buffers
144- )
185+ self ._output_buffers = self .compiled_module (* args , ** kwargs )
145186
146187 self .cudagraph .replay () # type: ignore
147188 self ._caller_stream .wait_stream (self ._engine_stream )
@@ -158,4 +199,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
158199 if self .cudagraph :
159200 self .cudagraph .reset ()
160201 self .cudagraph = None
161- return self .compiled_module (* inputs )
202+ return self .compiled_module (* args , ** kwargs )
0 commit comments