|
12 | 12 | import tensorrt as trt |
13 | 13 | import torch |
14 | 14 | from torch._subclasses.fake_tensor import FakeTensor |
| 15 | +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily |
15 | 16 | from torch_tensorrt._Device import Device |
16 | 17 | from torch_tensorrt._enums import dtype |
17 | 18 | from torch_tensorrt._features import ENABLED_FEATURES |
@@ -256,48 +257,54 @@ def prepare_inputs( |
256 | 257 | inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any], |
257 | 258 | disable_memory_format_check: bool = False, |
258 | 259 | ) -> Any: |
259 | | - if inputs is None: |
260 | | - return None |
261 | | - |
262 | | - elif isinstance(inputs, Input): |
263 | | - return inputs |
| 260 | + """ |
| 261 | + We take a nested group of torch.Tensors or scalars and convert them into torchtrt.Input's |
| 262 | + """ |
| 263 | + # Any tensors created inside this call will be FakeTensors if it's inside a torch.compile session |
| 264 | + # So, we disable fake mode temporarily. |
| 265 | + with unset_fake_temporarily(): |
| 266 | + if inputs is None: |
| 267 | + return None |
264 | 268 |
|
265 | | - elif isinstance(inputs, (torch.Tensor, int, float, bool)): |
266 | | - return Input.from_tensor( |
267 | | - torch.tensor(inputs), |
268 | | - disable_memory_format_check=disable_memory_format_check, |
269 | | - ) |
| 269 | + elif isinstance(inputs, Input): |
| 270 | + return inputs |
270 | 271 |
|
271 | | - elif isinstance(inputs, (list, tuple)): |
272 | | - torchtrt_input_list = [] |
273 | | - for input_obj in inputs: |
274 | | - torchtrt_input = prepare_inputs( |
275 | | - input_obj, disable_memory_format_check=disable_memory_format_check |
| 272 | + elif isinstance(inputs, (torch.Tensor, int, float, bool)): |
| 273 | + return Input.from_tensor( |
| 274 | + torch.tensor(inputs), |
| 275 | + disable_memory_format_check=disable_memory_format_check, |
276 | 276 | ) |
277 | | - torchtrt_input_list.append(torchtrt_input) |
278 | | - |
279 | | - return ( |
280 | | - torchtrt_input_list |
281 | | - if isinstance(inputs, list) |
282 | | - else tuple(torchtrt_input_list) |
283 | | - ) |
284 | 277 |
|
285 | | - elif isinstance(inputs, dict): |
286 | | - torchtrt_inputs_dict: Dict[Any, Any] = dict() |
| 278 | + elif isinstance(inputs, (list, tuple)): |
| 279 | + torchtrt_input_list = [] |
| 280 | + for input_obj in inputs: |
| 281 | + torchtrt_input = prepare_inputs( |
| 282 | + input_obj, disable_memory_format_check=disable_memory_format_check |
| 283 | + ) |
| 284 | + torchtrt_input_list.append(torchtrt_input) |
287 | 285 |
|
288 | | - for key, input_obj in inputs.items(): |
289 | | - torchtrt_input = prepare_inputs( |
290 | | - input_obj, disable_memory_format_check=disable_memory_format_check |
| 286 | + return ( |
| 287 | + torchtrt_input_list |
| 288 | + if isinstance(inputs, list) |
| 289 | + else tuple(torchtrt_input_list) |
291 | 290 | ) |
292 | | - torchtrt_inputs_dict[key] = torchtrt_input |
293 | 291 |
|
294 | | - return torchtrt_inputs_dict |
| 292 | + elif isinstance(inputs, dict): |
| 293 | + torchtrt_inputs_dict: Dict[Any, Any] = dict() |
295 | 294 |
|
296 | | - else: |
297 | | - raise ValueError( |
298 | | - f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " |
299 | | - + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" |
300 | | - ) |
| 295 | + for key, input_obj in inputs.items(): |
| 296 | + torchtrt_input = prepare_inputs( |
| 297 | + input_obj, disable_memory_format_check=disable_memory_format_check |
| 298 | + ) |
| 299 | + torchtrt_inputs_dict[key] = torchtrt_input |
| 300 | + |
| 301 | + return torchtrt_inputs_dict |
| 302 | + |
| 303 | + else: |
| 304 | + raise ValueError( |
| 305 | + f"Invalid input type {type(inputs)} encountered in the dynamo_compile input parsing. " |
| 306 | + + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" |
| 307 | + ) |
301 | 308 |
|
302 | 309 |
|
303 | 310 | def parse_complex_tensor_structs( |
|
0 commit comments