12
12
"empty_one_dimension" ,
13
13
[1 ],
14
14
None ,
15
- None ,
16
- None ,
17
15
),
18
16
(
19
17
"empty_two_dimension" ,
20
18
[1 , 2 ],
21
19
None ,
22
- None ,
23
- None ,
24
20
),
25
21
(
26
22
"empty_three_dimension" ,
27
23
[2 , 3 , 4 ],
28
24
None ,
29
- None ,
30
- None ,
31
25
),
32
26
(
33
27
"empty_one_dimension_dtype" ,
34
28
[1 ],
35
29
torch .float32 ,
36
- None ,
37
- None ,
38
30
),
39
31
(
40
32
"empty_two_dimension_dtype" ,
41
33
[2 , 3 ],
42
34
torch .float32 ,
43
- None ,
44
- None ,
45
- ),
46
- (
47
- "empty_one_dimension_dtype_device" ,
48
- [1 ],
49
- torch .float32 ,
50
- "cuda" ,
51
- None ,
52
- ),
53
- (
54
- "empty_two_dimension_dtype_device" ,
55
- [2 , 3 ],
56
- torch .float32 ,
57
- "cuda" ,
58
- None ,
59
35
),
60
36
(
61
- "empty_four_dimension_memformat " ,
37
+ "empty_four_dimension_dtype " ,
62
38
[1 , 2 , 2 , 1 ],
63
39
torch .float32 ,
64
- "cuda" ,
65
- torch .channels_last ,
66
40
),
67
41
(
68
- "empty_five_dimension_memformat " ,
42
+ "empty_five_dimension_dtype " ,
69
43
[1 , 2 , 2 , 2 , 1 ],
70
44
torch .float32 ,
71
- "cuda" ,
72
- torch .channels_last_3d ,
73
45
),
74
46
]
75
47
76
48
77
49
class TestEmptyConverter (DispatchTestCase ):
78
50
@parameterized .expand (
79
- [
80
- (empty_op [0 ], empty_op [1 ], empty_op [2 ], empty_op [3 ], empty_op [4 ])
81
- for empty_op in empty_ops
82
- ]
51
+ [(empty_op [0 ], empty_op [1 ], empty_op [2 ]) for empty_op in empty_ops ]
83
52
)
84
- def test_empty (self , name , shape_or_input , data_type , device , memory_format ):
53
+ def test_empty (self , name , shape_or_input , data_type ):
85
54
class TestModule (nn .Module ):
86
55
def __init__ (self ):
87
56
super ().__init__ ()
@@ -91,42 +60,31 @@ def forward(self, x):
91
60
return torch .ops .aten .empty .memory_format (
92
61
shape_or_input ,
93
62
dtype = data_type ,
94
- memory_format = memory_format ,
95
- device = device ,
96
63
)
97
64
98
65
empty_model = TestModule ()
99
66
100
67
inputs = [torch .randint (1 , 3 , shape_or_input , dtype = torch .int32 )]
101
68
comparator_shape_dtype_device = (
102
- lambda x , y , check_dtype , check_device : x .shape == y .shape
69
+ lambda x , y , check_dtype : x .shape == y .shape
103
70
and (x .stride () == y .stride ())
104
71
and (x .dtype == y .dtype if check_dtype else True )
105
- and (x .get_device () == y .get_device () if check_device else True )
106
72
)
107
73
expected_ops = []
108
- if "device" in name :
109
- self .run_test_compare_tensor_attributes_only (
110
- empty_model ,
111
- inputs ,
112
- expected_ops ,
113
- [(comparator_shape_dtype_device , [True , True ])],
114
- use_dynamo_tracer = True ,
115
- )
116
- elif "dtype" in name :
74
+ if "dtype" in name :
117
75
self .run_test_compare_tensor_attributes_only (
118
76
empty_model ,
119
77
inputs ,
120
78
expected_ops ,
121
- [(comparator_shape_dtype_device , [True , False ])],
79
+ [(comparator_shape_dtype_device , [True ])],
122
80
use_dynamo_tracer = True ,
123
81
)
124
82
else :
125
83
self .run_test_compare_tensor_attributes_only (
126
84
empty_model ,
127
85
inputs ,
128
86
expected_ops ,
129
- [(comparator_shape_dtype_device , [False , False ])],
87
+ [(comparator_shape_dtype_device , [False ])],
130
88
use_dynamo_tracer = True ,
131
89
)
132
90
0 commit comments