|
6 | 6 |
|
7 | 7 | from typing import Any, Optional, Sequence |
8 | 8 | from torch_tensorrt import EngineCapability, Device |
9 | | -from torch_tensorrt.fx.utils import LowerPrecision |
10 | 9 | from torch.fx.passes.pass_manager import PassManager |
11 | 10 | from torch.fx.passes.shape_prop import ShapeProp |
12 | 11 | from torch_tensorrt.dynamo.aten_tracer import trace |
@@ -78,119 +77,63 @@ def compile( |
78 | 77 | if not isinstance(inputs, collections.abc.Sequence): |
79 | 78 | inputs = [inputs] |
80 | 79 |
|
81 | | - inputs = prepare_inputs(inputs, prepare_device(device)) |
| 80 | + torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device)) |
82 | 81 |
|
83 | 82 | if ( |
84 | 83 | torch.float16 in enabled_precisions |
85 | 84 | or torch_tensorrt.dtype.half in enabled_precisions |
86 | 85 | ): |
87 | | - lower_precision = LowerPrecision.FP16 |
| 86 | + precision = torch.float16 |
88 | 87 | elif ( |
89 | 88 | torch.float32 in enabled_precisions |
90 | 89 | or torch_tensorrt.dtype.float in enabled_precisions |
91 | 90 | ): |
92 | | - lower_precision = LowerPrecision.FP32 |
| 91 | + precision = torch.float32 |
93 | 92 | elif len(enabled_precisions) == 0: |
94 | 93 | logger.info(f"No precision specified, defaulting to {PRECISION}") |
95 | | - lower_precision = PRECISION |
| 94 | + precision = PRECISION |
96 | 95 | else: |
97 | 96 | raise ValueError( |
98 | 97 | f"Precision {enabled_precisions} not supported in the Dynamo Path" |
99 | 98 | ) |
100 | 99 |
|
| 100 | + compilation_options = { |
| 101 | + "precision": precision, |
| 102 | + "debug": debug, |
| 103 | + "workspace_size": workspace_size, |
| 104 | + "min_block_size": min_block_size, |
| 105 | + "torch_executed_ops": torch_executed_ops, |
| 106 | + "pass_through_build_failures": pass_through_build_failures, |
| 107 | + "max_aux_streams": max_aux_streams, |
| 108 | + "version_compatible": version_compatible, |
| 109 | + "optimization_level": optimization_level, |
| 110 | + "use_python_runtime": use_python_runtime, |
| 111 | + } |
| 112 | + |
101 | 113 | if kwargs.get("ir", "dynamo") == "torch_compile": |
102 | | - custom_backend = create_backend( |
103 | | - precision=lower_precision, |
104 | | - debug=debug, |
105 | | - workspace_size=workspace_size, |
106 | | - min_block_size=min_block_size, |
107 | | - torch_executed_ops=torch_executed_ops, |
108 | | - pass_through_build_failures=pass_through_build_failures, |
109 | | - max_aux_streams=max_aux_streams, |
110 | | - version_compatible=version_compatible, |
111 | | - optimization_level=optimization_level, |
112 | | - use_python_runtime=use_python_runtime, |
113 | | - **kwargs, |
| 114 | + model = torch.compile( |
| 115 | + gm, |
| 116 | + backend=torch_tensorrt_backend, |
| 117 | + options={**compilation_options, **kwargs}, |
114 | 118 | ) |
115 | | - model = torch.compile(gm, backend=custom_backend) |
116 | 119 | # Ensure compilation occurs by calling the function with provided inputs |
117 | | - model(*inputs) |
| 120 | + model(*torch_inputs) |
118 | 121 | return model |
119 | 122 |
|
120 | 123 | else: |
121 | | - settings = CompilationSettings( |
122 | | - debug=debug, |
123 | | - precision=lower_precision, |
124 | | - workspace_size=workspace_size, |
125 | | - min_block_size=min_block_size, |
126 | | - torch_executed_ops=torch_executed_ops, |
127 | | - pass_through_build_failures=pass_through_build_failures, |
128 | | - max_aux_streams=max_aux_streams, |
129 | | - version_compatible=version_compatible, |
130 | | - optimization_level=optimization_level, |
131 | | - use_python_runtime=use_python_runtime, |
132 | | - ) |
133 | | - |
134 | | - model = trace(gm, inputs, **kwargs) |
| 124 | + settings = CompilationSettings(**compilation_options) |
| 125 | + model = trace(gm, torch_inputs, **kwargs) |
135 | 126 |
|
136 | 127 | if kwargs.get("use_capability_partitioner", None): |
137 | | - model = lower_model(model, inputs) |
138 | | - return _compile_module(model, inputs, settings) |
| 128 | + model = lower_model(model, torch_inputs) |
| 129 | + return _compile_module(model, torch_inputs, settings) |
139 | 130 | else: |
140 | | - split_result = lower_model_using_trt_splitter(model, inputs) |
141 | | - trt_module = _compile_graph(split_result, inputs, settings) |
| 131 | + split_result = lower_model_using_trt_splitter(model, torch_inputs) |
| 132 | + trt_module = _compile_graph(split_result, torch_inputs, settings) |
142 | 133 |
|
143 | 134 | return trt_module |
144 | 135 |
|
145 | 136 |
|
146 | | -def create_backend( |
147 | | - precision: LowerPrecision = PRECISION, |
148 | | - debug: bool = DEBUG, |
149 | | - workspace_size: int = WORKSPACE_SIZE, |
150 | | - min_block_size: int = MIN_BLOCK_SIZE, |
151 | | - torch_executed_ops: Sequence[str] = set(), |
152 | | - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, |
153 | | - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, |
154 | | - version_compatible: bool = VERSION_COMPATIBLE, |
155 | | - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, |
156 | | - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, |
157 | | - **kwargs, |
158 | | -): |
159 | | - """Create torch.compile backend given specified arguments |
160 | | -
|
161 | | - Args: |
162 | | - precision: Model Layer precision |
163 | | - debug: Whether to print out verbose debugging information |
164 | | - workspace_size: Workspace TRT is allowed to use for the module (0 is default) |
165 | | - min_block_size: Minimum number of operators per TRT-Engine Block |
166 | | - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage |
167 | | - pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) |
168 | | - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine |
169 | | - version_compatible: Provide version forward-compatibility for engine plan files |
170 | | - optimization_level: Builder optimization 0-5, higher levels imply longer build time, |
171 | | - searching for more optimization options. TRT defaults to 3 |
172 | | - use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime |
173 | | - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the |
174 | | - argument as None |
175 | | - Returns: |
176 | | - Backend for torch.compile |
177 | | - """ |
178 | | - return partial( |
179 | | - torch_tensorrt_backend, |
180 | | - debug=debug, |
181 | | - precision=precision, |
182 | | - workspace_size=workspace_size, |
183 | | - min_block_size=min_block_size, |
184 | | - torch_executed_ops=torch_executed_ops, |
185 | | - pass_through_build_failures=pass_through_build_failures, |
186 | | - max_aux_streams=max_aux_streams, |
187 | | - version_compatible=version_compatible, |
188 | | - optimization_level=optimization_level, |
189 | | - use_python_runtime=use_python_runtime, |
190 | | - **kwargs, |
191 | | - ) |
192 | | - |
193 | | - |
194 | 137 | def _compile_graph( |
195 | 138 | split_result: TRTSplitter, |
196 | 139 | inputs: Any, |
@@ -234,7 +177,7 @@ def lower_model(model: torch.nn.Module, inputs: Any, **kwargs): |
234 | 177 | [fuse_permute_matmul, fuse_permute_linear] |
235 | 178 | ) |
236 | 179 | lowered_model = graph_optimization_pm(model) |
237 | | - if isinstance(lowered_model, torch.fx.GraphModule): |
238 | | - ShapeProp(lowered_model).propagate(*inputs) |
| 180 | + # if isinstance(lowered_model, torch.fx.GraphModule): |
| 181 | + # ShapeProp(lowered_model).propagate(*inputs) |
239 | 182 |
|
240 | 183 | return lowered_model |
0 commit comments