Skip to content

Commit 0eae97f

Browse files
peri044gs-olive
andauthored
feat: Dynamo refactor (#2104)
Signed-off-by: Dheeraj Peri <[email protected]> Co-authored-by: gs-olive <[email protected]>
1 parent e884820 commit 0eae97f

File tree

79 files changed

+1225
-2595
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1225
-2595
lines changed

.circleci/config.yml

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ commands:
519519
command: |
520520
set -e
521521
mkdir -p /tmp/artifacts/test_results
522-
cd tests/py
522+
cd tests/py/ts/
523523
pytest --junitxml=/tmp/artifacts/test_results/api/api_test_results.xml api/
524524
pytest --junitxml=/tmp/artifacts/test_results/models/models_test_results.xml models/
525525
pytest --junitxml=/tmp/artifacts/test_results/integrations/integrations_test_results.xml integrations/
@@ -733,50 +733,47 @@ commands:
733733
# =================== FX tests end ======================== #
734734

735735
# =================== Dynamo tests start ======================== #
736-
test-dynamo-fx_ts:
737-
description: "Test the Dynamo fx_ts_compat path"
736+
737+
test-dynamo-torch_compile:
738+
description: "Test Dynamo torch_compile tests"
738739
steps:
739740
- run:
740-
name: Run Dynamo fx_ts_compat core tests
741+
name: Run Dynamo torch_compile tests
741742
command: |
742-
cd py/torch_tensorrt/dynamo/fx_ts_compat/test
743-
pushd core/
744-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/fx_ts_compat/test_results.xml
745-
popd
743+
cd tests/py/dynamo/backend/
744+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/torch_compile/test_results.xml
746745
747746
- store_test_results:
748747
path: /tmp/artifacts
749748
- store_artifacts:
750749
path: /tmp/testlogs
751750

752-
test-dynamo-compile-core:
753-
description: "Test the Dynamo compile path"
751+
test-dynamo-models_torch_compile:
752+
description: "Test the Dynamo models via torch_compile path"
754753
steps:
755754
- run:
756-
name: Run Dynamo compile core tests
755+
name: Run Dynamo models via torch_compile path
757756
command: |
758-
cd py/torch_tensorrt/dynamo/backend
759-
pushd test/
760-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml
761-
popd
757+
cd tests/py/dynamo/models
758+
pip3 install timm
759+
pip3 install transformers
760+
pytest test_models.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir torch_compile
762761
763762
- store_test_results:
764763
path: /tmp/artifacts
765764
- store_artifacts:
766765
path: /tmp/testlogs
767766

768-
test-dynamo-compile:
769-
description: "Test the Dynamo compile path"
767+
test-dynamo-models_torch_export:
768+
description: "Test the Dynamo models via torch_export path"
770769
steps:
771770
- run:
772-
name: Run Dynamo compile E2E tests
771+
name: Run Dynamo models via torch_export path
773772
command: |
774-
cd py/torch_tensorrt/dynamo/
775-
pushd test/
773+
cd tests/py/dynamo/models
776774
pip3 install timm
777775
pip3 install transformers
778-
pytest --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo_compile
779-
popd
776+
pytest test_models_export.py --junitxml=/tmp/artifacts/test_results/dynamo/backend/test_results.xml --ir dynamo
780777
781778
- store_test_results:
782779
path: /tmp/artifacts
@@ -1039,9 +1036,9 @@ jobs:
10391036
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
10401037
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
10411038
- dump-test-env
1042-
- test-dynamo-compile
1043-
- test-dynamo-compile-core
1044-
- test-dynamo-fx_ts
1039+
- test-dynamo-torch_compile
1040+
- test-dynamo-models_torch_compile
1041+
- test-dynamo-models_torch_export
10451042

10461043
package-x86_64-linux:
10471044
parameters:

py/torch_tensorrt/_Input.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,47 +302,58 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple:
302302
return result_domain
303303

304304
@classmethod
305-
def from_tensor(cls, t: torch.Tensor) -> "Input":
305+
def from_tensor(
306+
cls, t: torch.Tensor, disable_memory_format_check: bool = False
307+
) -> "Input":
306308
"""
307309
Produce a Input which contains the information of the given PyTorch tensor.
308310
309311
Args:
310312
tensor (torch.Tensor): A PyTorch tensor.
313+
disable_memory_format_check (bool): Whether to validate the memory formats of input tensors
311314
312315
Returns:
313316
A Input object.
314317
"""
315-
if not any(
316-
[
317-
t.is_contiguous(memory_format=torch.contiguous_format),
318-
t.is_contiguous(memory_format=torch.channels_last),
319-
]
318+
if not (
319+
t.is_contiguous(memory_format=torch.contiguous_format)
320+
or t.is_contiguous(memory_format=torch.channels_last)
321+
or disable_memory_format_check
320322
):
321323
raise ValueError(
322324
"Tensor does not have a supported memory format, supported formats are contiguous or channel_last"
323325
)
324326
frmt = (
325327
torch.contiguous_format
326-
if t.is_contiguous(memory_format=torch.contiguous_format)
328+
if (
329+
t.is_contiguous(memory_format=torch.contiguous_format)
330+
or disable_memory_format_check
331+
)
327332
else torch.channels_last
328333
)
329334
return cls(shape=t.shape, dtype=t.dtype, format=frmt)
330335

331336
@classmethod
332-
def from_tensors(cls, ts: torch.Tensor) -> List["Input"]:
337+
def from_tensors(
338+
cls, ts: torch.Tensor, disable_memory_format_check: bool = False
339+
) -> List["Input"]:
333340
"""
334341
Produce a list of Inputs which contain
335342
the information of all the given PyTorch tensors.
336343
337344
Args:
338345
tensors (Iterable[torch.Tensor]): A list of PyTorch tensors.
346+
disable_memory_format_check (bool): Whether to validate the memory formats of input tensors
339347
340348
Returns:
341349
A list of Inputs.
342350
"""
343351

344352
assert isinstance(ts, (list, tuple))
345-
return [cls.from_tensor(t) for t in ts]
353+
return [
354+
cls.from_tensor(t, disable_memory_format_check=disable_memory_format_check)
355+
for t in ts
356+
]
346357

347358
def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor:
348359
"""

py/torch_tensorrt/_compile.py

Lines changed: 53 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class _IRType(Enum):
1515

1616
ts = 0
1717
fx = 1
18-
fx_ts_compat = 2
19-
dynamo_compile = 3
18+
dynamo = 2
19+
torch_compile = 3
2020

2121

2222
class _ModuleType(Enum):
@@ -47,17 +47,17 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
4747

4848
ir_targets_torchscript = any([ir == opt for opt in ["torchscript", "ts"]])
4949
ir_targets_fx = ir == "fx"
50-
ir_targets_dynamo_compile = ir == "dynamo_compile"
51-
ir_targets_fx_ts_compat = ir == "fx_ts_compat"
50+
ir_targets_dynamo = ir == "dynamo"
51+
ir_targets_torch_compile = ir == "torch_compile"
5252

5353
if module_is_tsable and ir_targets_torchscript:
5454
return _IRType.ts
5555
elif module_is_fxable and ir_targets_fx:
5656
return _IRType.fx
57-
elif module_is_fxable and ir_targets_fx_ts_compat:
58-
return _IRType.fx_ts_compat
59-
elif module_is_fxable and ir_targets_dynamo_compile:
60-
return _IRType.dynamo_compile
57+
elif module_is_fxable and ir_targets_dynamo:
58+
return _IRType.dynamo
59+
elif module_is_fxable and ir_targets_torch_compile:
60+
return _IRType.torch_compile
6161
else:
6262
if ir == "default":
6363
# Options are listed in order of preference
@@ -67,13 +67,13 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
6767
)
6868
return _IRType.ts
6969
elif module_is_fxable:
70-
raise ValueError(
71-
"Was given a torch.fx.GraphModule, fx is not currently supported by Torch-TensorRT"
70+
logging.log(
71+
logging.Level.Warning,
72+
"Input graph is a torch.fx.GraphModule but the ir provided is default (ts). Please set ir=dynamo to suppress the warning.",
7273
)
73-
# logging.log(logging.Level.Info, "ir was set to default, using TorchScript as fx")
74-
# return _IRType.fx
74+
return _IRType.dynamo
7575
else:
76-
raise ValueError("Module was provided with in an unsupported format")
76+
raise ValueError("Module was provided in an unsupported format")
7777
else:
7878
raise ValueError("Unknown ir was requested")
7979

@@ -156,18 +156,41 @@ def compile(
156156
dynamic_batch=False,
157157
**kwargs,
158158
)
159-
elif target_ir == _IRType.dynamo_compile:
159+
elif target_ir == _IRType.dynamo:
160+
from torch_tensorrt import Device
161+
from torch_tensorrt.dynamo.utils import prepare_inputs, prepare_device
162+
import collections.abc
163+
164+
if not isinstance(inputs, collections.abc.Sequence):
165+
inputs = [inputs]
166+
device = kwargs.get("device", Device._current_device())
167+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
168+
module = torch_tensorrt.dynamo.trace(module, torch_inputs, **kwargs)
160169
return torch_tensorrt.dynamo.compile(
161-
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
162-
)
163-
elif target_ir == _IRType.fx_ts_compat:
164-
return torch_tensorrt.dynamo.fx_ts_compat.compile(
165-
module, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
170+
module,
171+
inputs=inputs,
172+
enabled_precisions=enabled_precisions,
173+
**kwargs,
166174
)
175+
elif target_ir == _IRType.torch_compile:
176+
return torch_compile(module, enabled_precisions=enabled_precisions, **kwargs)
167177
else:
168178
raise RuntimeError("Module is an unknown format or the ir requested is unknown")
169179

170180

181+
def torch_compile(module, **kwargs):
182+
"""
183+
Returns a boxed model which is the output of torch.compile.
184+
This does not compile the model to TRT. Execute this model on
185+
sample inputs to compile the model to TRT.
186+
"""
187+
from torch_tensorrt.dynamo.backend import torch_tensorrt_backend
188+
189+
boxed_fn = torch.compile(module, backend=torch_tensorrt_backend, options={**kwargs})
190+
191+
return boxed_fn
192+
193+
171194
def convert_method_to_trt_engine(
172195
module: Any,
173196
method_name: str,
@@ -224,6 +247,16 @@ def convert_method_to_trt_engine(
224247
**kwargs,
225248
)
226249
elif target_ir == _IRType.fx:
227-
raise RuntimeError("fx is currently not supported")
250+
raise RuntimeError(
251+
"convert_method_to_trt_engine call is not supported for ir=fx"
252+
)
253+
elif target_ir == _IRType.dynamo:
254+
raise RuntimeError(
255+
"convert_method_to_trt_engine call is not supported for ir=dynamo."
256+
)
257+
elif target_ir == _IRType.torch_compile:
258+
raise RuntimeError(
259+
"convert_method_to_trt_engine call is not supported for ir=torch_compile"
260+
)
228261
else:
229262
raise RuntimeError("Module is an unknown format or the ir requested is unknown")

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@
22
from torch_tensorrt._util import sanitized_torch_version
33

44
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
5-
from torch_tensorrt.dynamo import fx_ts_compat
6-
from .backend import compile
5+
from ._settings import *
6+
from .compile import compile
7+
from .aten_tracer import trace

py/torch_tensorrt/dynamo/backend/_defaults.py renamed to py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from torch_tensorrt.fx.utils import LowerPrecision
1+
import torch
22

3-
4-
PRECISION = LowerPrecision.FP32
3+
PRECISION = torch.float32
54
DEBUG = False
65
WORKSPACE_SIZE = 0
76
MIN_BLOCK_SIZE = 5

py/torch_tensorrt/dynamo/backend/_settings.py renamed to py/torch_tensorrt/dynamo/_settings.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, Sequence
3-
4-
from torch_tensorrt.fx.utils import LowerPrecision
5-
from torch_tensorrt.dynamo.backend._defaults import (
3+
import torch
4+
from torch_tensorrt.dynamo._defaults import (
65
PRECISION,
76
DEBUG,
87
WORKSPACE_SIZE,
@@ -17,7 +16,7 @@
1716

1817
@dataclass
1918
class CompilationSettings:
20-
precision: LowerPrecision = PRECISION
19+
precision: torch.dtype = PRECISION
2120
debug: bool = DEBUG
2221
workspace_size: int = WORKSPACE_SIZE
2322
min_block_size: int = MIN_BLOCK_SIZE

0 commit comments

Comments
 (0)