Skip to content

Commit ad5b467

Browse files
committed
empty_memory_format evaluator
1 parent 431e9a9 commit ad5b467

File tree

2 files changed

+176
-1
lines changed

2 files changed

+176
-1
lines changed

py/torch_tensorrt/dynamo/conversion/ops_evaluators.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
dynamo_tensorrt_converter,
1212
)
1313
from torch_tensorrt.fx.types import TRTTensor
14+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
1415

1516
_LOGGER: logging.Logger = logging.getLogger(__name__)
1617

@@ -88,7 +89,6 @@ def aten_ops_randn(
8889
) -> Union[TRTTensor, Sequence[TRTTensor]]:
8990
return np.random.randn(*args[0])
9091

91-
9292
def randperm_validator(randperm_node: Node) -> bool:
9393
dtype = randperm_node.kwargs.get("dtype", None)
9494
layout = randperm_node.kwargs.get("layout", None)
@@ -118,3 +118,52 @@ def aten_ops_randperm(
118118
name: str,
119119
) -> Union[TRTTensor, Sequence[TRTTensor]]:
120120
return np.random.permutation(args[0])
121+
122+
def empty_validator(empty_node: Node) -> bool:
123+
layout = empty_node.kwargs.get("layout", None)
124+
pin_memory = empty_node.kwargs.get("pin_memory", None)
125+
memory_format = empty_node.kwargs.get("memory_format", None)
126+
if layout is not None:
127+
_LOGGER.debug(f"Currently we don't support specifying layout, got {layout}.")
128+
return False
129+
return True
130+
131+
132+
@dynamo_tensorrt_converter(
133+
torch.ops.aten.empty.memory_format, capability_validator=empty_validator
134+
)
135+
def aten_ops_empty(
136+
ctx: ConversionContext,
137+
target: Target,
138+
args: Tuple[Argument, ...],
139+
kwargs: Dict[str, Argument],
140+
name: str,
141+
) -> Union[TRTTensor, Sequence[TRTTensor]]:
142+
empty_np_tensor = None
143+
memory_format = kwargs.get("memory_format")
144+
if kwargs.get("dtype") is not None:
145+
empty_np_tensor = np.empty(
146+
tuple(args[0]),
147+
dtype=unified_dtype_converter(kwargs.get("dtype"), Frameworks.NUMPY),
148+
)
149+
else:
150+
# default returns np.float64. Verify the correctness of this
151+
empty_np_tensor = np.empty(tuple(args[0]))
152+
153+
empty_tensor = torch.Tensor(empty_np_tensor)
154+
# device
155+
if kwargs.get("device") is not None:
156+
empty_tensor = empty_tensor.to(device=kwargs.get("device"))
157+
158+
# memory_format. default is torch.contiguous_format
159+
if memory_format == torch.channels_last:
160+
# shape of args[0] must be 4
161+
empty_tensor = empty_tensor.to(memory_format=torch.channels_last)
162+
elif memory_format == torch.channels_last_3d:
163+
# shape of args[0] must be 5
164+
empty_tensor = empty_tensor.to(memory_format=torch.channels_last_3d)
165+
166+
return empty_tensor
167+
168+
169+
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn as nn
4+
import torch_tensorrt
5+
from harness import DispatchTestCase
6+
from parameterized import parameterized
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
empty_ops = [
10+
(
11+
"empty_one_dimension",
12+
[1],
13+
None,
14+
None,
15+
None,
16+
),
17+
(
18+
"empty_two_dimension",
19+
[1, 2],
20+
None,
21+
None,
22+
None,
23+
),
24+
(
25+
"empty_three_dimension",
26+
[2, 3, 4],
27+
None,
28+
None,
29+
None,
30+
),
31+
(
32+
"empty_one_dimension_dtype",
33+
[1],
34+
torch.float32,
35+
None,
36+
None,
37+
),
38+
(
39+
"empty_two_dimension_dtype",
40+
[2, 3],
41+
torch.float32,
42+
None,
43+
None,
44+
),
45+
(
46+
"empty_one_dimension_dtype_device",
47+
[1],
48+
torch.float32,
49+
"cuda",
50+
None,
51+
),
52+
(
53+
"empty_two_dimension_dtype_device",
54+
[2, 3],
55+
torch.float32,
56+
"cuda",
57+
None,
58+
),
59+
(
60+
"empty_four_dimension_memformat",
61+
[1, 2, 2, 1],
62+
torch.float32,
63+
"cuda",
64+
torch.channels_last,
65+
),
66+
(
67+
"empty_five_dimension_memformat",
68+
[1, 2, 2, 2, 1],
69+
torch.float32,
70+
"cuda",
71+
torch.channels_last_3d,
72+
),
73+
]
74+
75+
76+
class TestRandConverter(DispatchTestCase):
77+
@parameterized.expand(
78+
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3]) for empty_op in empty_ops]
79+
)
80+
def test_empty(self, name, shape_or_input, data_type, device):
81+
class TestModule(nn.Module):
82+
def __init__(self):
83+
super().__init__()
84+
85+
def forward(self, x):
86+
shape_or_input[0] = x.shape[0]
87+
return torch.empty(shape_or_input)
88+
89+
empty_model = TestModule()
90+
91+
inputs = [torch.randint(1, 3, shape_or_input, dtype=torch.int32)]
92+
comparator_shape_dtype_device = (
93+
lambda x, y, check_dtype, check_device: x.shape == y.shape
94+
and (x.stride() == y.stride())
95+
and (x.dtype == y.dtype if check_dtype else True)
96+
and (x.get_device() == y.get_device() if check_device else True)
97+
)
98+
expected_ops = []
99+
if "device" in name:
100+
self.run_test_compare_tensor_attributes_only(
101+
empty_model,
102+
inputs,
103+
expected_ops,
104+
[(comparator_shape_dtype_device, [True, True])],
105+
use_dynamo_tracer=True,
106+
)
107+
elif "dtype" in name:
108+
self.run_test_compare_tensor_attributes_only(
109+
empty_model,
110+
inputs,
111+
expected_ops,
112+
[(comparator_shape_dtype_device, [True, False])],
113+
use_dynamo_tracer=True,
114+
)
115+
else:
116+
self.run_test_compare_tensor_attributes_only(
117+
empty_model,
118+
inputs,
119+
expected_ops,
120+
[(comparator_shape_dtype_device, [False, False])],
121+
use_dynamo_tracer=True,
122+
)
123+
124+
125+
if __name__ == "__main__":
126+
run_tests()

0 commit comments

Comments
 (0)