Skip to content

Commit 0387757

Browse files
authored
empty_memory_format evaluator (#2745)
1 parent 0c38016 commit 0387757

File tree

2 files changed

+131
-0
lines changed

2 files changed

+131
-0
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from torch_tensorrt.dynamo.conversion.impl.elementwise import sub, trunc_div
2020
from torch_tensorrt.fx.types import TRTTensor
21+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2122

2223
_LOGGER: logging.Logger = logging.getLogger(__name__)
2324

@@ -165,3 +166,43 @@ def aten_ops_randperm(
165166
name: str,
166167
) -> Union[TRTTensor, Sequence[TRTTensor]]:
167168
return np.random.permutation(args[0])
169+
170+
171+
def empty_validator(empty_node: Node) -> bool:
172+
device = empty_node.kwargs.get("device", None)
173+
if device is not None:
174+
_LOGGER.debug(f"Currently we don't support specifying device, got {device}.")
175+
return False
176+
layout = empty_node.kwargs.get("layout", None)
177+
if layout is not None:
178+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
179+
return False
180+
memory_format = empty_node.kwargs.get("memory_format", None)
181+
if memory_format is not None:
182+
_LOGGER.debug(
183+
f"Currently we don't support specifying memory_format, got {memory_format}."
184+
)
185+
return False
186+
return True
187+
188+
189+
@dynamo_tensorrt_converter(
190+
torch.ops.aten.empty.memory_format, capability_validator=empty_validator
191+
)
192+
def aten_ops_empty(
193+
ctx: ConversionContext,
194+
target: Target,
195+
args: Tuple[Argument, ...],
196+
kwargs: Dict[str, Argument],
197+
name: str,
198+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
199+
empty_np_tensor = None
200+
if kwargs.get("dtype") is not None:
201+
empty_np_tensor = np.empty(
202+
tuple(args[0]),
203+
dtype=unified_dtype_converter(kwargs.get("dtype"), Frameworks.NUMPY),
204+
)
205+
else:
206+
# default returns np.float64. Verify the correctness of this
207+
empty_np_tensor = np.empty(tuple(args[0]))
208+
return empty_np_tensor
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch_tensorrt
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
8+
from .harness import DispatchTestCase
9+
10+
empty_ops = [
11+
(
12+
"empty_one_dimension",
13+
[1],
14+
None,
15+
),
16+
(
17+
"empty_two_dimension",
18+
[1, 2],
19+
None,
20+
),
21+
(
22+
"empty_three_dimension",
23+
[2, 3, 4],
24+
None,
25+
),
26+
(
27+
"empty_one_dimension_dtype",
28+
[1],
29+
torch.float32,
30+
),
31+
(
32+
"empty_two_dimension_dtype",
33+
[2, 3],
34+
torch.float32,
35+
),
36+
(
37+
"empty_four_dimension_dtype",
38+
[1, 2, 2, 1],
39+
torch.float32,
40+
),
41+
(
42+
"empty_five_dimension_dtype",
43+
[1, 2, 2, 2, 1],
44+
torch.float32,
45+
),
46+
]
47+
48+
49+
class TestEmptyConverter(DispatchTestCase):
50+
@parameterized.expand(
51+
[(empty_op[0], empty_op[1], empty_op[2]) for empty_op in empty_ops]
52+
)
53+
def test_empty(self, name, shape_or_input, data_type):
54+
class TestModule(nn.Module):
55+
def forward(self, x):
56+
shape_or_input[0] = x.shape[0]
57+
return torch.ops.aten.empty.memory_format(
58+
shape_or_input,
59+
dtype=data_type,
60+
)
61+
62+
empty_model = TestModule()
63+
64+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
65+
comparator_shape_dtype_device = (
66+
lambda x, y, check_dtype: x.shape == y.shape
67+
and (x.stride() == y.stride())
68+
and (x.dtype == y.dtype if check_dtype else True)
69+
)
70+
expected_ops = []
71+
if "dtype" in name:
72+
self.run_test_compare_tensor_attributes_only(
73+
empty_model,
74+
inputs,
75+
expected_ops,
76+
[(comparator_shape_dtype_device, [True])],
77+
use_dynamo_tracer=True,
78+
)
79+
else:
80+
self.run_test_compare_tensor_attributes_only(
81+
empty_model,
82+
inputs,
83+
expected_ops,
84+
[(comparator_shape_dtype_device, [False])],
85+
use_dynamo_tracer=True,
86+
)
87+
88+
89+
if __name__ == "__main__":
90+
run_tests()

0 commit comments

Comments
 (0)