Skip to content

Commit a135bab

Browse files
committed
Review comments- adding cases for stride, correcting validator and changing call to torch.ops.aten.empty.memory_format
1 parent 973e4ae commit a135bab

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

tests/py/dynamo/conversion/test_empty_aten.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
import torch
33
import torch.nn as nn
44
import torch_tensorrt
5-
from .harness import DispatchTestCase
65
from parameterized import parameterized
76
from torch.testing._internal.common_utils import run_tests
87

8+
from .harness import DispatchTestCase
9+
910
empty_ops = [
1011
(
1112
"empty_one_dimension",
@@ -73,9 +74,12 @@
7374
]
7475

7576

76-
class TestRandConverter(DispatchTestCase):
77+
class TestEmptyConverter(DispatchTestCase):
7778
@parameterized.expand(
78-
[(empty_op[0], empty_op[1], empty_op[2], empty_op[3], empty_op[4]) for empty_op in empty_ops]
79+
[
80+
(empty_op[0], empty_op[1], empty_op[2], empty_op[3], empty_op[4])
81+
for empty_op in empty_ops
82+
]
7983
)
8084
def test_empty(self, name, shape_or_input, data_type, device, memory_format):
8185
class TestModule(nn.Module):
@@ -84,7 +88,12 @@ def __init__(self):
8488

8589
def forward(self, x):
8690
shape_or_input[0] = x.shape[0]
87-
return torch.ops.aten.empty.memory_format(shape_or_input, dtype = data_type, memory_format = layout, device = device)
91+
return torch.ops.aten.empty.memory_format(
92+
shape_or_input,
93+
dtype=data_type,
94+
memory_format=memory_format,
95+
device=device,
96+
)
8897

8998
empty_model = TestModule()
9099

0 commit comments

Comments
 (0)