|
1 | | -import torch |
2 | | -import logging |
3 | | -import collections.abc |
4 | | -import torch_tensorrt |
5 | | -from functools import partial |
6 | | - |
7 | | -from typing import Any, Optional, Sequence |
8 | | -from torch_tensorrt import EngineCapability, Device |
9 | | -from torch_tensorrt.fx.utils import LowerPrecision |
10 | | - |
11 | | -from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device |
12 | | -from torch_tensorrt.dynamo.backend.backends import torch_tensorrt_backend |
13 | | -from torch_tensorrt.dynamo._defaults import ( |
14 | | - PRECISION, |
15 | | - DEBUG, |
16 | | - WORKSPACE_SIZE, |
17 | | - MIN_BLOCK_SIZE, |
18 | | - PASS_THROUGH_BUILD_FAILURES, |
19 | | - MAX_AUX_STREAMS, |
20 | | - VERSION_COMPATIBLE, |
21 | | - OPTIMIZATION_LEVEL, |
22 | | - USE_PYTHON_RUNTIME, |
23 | | -) |
24 | | - |
25 | | - |
26 | | -logger = logging.getLogger(__name__) |
27 | | - |
28 | | - |
29 | | -def compile( |
30 | | - gm: torch.nn.Module, |
31 | | - inputs: Any, |
32 | | - *, |
33 | | - device=Device._current_device(), |
34 | | - disable_tf32=False, |
35 | | - sparse_weights=False, |
36 | | - enabled_precisions=set(), |
37 | | - refit=False, |
38 | | - debug=DEBUG, |
39 | | - capability=EngineCapability.default, |
40 | | - num_avg_timing_iters=1, |
41 | | - workspace_size=WORKSPACE_SIZE, |
42 | | - dla_sram_size=1048576, |
43 | | - dla_local_dram_size=1073741824, |
44 | | - dla_global_dram_size=536870912, |
45 | | - calibrator=None, |
46 | | - truncate_long_and_double=False, |
47 | | - require_full_compilation=False, |
48 | | - min_block_size=MIN_BLOCK_SIZE, |
49 | | - torch_executed_ops=[], |
50 | | - torch_executed_modules=[], |
51 | | - pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES, |
52 | | - max_aux_streams=MAX_AUX_STREAMS, |
53 | | - version_compatible=VERSION_COMPATIBLE, |
54 | | - optimization_level=OPTIMIZATION_LEVEL, |
55 | | - use_python_runtime=USE_PYTHON_RUNTIME, |
56 | | - **kwargs, |
57 | | -): |
58 | | - if debug: |
59 | | - logger.setLevel(logging.DEBUG) |
60 | | - |
61 | | - logger.warn( |
62 | | - "The Dynamo backend is an experimental feature, for which only the " |
63 | | - + "following arguments are supported: " |
64 | | - + "{enabled_precisions, debug, workspace_size, min_block_size, " |
65 | | - + "torch_executed_ops, pass_through_build_failures}" |
66 | | - ) |
67 | | - |
68 | | - if not isinstance(inputs, collections.abc.Sequence): |
69 | | - inputs = [inputs] |
70 | | - |
71 | | - inputs = prepare_inputs(inputs, prepare_device(device)) |
72 | | - |
73 | | - if not isinstance(enabled_precisions, collections.abc.Collection): |
74 | | - enabled_precisions = [enabled_precisions] |
75 | | - |
76 | | - # Parse user-specified enabled precisions |
77 | | - if ( |
78 | | - torch.float16 in enabled_precisions |
79 | | - or torch_tensorrt.dtype.half in enabled_precisions |
80 | | - ): |
81 | | - lower_precision = LowerPrecision.FP16 |
82 | | - elif ( |
83 | | - torch.float32 in enabled_precisions |
84 | | - or torch_tensorrt.dtype.float in enabled_precisions |
85 | | - ): |
86 | | - lower_precision = LowerPrecision.FP32 |
87 | | - elif len(enabled_precisions) == 0: |
88 | | - logger.info(f"No precision specified, defaulting to {PRECISION}") |
89 | | - lower_precision = PRECISION |
90 | | - else: |
91 | | - raise ValueError( |
92 | | - f"Precision {enabled_precisions} not supported in the Dynamo Path" |
93 | | - ) |
94 | | - |
95 | | - custom_backend = create_backend( |
96 | | - precision=lower_precision, |
97 | | - debug=debug, |
98 | | - workspace_size=workspace_size, |
99 | | - min_block_size=min_block_size, |
100 | | - torch_executed_ops=torch_executed_ops, |
101 | | - pass_through_build_failures=pass_through_build_failures, |
102 | | - max_aux_streams=max_aux_streams, |
103 | | - version_compatible=version_compatible, |
104 | | - optimization_level=optimization_level, |
105 | | - use_python_runtime=use_python_runtime, |
106 | | - **kwargs, |
107 | | - ) |
108 | | - |
109 | | - model = torch.compile(gm, backend=custom_backend) |
110 | | - |
111 | | - # Ensure compilation occurs by calling the function with provided inputs |
112 | | - model(*inputs) |
113 | | - |
114 | | - return model |
115 | | - |
116 | | - |
117 | | -from torch_tensorrt.fx.utils import LowerPrecision |
118 | | - |
119 | | -logger = logging.getLogger(__name__) |
120 | | - |
121 | | - |
122 | | -def create_backend( |
123 | | - precision: LowerPrecision = PRECISION, |
124 | | - debug: bool = DEBUG, |
125 | | - workspace_size: int = WORKSPACE_SIZE, |
126 | | - min_block_size: int = MIN_BLOCK_SIZE, |
127 | | - torch_executed_ops: Sequence[str] = set(), |
128 | | - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, |
129 | | - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, |
130 | | - version_compatible: bool = VERSION_COMPATIBLE, |
131 | | - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, |
132 | | - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, |
133 | | - **kwargs, |
134 | | -): |
135 | | - """Create torch.compile backend given specified arguments |
136 | | -
|
137 | | - Args: |
138 | | - precision: Model Layer precision |
139 | | - debug: Whether to print out verbose debugging information |
140 | | - workspace_size: Workspace TRT is allowed to use for the module (0 is default) |
141 | | - min_block_size: Minimum number of operators per TRT-Engine Block |
142 | | - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage |
143 | | - pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) |
144 | | - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine |
145 | | - version_compatible: Provide version forward-compatibility for engine plan files |
146 | | - optimization_level: Builder optimization 0-5, higher levels imply longer build time, |
147 | | - searching for more optimization options. TRT defaults to 3 |
148 | | - use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime |
149 | | - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the |
150 | | - argument as None |
151 | | - Returns: |
152 | | - Backend for torch.compile |
153 | | - """ |
154 | | - return partial( |
155 | | - torch_tensorrt_backend, |
156 | | - debug=debug, |
157 | | - precision=precision, |
158 | | - workspace_size=workspace_size, |
159 | | - min_block_size=min_block_size, |
160 | | - torch_executed_ops=torch_executed_ops, |
161 | | - pass_through_build_failures=pass_through_build_failures, |
162 | | - max_aux_streams=max_aux_streams, |
163 | | - version_compatible=version_compatible, |
164 | | - optimization_level=optimization_level, |
165 | | - use_python_runtime=use_python_runtime, |
166 | | - **kwargs, |
167 | | - ) |
| 1 | +from .backends import torch_tensorrt_backend |
| 2 | +from .compile import compile |
0 commit comments