11from __future__ import annotations
22
33import logging
4- from typing import List , Optional , Sequence , Tuple
4+ from typing import Any , List , Optional , Sequence , Tuple
55
66import torch
77import torch_tensorrt
88from torch .fx .experimental .proxy_tensor import unset_fake_temporarily
9+ from torch .utils ._pytree import tree_flatten , tree_map , tree_unflatten
910from torch_tensorrt .dynamo import partitioning
1011
1112logger = logging .getLogger (__name__ )
1213
1314
15+ def _unflatten_inputs (
16+ flattened_inputs : Sequence [torch_tensorrt .Input ],
17+ compiled_module : torch .fx .GraphModule ,
18+ ) -> Tuple [Any , Any ]:
19+ """
20+ Process inputs using tree_unflatten and tree_map to reconstructe inputs
21+
22+ Args:
23+ flattened_inputs: Flattened input tensors to process
24+ compiled_module: The compiled GraphModule containing input specifications
25+
26+ Returns:
27+ Tuple of (args, kwargs) containing reconstructed input tensors
28+ """
29+
30+ def convert_input_to_cuda_tensor (input : Any ) -> torch .Tensor :
31+ if isinstance (input , torch_tensorrt .Input ):
32+ return input .torch_tensor .cuda ()
33+ else :
34+ raise RuntimeError ("Input is not a torch_tensorrt.Input" )
35+
36+ # Reconstruct the (args, kwargs) structure that was flattened during export
37+ pytree_inputs = tree_unflatten (flattened_inputs , compiled_module ._in_spec )
38+ # Apply the tensor creation to the reconstructed structure
39+ processed_inputs = tree_map (convert_input_to_cuda_tensor , pytree_inputs )
40+
41+ # Since inputs were originally flattened from (args, kwargs),
42+ # processed_inputs is now that same tuple structure
43+ return processed_inputs [0 ], processed_inputs [1 ]
44+
45+
1446class CudaGraphsTorchTensorRTModule (torch .nn .Module ): # type: ignore[misc]
1547 """This Wrapper runtime module is to record/replay whole cuda graph in sub modules
1648
@@ -42,14 +74,15 @@ def warm_up(self) -> None:
4274 Warm up is necessary to ensure that memory allocations and initializations
4375 are not recorded in cuda graphs
4476 """
77+
4578 with torch_tensorrt .logging .errors ():
4679 with unset_fake_temporarily ():
47- inputs_tensor = [ spec . torch_tensor . cuda () for spec in self .inputs ]
80+ args , kwargs = _unflatten_inputs ( self .inputs , self . compiled_module )
4881 s = torch .cuda .Stream ()
4982 s .wait_stream (torch .cuda .current_stream ())
5083 with torch .cuda .stream (s ):
5184 for _ in range (3 ):
52- self .compiled_module (* inputs_tensor )
85+ self .compiled_module (* args , ** kwargs )
5386 torch .cuda .current_stream ().wait_stream (s )
5487
5588 def validate_input_shapes (self , inputs : Sequence [torch .Tensor ]) -> bool :
@@ -73,15 +106,18 @@ def __del__(self) -> None:
73106 if self .cudagraph :
74107 self .cudagraph .reset ()
75108
76- def forward (self , * inputs : torch .Tensor ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
109+ def forward (
110+ self , * args : Any , ** kwargs : Any
111+ ) -> torch .Tensor | Tuple [torch .Tensor , ...]:
112+ inputs , _ = tree_flatten ((args , kwargs ))
77113 cudagraphs_enabled = torch_tensorrt .runtime .get_whole_cudagraphs_mode ()
78114 if cudagraphs_enabled :
79115 shape_changed = self .validate_input_shapes (inputs )
80116 need_cudagraphs_record = shape_changed or self .is_weight_streaming_set
81117 if need_cudagraphs_record :
82118 if self .cudagraph :
83119 self .cudagraph .reset ()
84- self ._input_buffers = [None ] * len (self . inputs )
120+ self ._input_buffers = [None ] * len (inputs )
85121
86122 self .is_weight_streaming_set = False
87123 # Ensure inputs are available in all scopes and cast symbolic integers to Tensors
@@ -94,10 +130,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
94130 for i in inputs
95131 ]
96132 assert len (contiguous_inputs ) == len (
97- self . inputs
98- ), f"Wrong number of inputs, expect { len (self . inputs )} get { len (contiguous_inputs )} ."
133+ inputs
134+ ), f"Wrong number of inputs, expect { len (inputs )} get { len (contiguous_inputs )} ."
99135
100- for i , _ in enumerate (self . inputs ):
136+ for i , _ in enumerate (inputs ):
101137 if not contiguous_inputs [i ].is_cuda :
102138 logger .warning (
103139 f"Detected input[{ i } ] is not on a cuda device. "
@@ -112,15 +148,22 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
112148 )
113149
114150 assert (
115- contiguous_inputs [i ].dtype == self .inputs [i ].dtype
116- ), f"Dtype mismatch for { i } th input. Expect { self .inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
151+ contiguous_inputs [i ].dtype == inputs [i ].dtype
152+ ), f"Dtype mismatch for { i } th input. Expect { inputs [i ].dtype } , got { contiguous_inputs [i ].dtype } ."
153+
154+ if need_cudagraphs_record :
155+ # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
156+ # Clone is required to avoid re-using user-provided GPU memory
157+ self ._input_buffers [i ] = contiguous_inputs [i ].clone ()
158+ else :
159+ self ._input_buffers [i ].copy_ (contiguous_inputs [i ])
117160
118161 if need_cudagraphs_record :
119- # If cudagraphs is enabled, this memory is reserved for future cudagraph runs
120- # Clone is required to avoid re-using user-provided GPU memory
121- self . _input_buffers [ i ] = contiguous_inputs [ i ]. clone ()
122- else :
123- self . _input_buffers [ i ]. copy_ ( contiguous_inputs [ i ] )
162+ # Reconstruct the original args and kwargs structure from static input buffers
163+ # using the input specification stored during module compilation
164+ args , kwargs = tree_unflatten (
165+ self . _input_buffers , self . compiled_module . _in_spec
166+ )
124167
125168 self ._caller_stream = torch .cuda .current_stream ()
126169 if (
@@ -135,9 +178,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
135178 if need_cudagraphs_record :
136179 self .cudagraph = torch .cuda .CUDAGraph ()
137180 with torch .cuda .graph (self .cudagraph , stream = self ._engine_stream ):
138- self ._output_buffers = self .compiled_module (
139- * self ._input_buffers
140- )
181+ self ._output_buffers = self .compiled_module (* args , ** kwargs )
141182
142183 self .cudagraph .replay () # type: ignore
143184 self ._caller_stream .wait_stream (self ._engine_stream )
@@ -154,4 +195,4 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
154195 if self .cudagraph :
155196 self .cudagraph .reset ()
156197 self .cudagraph = None
157- return self .compiled_module (* inputs )
198+ return self .compiled_module (* args , ** kwargs )
0 commit comments