Skip to content

Commit 9ac3898

Browse files
committed
fix: Replay all FX changes in Dynamo
- Add multiple fixes to make FX changes appear in Dynamo directory, using Dynamo registry - All converters with open PRs are linked and shown - Update references, imports, code, merges, rebases accordingly - Add new test cases to Dynamo for converters
1 parent d209b19 commit 9ac3898

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+4421
-0
lines changed

.circleci/config.yml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,21 @@ commands:
780780
- store_artifacts:
781781
path: /tmp/testlogs
782782

783+
test-dynamo-converters:
784+
description: "Test the Dynamo aten converters"
785+
steps:
786+
- run:
787+
name: Run Dynamo converter tests
788+
command: |
789+
cd tests/py/dynamo/converters
790+
TESTS_TO_RUN=$(circleci tests glob "test_*.py" | circleci tests split --split-by=timings)
791+
pytest --junitxml=/tmp/artifacts/test_results/dynamo/converters/test_results.xml $TESTS_TO_RUN
792+
793+
- store_test_results:
794+
path: /tmp/artifacts
795+
- store_artifacts:
796+
path: /tmp/testlogs
797+
783798
# =================== Dynamo tests end ======================== #
784799

785800
# Define a job to be invoked later in a workflow.
@@ -1036,6 +1051,7 @@ jobs:
10361051
command: pip3 install --pre /tmp/dist/x86_64-linux/*cp39-cp39*.whl
10371052
# We install torch after torch-trt because pip automatically enforces the version constraint otherwise
10381053
- dump-test-env
1054+
- test-dynamo-converters
10391055
- test-dynamo-torch_compile
10401056
- test-dynamo-models_torch_compile
10411057
- test-dynamo-models_torch_export

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torch_tensorrt._util import sanitized_torch_version
33

44
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
5+
from .converters import *
56
from ._settings import *
67
from .aten_tracer import trace
78
from .converter_registry import (
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from enum import Enum, auto
2+
3+
4+
class SourceIR(Enum):
5+
NN = auto()
6+
ACC = auto()
7+
ATEN = auto()
8+
PRIM = auto()
9+
TORCHTRT_LOWERED = auto()
10+
UNKNOWN = auto()
11+
12+
def __str__(self):
13+
if self == SourceIR.NN:
14+
return "nn"
15+
elif self == SourceIR.ACC:
16+
return "acc"
17+
elif self == SourceIR.ATEN:
18+
return "aten"
19+
elif self == SourceIR.PRIM:
20+
return "prim"
21+
elif self == SourceIR.TORCHTRT_LOWERED:
22+
return "torchtrt_lowered"
23+
else:
24+
return "unknown_ir"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .SourceIR import SourceIR
2+
from .aten_ops_converters import *
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import logging
2+
from typing import Dict, Sequence, Tuple, Union
3+
import torch
4+
from torch_tensorrt.fx.converters import acc_ops_converters
5+
from ..converter_registry import dynamo_tensorrt_converter
6+
from torch.fx.node import Argument, Target
7+
8+
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9+
from torch_tensorrt.dynamo.converters import SourceIR
10+
from torch_tensorrt.dynamo.converters import impl
11+
12+
_LOGGER: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
def or_none(args, i):
16+
return args[i] if len(args) > i else None
17+
18+
19+
@dynamo_tensorrt_converter(torch.ops.aten.batch_norm)
20+
def aten_ops_batch_norm(
21+
network: TRTNetwork,
22+
target: Target,
23+
args: Tuple[Argument, ...],
24+
kwargs: Dict[str, Argument],
25+
name: str,
26+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
27+
return impl.normalization.batch_norm(
28+
network,
29+
target,
30+
SourceIR.ATEN,
31+
name,
32+
args[0],
33+
args[1],
34+
args[2],
35+
args[3],
36+
args[4],
37+
args[5],
38+
args[6],
39+
args[7],
40+
)
41+
42+
43+
@dynamo_tensorrt_converter(torch.ops.aten.div.default)
44+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode)
45+
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor)
46+
def aten_ops_div(
47+
network: TRTNetwork,
48+
target: Target,
49+
args: Tuple[Argument, ...],
50+
kwargs: Dict[str, Argument],
51+
name: str,
52+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
53+
kwargs_new = {
54+
"input": args[0],
55+
"other": args[1],
56+
}
57+
rounding_mode = kwargs.get("rounding_mode")
58+
if rounding_mode is None:
59+
return acc_ops_converters.acc_ops_div(network, target, None, kwargs_new, name)
60+
elif rounding_mode == "floor":
61+
return acc_ops_converters.acc_ops_floor_div(
62+
network, target, None, kwargs_new, name
63+
)
64+
elif rounding_mode == "trunc":
65+
return impl.elementwise.trunc_div(
66+
network, target, SourceIR.ATEN, name, args[0], args[1]
67+
)
68+
else:
69+
raise RuntimeError(
70+
f"Target {target} does not support rounding mode {rounding_mode}"
71+
)
72+
73+
74+
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Scalar)
75+
@dynamo_tensorrt_converter(torch.ops.aten.fmod.Tensor)
76+
def aten_ops_fmod(
77+
network: TRTNetwork,
78+
target: Target,
79+
args: Tuple[Argument, ...],
80+
kwargs: Dict[str, Argument],
81+
name: str,
82+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
83+
return impl.elementwise.fmod(network, target, SourceIR.ATEN, name, args[0], args[1])
84+
85+
86+
@dynamo_tensorrt_converter(torch.ops.aten.gelu.default)
87+
def aten_ops_gelu(
88+
network: TRTNetwork,
89+
target: Target,
90+
args: Tuple[Argument, ...],
91+
kwargs: Dict[str, Argument],
92+
name: str,
93+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
94+
return impl.activation.gelu(
95+
network,
96+
target,
97+
SourceIR.ATEN,
98+
name,
99+
args[0],
100+
)
101+
102+
103+
@dynamo_tensorrt_converter(torch.ops.aten.matmul)
104+
@dynamo_tensorrt_converter(torch.ops.aten.mm.default)
105+
def aten_ops_matmul(
106+
network: TRTNetwork,
107+
target: Target,
108+
args: Tuple[Argument, ...],
109+
kwargs: Dict[str, Argument],
110+
name: str,
111+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
112+
return impl.matmul.matrix_multiply(
113+
network, target, SourceIR.ATEN, name, args[0], args[1]
114+
)
115+
116+
117+
@dynamo_tensorrt_converter(torch.ops.aten.layer_norm.default)
118+
def aten_ops_layernorm(
119+
network: TRTNetwork,
120+
target: Target,
121+
args: Tuple[Argument, ...],
122+
kwargs: Dict[str, Argument],
123+
name: str,
124+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
125+
return impl.normalization.layer_norm(
126+
network,
127+
target,
128+
SourceIR.ATEN,
129+
name,
130+
args[0],
131+
args[1],
132+
args[2],
133+
args[3],
134+
args[4],
135+
)
136+
137+
138+
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
139+
def aten_ops_relu(
140+
network: TRTNetwork,
141+
target: Target,
142+
args: Tuple[Argument, ...],
143+
kwargs: Dict[str, Argument],
144+
name: str,
145+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
146+
147+
return impl.activation.relu(
148+
network,
149+
target,
150+
SourceIR.ATEN,
151+
name,
152+
args[0],
153+
)
154+
155+
156+
@dynamo_tensorrt_converter(torch.ops.aten.rsqrt.default)
157+
def aten_ops_rsqrt(
158+
network: TRTNetwork,
159+
target: Target,
160+
args: Tuple[Argument, ...],
161+
kwargs: Dict[str, Argument],
162+
name: str,
163+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
164+
165+
return impl.elementwise.rsqrt(
166+
network,
167+
target,
168+
SourceIR.ATEN,
169+
name,
170+
args[0],
171+
)
172+
173+
174+
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dim)
175+
@dynamo_tensorrt_converter(torch.ops.aten.squeeze.dims)
176+
def aten_ops_squeeze(
177+
network: TRTNetwork,
178+
target: Target,
179+
args: Tuple[Argument, ...],
180+
kwargs: Dict[str, Argument],
181+
name: str,
182+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
183+
return impl.squeeze.squeeze(network, target, SourceIR.ATEN, name, args[0], args[1])
184+
185+
186+
@dynamo_tensorrt_converter(torch.ops.aten.unsqueeze.default)
187+
def aten_ops_unsqueeze(
188+
network: TRTNetwork,
189+
target: Target,
190+
args: Tuple[Argument, ...],
191+
kwargs: Dict[str, Argument],
192+
name: str,
193+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
194+
return impl.unsqueeze.unsqueeze(
195+
network, target, SourceIR.ATEN, name, input_t=args[0], dim=args[1]
196+
)
197+
198+
199+
@dynamo_tensorrt_converter(torch.ops.aten.rsub.Tensor)
200+
def aten_ops_rsub(
201+
network: TRTNetwork,
202+
target: Target,
203+
args: Tuple[Argument, ...],
204+
kwargs: Dict[str, Argument],
205+
name: str,
206+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
207+
alpha = None
208+
if "alpha" in kwargs:
209+
alpha = kwargs["alpha"]
210+
return impl.elementwise.rsub(
211+
network, target, SourceIR.ATEN, name, args[0], args[1], alpha
212+
)
213+
214+
215+
@dynamo_tensorrt_converter(torch.ops.aten._softmax.default)
216+
def aten_ops_softmax(
217+
network: TRTNetwork,
218+
target: Target,
219+
args: Tuple[Argument, ...],
220+
kwargs: Dict[str, Argument],
221+
name: str,
222+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
223+
return impl.normalization.softmax(
224+
network, target, SourceIR.ATEN, name, args[0], args[1]
225+
)
226+
227+
228+
@dynamo_tensorrt_converter(torch.ops.aten.where.self)
229+
def aten_ops_where(
230+
network: TRTNetwork,
231+
target: Target,
232+
args: Tuple[Argument, ...],
233+
kwargs: Dict[str, Argument],
234+
name: str,
235+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
236+
return impl.condition.where(
237+
network,
238+
target,
239+
SourceIR.ATEN,
240+
name,
241+
args[1],
242+
args[2],
243+
args[0],
244+
)
245+
246+
247+
@dynamo_tensorrt_converter(torch.ops.aten.clamp.default)
248+
def aten_ops_clamp(
249+
network: TRTNetwork,
250+
target: Target,
251+
args: Tuple[Argument, ...],
252+
kwargs: Dict[str, Argument],
253+
name: str,
254+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
255+
return impl.elementwise.clamp(
256+
network,
257+
target,
258+
SourceIR.ACC,
259+
name,
260+
input_val=args[0],
261+
min_val=or_none(args, 1),
262+
max_val=or_none(args, 2),
263+
)
264+
265+
266+
@dynamo_tensorrt_converter(torch.ops.aten.select.int)
267+
def aten_ops_select(
268+
network: TRTNetwork,
269+
target: Target,
270+
args: Tuple[Argument, ...],
271+
kwargs: Dict[str, Argument],
272+
name: str,
273+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
274+
return impl.select.select(
275+
network, target, SourceIR.ATEN, name, args[0], args[1], args[2]
276+
)
277+
278+
279+
@dynamo_tensorrt_converter(torch.ops.aten.slice.Tensor)
280+
def aten_ops_slice(
281+
network: TRTNetwork,
282+
target: Target,
283+
args: Tuple[Argument, ...],
284+
kwargs: Dict[str, Argument],
285+
name: str,
286+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
287+
return impl.slice.slice_op(
288+
network,
289+
target,
290+
SourceIR.ATEN,
291+
name,
292+
args[0],
293+
args[1],
294+
args[2],
295+
args[3],
296+
args[4],
297+
)

0 commit comments

Comments
 (0)