Skip to content

Commit 7be80f4

Browse files
committed
fix: Reorganize folders in latest implementation
- Update test references and imports accordingly
1 parent 910e948 commit 7be80f4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

55 files changed

+213
-250
lines changed

py/torch_tensorrt/_Input.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ def __init__(self, *args, **kwargs):
6868
- Input(shape=(1,3,32,32), dtype=torch_tensorrt.dtype.int32, format=torch_tensorrt.TensorFormat.NCHW)
6969
- Input(min_shape=(1,3,32,32), opt_shape=[2,3,32,32], max_shape=(3,3,32,32)) #Implicitly dtype=torch_tensorrt.dtype.float32, format=torch_tensorrt.TensorFormat.NCHW
7070
"""
71+
# Compatibility code for switching over from InputTensorSpec
72+
if "shape" in kwargs and "shape_ranges" in kwargs:
73+
assert (
74+
len(kwargs["shape_ranges"]) == 1 and len(kwargs["shape_ranges"][0]) == 3
75+
)
76+
del kwargs["shape"]
77+
78+
kwargs["min_shape"] = kwargs["shape_ranges"][0][0]
79+
kwargs["opt_shape"] = kwargs["shape_ranges"][0][1]
80+
kwargs["max_shape"] = kwargs["shape_ranges"][0][2]
81+
7182
if len(args) == 1:
7283
if not Input._supported_input_size_type(args[0]):
7384
raise TypeError(

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from torch_tensorrt._util import sanitized_torch_version
33

44
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
5-
from .converters import *
65
from ._settings import *
6+
from .conversion import *
77
from .aten_tracer import trace
88
from .converter_registry import (
99
DYNAMO_CONVERTERS,
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from .SourceIR import SourceIR
2+
from .aten_ops_converters import *
13
from .trt_interpreter import *
24
from .conversion import *
35
from .truncate_long_and_double import repair_long_or_double_inputs

py/torch_tensorrt/dynamo/converters/aten_ops_converters.py renamed to py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import logging
22
from typing import Dict, Sequence, Tuple, Union
33
import torch
4+
import tensorrt as trt
45
from torch_tensorrt.fx.converters import acc_ops_converters
56
from ..converter_registry import dynamo_tensorrt_converter
67
from torch.fx.node import Argument, Target
78

89
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9-
from torch_tensorrt.dynamo.converters import SourceIR
10-
from torch_tensorrt.dynamo.converters import impl
10+
from torch_tensorrt.dynamo.conversion import SourceIR, impl
1111

1212
_LOGGER: logging.Logger = logging.getLogger(__name__)
1313

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
import torch
22

3+
from torch_tensorrt.fx.types import (
4+
TRTDataType,
5+
TRTNetwork,
6+
TRTTensor,
7+
)
8+
39

410
def dynamic_unsupported(node: torch.fx.Node) -> bool:
511
# Validate that none of the inputs to the node have Dynamic shapes
@@ -28,3 +34,63 @@ def dynamic_unsupported(node: torch.fx.Node) -> bool:
2834
return False
2935

3036
return True
37+
38+
39+
def cast_trt_tensor(
40+
network: TRTNetwork,
41+
input_val: TRTTensor,
42+
dtype: TRTDataType,
43+
name: str,
44+
) -> TRTTensor:
45+
"""
46+
Given a TRT Tensor, convert that Tensor to the specified dtype
47+
Adds an Identity layer to the network which performs the conversion
48+
Args:
49+
network (TRTNetwork): A TensorRT network
50+
input_val (TRTTensor): A TRT Tensor to cast to a new data type
51+
dtype (TRTDataType): The TRTDataType to cast the input Tensor to
52+
name (str): Name of the calling layer
53+
Returns:
54+
A TensorRT ITensor which has been casted to the specified dtype
55+
"""
56+
#
57+
if input_val.dtype != dtype:
58+
identity_layer = network.add_identity(input_val)
59+
identity_layer.set_output_type(0, dtype)
60+
identity_layer.name = (
61+
f"Cast ITensor {input_val.name} from {input_val.dtype} to {dtype} - {name}"
62+
)
63+
return identity_layer.get_output(0)
64+
else:
65+
return input_val
66+
67+
68+
def broadcastable(
69+
a: TRTTensor,
70+
b: TRTTensor,
71+
) -> bool:
72+
"Check if two tensors are broadcastable according to torch rules"
73+
a_shape = tuple(a.shape)
74+
b_shape = tuple(b.shape)
75+
# check from the trailing
76+
diff = len(a_shape) - len(b_shape)
77+
if diff == 0:
78+
return True
79+
if diff > 0:
80+
max = len(a_shape)
81+
min = len(b_shape)
82+
greater_tensor = a_shape
83+
lesser_tensor = b_shape
84+
elif diff < 0:
85+
max = len(b_shape)
86+
min = len(a_shape)
87+
greater_tensor = b_shape
88+
lesser_tensor = a_shape
89+
j = min - 1
90+
for i in range(max - 1, diff - 1, -1):
91+
if not (
92+
greater_tensor[i] != lesser_tensor[j]
93+
and (greater_tensor[i] == 1 or lesser_tensor[i] == 1)
94+
):
95+
return False
96+
return True

py/torch_tensorrt/dynamo/converters/impl/activation.py renamed to py/torch_tensorrt/dynamo/conversion/impl/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
set_layer_name,
1313
get_trt_plugin,
1414
)
15-
from torch_tensorrt.dynamo.converters import SourceIR
15+
from torch_tensorrt.dynamo.conversion import SourceIR
1616

1717
from torch_tensorrt.fx.types import (
1818
TRTNetwork,

py/torch_tensorrt/dynamo/converters/impl/condition/ops.py renamed to py/torch_tensorrt/dynamo/conversion/impl/condition/ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
from torch.fx.node import Target
77

88
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
9-
from torch_tensorrt.dynamo.converters import SourceIR
10-
from torch_tensorrt.dynamo.converters.converter_utils import broadcastable
9+
from torch_tensorrt.dynamo.conversion import SourceIR
10+
from torch_tensorrt.dynamo.conversion.converter_utils import broadcastable
1111
from torch_tensorrt.fx.converters.converter_utils import (
1212
broadcast,
1313
get_trt_tensor,
1414
set_layer_name,
1515
)
16-
from torch_tensorrt.dynamo.converters.impl.slice import expand
16+
from torch_tensorrt.dynamo.conversion.impl.slice import expand
1717

1818

1919
def where(

0 commit comments

Comments
 (0)