Skip to content

Commit e5afe0f

Browse files
authored
Dynamic support for split (#2871)
1 parent 60c26a2 commit e5afe0f

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -629,14 +629,19 @@ def aten_ops_softmax(
629629

630630

631631
@dynamo_tensorrt_converter(
632-
torch.ops.aten.split.Tensor, capability_validator=has_static_shapes_in_args([1])
632+
torch.ops.aten.split.Tensor,
633+
capability_validator=has_static_shapes_in_args([1]),
634+
supports_dynamic_shapes=True,
633635
)
634636
@dynamo_tensorrt_converter(
635-
torch.ops.aten.split.sizes, capability_validator=has_static_shapes_in_args([1])
637+
torch.ops.aten.split.sizes,
638+
capability_validator=has_static_shapes_in_args([1]),
639+
supports_dynamic_shapes=True,
636640
)
637641
@dynamo_tensorrt_converter(
638642
torch.ops.aten.split_with_sizes.default,
639643
capability_validator=has_static_shapes_in_args([1]),
644+
supports_dynamic_shapes=True,
640645
)
641646
def aten_ops_split(
642647
ctx: ConversionContext,

tests/py/dynamo/conversion/test_convolution_aten.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import torch
22
from parameterized import param, parameterized
33
from torch.testing._internal.common_utils import run_tests
4-
54
from torch_tensorrt import Input
65

76
from .harness import DispatchTestCase

tests/py/dynamo/conversion/test_split_aten.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def forward(self, input):
119119
@parameterized.expand(
120120
[
121121
("select_split_size_or_sections_dim_dynamic_shape", 2, 1),
122+
("select_split_size_or_sections_non_divisible_dim_dynamic_shape", 3, 1),
122123
]
123124
)
124125
def test_split_dynamic(self, _, split_size_or_tensor, dim):
@@ -132,9 +133,37 @@ def forward(self, input):
132133

133134
input_specs = [
134135
Input(
135-
shape=(1, 10, -1),
136136
dtype=torch.float32,
137-
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
137+
min_shape=[1, 10, 1],
138+
opt_shape=[1, 10, 10],
139+
max_shape=[1, 10, 10],
140+
),
141+
]
142+
self.run_test_with_dynamic_shape(
143+
TestModule(),
144+
input_specs,
145+
)
146+
147+
@parameterized.expand(
148+
[
149+
("select_split_size_or_sections_dim_dynamic_shape_on_first_axis", 2, 1),
150+
]
151+
)
152+
def test_split_dynamic_first_axis_dynamic(self, _, split_size_or_tensor, dim):
153+
class TestModule(torch.nn.Module):
154+
def __init__(self):
155+
super().__init__()
156+
157+
def forward(self, input):
158+
out = torch.ops.aten.split.Tensor(input, split_size_or_tensor, dim)
159+
return out
160+
161+
input_specs = [
162+
Input(
163+
dtype=torch.float32,
164+
min_shape=[1, 10, 10],
165+
opt_shape=[3, 10, 10],
166+
max_shape=[5, 10, 10],
138167
),
139168
]
140169
self.run_test_with_dynamic_shape(

0 commit comments

Comments
 (0)