11// RUN: mlir-opt %s -test-vector-to-vector-lowering | FileCheck %s
22
3- // CHECK-LABEL: func @maskedload0(
4- // CHECK-SAME: %[[A0:.*]]: memref<?xf32>,
5- // CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
6- // CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
7- // CHECK-NEXT: %[[T:.*]] = vector.load %[[A0]][%[[C]]] : memref<?xf32>, vector<16xf32>
8- // CHECK-NEXT: return %[[T]] : vector<16xf32>
9- func.func @maskedload0 (%base: memref <?xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
3+ //-----------------------------------------------------------------------------
4+ // [Pattern: MaskedLoadFolder]
5+ //-----------------------------------------------------------------------------
6+
7+ // CHECK-LABEL: func @fold_maskedload_all_true_dynamic(
8+ // CHECK-SAME: %[[BASE:.*]]: memref<?xf32>,
9+ // CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
10+ // CHECK-DAG: %[[IDX:.*]] = arith.constant 0 : index
11+ // CHECK-NEXT: %[[LOAD:.*]] = vector.load %[[BASE]][%[[IDX]]] : memref<?xf32>, vector<16xf32>
12+ // CHECK-NEXT: return %[[LOAD]] : vector<16xf32>
13+ func.func @fold_maskedload_all_true_dynamic (%base: memref <?xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
1014 %c0 = arith.constant 0 : index
1115 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
1216 %ld = vector.maskedload %base [%c0 ], %mask , %pass_thru
1317 : memref <?xf32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
1418 return %ld : vector <16 xf32 >
1519}
1620
17- // CHECK-LABEL: func @maskedload1 (
18- // CHECK-SAME: %[[A0 :.*]]: memref<16xf32>,
19- // CHECK-SAME: %[[A1 :.*]]: vector<16xf32>) -> vector<16xf32> {
20- // CHECK-DAG: %[[C :.*]] = arith.constant 0 : index
21- // CHECK-NEXT: %[[T :.*]] = vector.load %[[A0 ]][%[[C ]]] : memref<16xf32>, vector<16xf32>
22- // CHECK-NEXT: return %[[T ]] : vector<16xf32>
23- func.func @maskedload1 (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
21+ // CHECK-LABEL: func @fold_maskedload_all_true_static (
22+ // CHECK-SAME: %[[BASE :.*]]: memref<16xf32>,
23+ // CHECK-SAME: %[[PASS_THRU :.*]]: vector<16xf32>) -> vector<16xf32> {
24+ // CHECK-DAG: %[[IDX :.*]] = arith.constant 0 : index
25+ // CHECK-NEXT: %[[LOAD :.*]] = vector.load %[[BASE ]][%[[IDX ]]] : memref<16xf32>, vector<16xf32>
26+ // CHECK-NEXT: return %[[LOAD ]] : vector<16xf32>
27+ func.func @fold_maskedload_all_true_static (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
2428 %c0 = arith.constant 0 : index
2529 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
2630 %ld = vector.maskedload %base [%c0 ], %mask , %pass_thru
2731 : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
2832 return %ld : vector <16 xf32 >
2933}
3034
31- // CHECK-LABEL: func @maskedload2 (
32- // CHECK-SAME: %[[A0 :.*]]: memref<16xf32>,
33- // CHECK-SAME: %[[A1 :.*]]: vector<16xf32>) -> vector<16xf32> {
34- // CHECK-NEXT: return %[[A1 ]] : vector<16xf32>
35- func.func @maskedload2 (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
35+ // CHECK-LABEL: func @fold_maskedload_all_false_static (
36+ // CHECK-SAME: %[[BASE :.*]]: memref<16xf32>,
37+ // CHECK-SAME: %[[PASS_THRU :.*]]: vector<16xf32>) -> vector<16xf32> {
38+ // CHECK-NEXT: return %[[PASS_THRU ]] : vector<16xf32>
39+ func.func @fold_maskedload_all_false_static (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
3640 %c0 = arith.constant 0 : index
3741 %mask = vector.constant_mask [0 ] : vector <16 xi1 >
3842 %ld = vector.maskedload %base [%c0 ], %mask , %pass_thru
3943 : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
4044 return %ld : vector <16 xf32 >
4145}
4246
43- // CHECK-LABEL: func @maskedload3 (
44- // CHECK-SAME: %[[A0 :.*]]: memref<?xf32>,
45- // CHECK-SAME: %[[A1 :.*]]: vector<16xf32>) -> vector<16xf32> {
46- // CHECK-DAG: %[[C :.*]] = arith.constant 8 : index
47- // CHECK-NEXT: %[[T :.*]] = vector.load %[[A0 ]][%[[C ]]] : memref<?xf32>, vector<16xf32>
48- // CHECK-NEXT: return %[[T ]] : vector<16xf32>
49- func.func @maskedload3 (%base: memref <?xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
47+ // CHECK-LABEL: func @fold_maskedload_dynamic_non_zero_idx (
48+ // CHECK-SAME: %[[BASE :.*]]: memref<?xf32>,
49+ // CHECK-SAME: %[[PASS_THRU :.*]]: vector<16xf32>) -> vector<16xf32> {
50+ // CHECK-DAG: %[[IDX :.*]] = arith.constant 8 : index
51+ // CHECK-NEXT: %[[LOAD :.*]] = vector.load %[[BASE ]][%[[IDX ]]] : memref<?xf32>, vector<16xf32>
52+ // CHECK-NEXT: return %[[LOAD ]] : vector<16xf32>
53+ func.func @fold_maskedload_dynamic_non_zero_idx (%base: memref <?xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
5054 %c8 = arith.constant 8 : index
5155 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
5256 %ld = vector.maskedload %base [%c8 ], %mask , %pass_thru
5357 : memref <?xf32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
5458 return %ld : vector <16 xf32 >
5559}
5660
57- // CHECK-LABEL: func @maskedstore1(
58- // CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
59- // CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
60- // CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
61- // CHECK-NEXT: vector.store %[[A1]], %[[A0]][%[[C]]] : memref<16xf32>, vector<16xf32>
61+ //-----------------------------------------------------------------------------
62+ // [Pattern: MaskedStoreFolder]
63+ //-----------------------------------------------------------------------------
64+
65+ // CHECK-LABEL: func @fold_maskedstore_all_true(
66+ // CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
67+ // CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
68+ // CHECK-NEXT: %[[IDX:.*]] = arith.constant 0 : index
69+ // CHECK-NEXT: vector.store %[[VALUE]], %[[BASE]][%[[IDX]]] : memref<16xf32>, vector<16xf32>
6270// CHECK-NEXT: return
63- func.func @maskedstore1 (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
71+ func.func @fold_maskedstore_all_true (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
6472 %c0 = arith.constant 0 : index
6573 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
6674 vector.maskedstore %base [%c0 ], %mask , %value : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 >
6775 return
6876}
6977
70- // CHECK-LABEL: func @maskedstore2 (
71- // CHECK-SAME: %[[A0 :.*]]: memref<16xf32>,
72- // CHECK-SAME: %[[A1 :.*]]: vector<16xf32>) {
78+ // CHECK-LABEL: func @fold_maskedstore_all_false (
79+ // CHECK-SAME: %[[BASE :.*]]: memref<16xf32>,
80+ // CHECK-SAME: %[[VALUE :.*]]: vector<16xf32>) {
7381// CHECK-NEXT: return
74- func.func @maskedstore2 (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
82+ func.func @fold_maskedstore_all_false (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
7583 %c0 = arith.constant 0 : index
7684 %mask = vector.constant_mask [0 ] : vector <16 xi1 >
7785 vector.maskedstore %base [%c0 ], %mask , %value : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 >
7886 return
7987}
8088
81- // CHECK-LABEL: func @gather1(
82- // CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
83- // CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
84- // CHECK-SAME: %[[A2:.*]]: vector<16xf32>) -> vector<16xf32> {
89+ //-----------------------------------------------------------------------------
90+ // [Pattern: GatherFolder]
91+ //-----------------------------------------------------------------------------
92+
93+ /// There is no alternative (i.e. simpler) Op for this, hence no-fold.
94+
95+ // CHECK-LABEL: func @no_fold_gather_all_true(
96+ // CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
97+ // CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
98+ // CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
8599// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
86100// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
87- // CHECK-NEXT: %[[G:.*]] = vector.gather %[[A0 ]][%[[C]]] [%[[A1 ]]], %[[M]], %[[A2 ]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
101+ // CHECK-NEXT: %[[G:.*]] = vector.gather %[[BASE ]][%[[C]]] [%[[INDICES ]]], %[[M]], %[[PASS_THRU ]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
88102// CHECK-NEXT: return %[[G]] : vector<16xf32>
89- func.func @gather1 (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
103+ func.func @no_fold_gather_all_true (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
90104 %c0 = arith.constant 0 : index
91105 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
92106 %ld = vector.gather %base [%c0 ][%indices ], %mask , %pass_thru
93107 : memref <16 xf32 >, vector <16 xi32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
94108 return %ld : vector <16 xf32 >
95109}
96110
97- // CHECK-LABEL: func @gather2 (
98- // CHECK-SAME: %[[A0 :.*]]: memref<16xf32>,
99- // CHECK-SAME: %[[A1 :.*]]: vector<16xi32>,
100- // CHECK-SAME: %[[A2 :.*]]: vector<16xf32>) -> vector<16xf32> {
101- // CHECK-NEXT: return %[[A2 ]] : vector<16xf32>
102- func.func @gather2 (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
111+ // CHECK-LABEL: func @fold_gather_all_true (
112+ // CHECK-SAME: %[[BASE :.*]]: memref<16xf32>,
113+ // CHECK-SAME: %[[INDICES :.*]]: vector<16xi32>,
114+ // CHECK-SAME: %[[PASS_THRU :.*]]: vector<16xf32>) -> vector<16xf32> {
115+ // CHECK-NEXT: return %[[PASS_THRU ]] : vector<16xf32>
116+ func.func @fold_gather_all_true (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
103117 %c0 = arith.constant 0 : index
104118 %mask = vector.constant_mask [0 ] : vector <16 xi1 >
105119 %ld = vector.gather %base [%c0 ][%indices ], %mask , %pass_thru
106120 : memref <16 xf32 >, vector <16 xi32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
107121 return %ld : vector <16 xf32 >
108122}
109123
110- // CHECK-LABEL: func @scatter1(
111- // CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
112- // CHECK-SAME: %[[A1:.*]]: vector<16xi32>,
113- // CHECK-SAME: %[[A2:.*]]: vector<16xf32>) {
124+ //-----------------------------------------------------------------------------
125+ // [Pattern: ScatterFolder]
126+ //-----------------------------------------------------------------------------
127+
128+ /// There is no alternative (i.e. simpler) Op for this, hence no-fold.
129+
130+ // CHECK-LABEL: func @no_fold_scatter_all_true(
131+ // CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
132+ // CHECK-SAME: %[[INDICES:.*]]: vector<16xi32>,
133+ // CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
114134// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
115135// CHECK-NEXT: %[[M:.*]] = arith.constant dense<true> : vector<16xi1>
116- // CHECK-NEXT: vector.scatter %[[A0 ]][%[[C]]] [%[[A1 ]]], %[[M]], %[[A2 ]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
136+ // CHECK-NEXT: vector.scatter %[[BASE ]][%[[C]]] [%[[INDICES ]]], %[[M]], %[[VALUE ]] : memref<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
117137// CHECK-NEXT: return
118- func.func @scatter1 (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %value: vector <16 xf32 >) {
138+ func.func @no_fold_scatter_all_true (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %value: vector <16 xf32 >) {
119139 %c0 = arith.constant 0 : index
120140 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
121141 vector.scatter %base [%c0 ][%indices ], %mask , %value
122142 : memref <16 xf32 >, vector <16 xi32 >, vector <16 xi1 >, vector <16 xf32 >
123143 return
124144}
125145
126- // CHECK-LABEL: func @scatter2 (
127- // CHECK-SAME: %[[A0 :.*]]: memref<16xf32>,
128- // CHECK-SAME: %[[A1 :.*]]: vector<16xi32>,
129- // CHECK-SAME: %[[A2 :.*]]: vector<16xf32>) {
146+ // CHECK-LABEL: func @fold_scatter_all_false (
147+ // CHECK-SAME: %[[BASE :.*]]: memref<16xf32>,
148+ // CHECK-SAME: %[[INDICES :.*]]: vector<16xi32>,
149+ // CHECK-SAME: %[[VALUE :.*]]: vector<16xf32>) {
130150// CHECK-NEXT: return
131- func.func @scatter2 (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %value: vector <16 xf32 >) {
151+ func.func @fold_scatter_all_false (%base: memref <16 xf32 >, %indices: vector <16 xi32 >, %value: vector <16 xf32 >) {
132152 %c0 = arith.constant 0 : index
133153 %0 = vector.type_cast %base : memref <16 xf32 > to memref <vector <16 xf32 >>
134154 %mask = vector.constant_mask [0 ] : vector <16 xi1 >
@@ -137,50 +157,58 @@ func.func @scatter2(%base: memref<16xf32>, %indices: vector<16xi32>, %value: vec
137157 return
138158}
139159
140- // CHECK-LABEL: func @expand1(
141- // CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
142- // CHECK-SAME: %[[A1:.*]]: vector<16xf32>) -> vector<16xf32> {
160+ //-----------------------------------------------------------------------------
161+ // [Pattern: ExpandLoadFolder]
162+ //-----------------------------------------------------------------------------
163+
164+ // CHECK-LABEL: func @fold_expandload_all_true(
165+ // CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
166+ // CHECK-SAME: %[[PASS_THRU:.*]]: vector<16xf32>) -> vector<16xf32> {
143167// CHECK-DAG: %[[C:.*]] = arith.constant 0 : index
144- // CHECK-NEXT: %[[T:.*]] = vector.load %[[A0 ]][%[[C]]] : memref<16xf32>, vector<16xf32>
168+ // CHECK-NEXT: %[[T:.*]] = vector.load %[[BASE ]][%[[C]]] : memref<16xf32>, vector<16xf32>
145169// CHECK-NEXT: return %[[T]] : vector<16xf32>
146- func.func @expand1 (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
170+ func.func @fold_expandload_all_true (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
147171 %c0 = arith.constant 0 : index
148172 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
149173 %ld = vector.expandload %base [%c0 ], %mask , %pass_thru
150174 : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
151175 return %ld : vector <16 xf32 >
152176}
153177
154- // CHECK-LABEL: func @expand2 (
155- // CHECK-SAME: %[[A0 :.*]]: memref<16xf32>,
156- // CHECK-SAME: %[[A1 :.*]]: vector<16xf32>) -> vector<16xf32> {
157- // CHECK-NEXT: return %[[A1 ]] : vector<16xf32>
158- func.func @expand2 (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
178+ // CHECK-LABEL: func @fold_expandload_all_false (
179+ // CHECK-SAME: %[[BASE :.*]]: memref<16xf32>,
180+ // CHECK-SAME: %[[PASS_THRU :.*]]: vector<16xf32>) -> vector<16xf32> {
181+ // CHECK-NEXT: return %[[PASS_THRU ]] : vector<16xf32>
182+ func.func @fold_expandload_all_false (%base: memref <16 xf32 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
159183 %c0 = arith.constant 0 : index
160184 %mask = vector.constant_mask [0 ] : vector <16 xi1 >
161185 %ld = vector.expandload %base [%c0 ], %mask , %pass_thru
162186 : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 > into vector <16 xf32 >
163187 return %ld : vector <16 xf32 >
164188}
165189
166- // CHECK-LABEL: func @compress1(
167- // CHECK-SAME: %[[A0:.*]]: memref<16xf32>,
168- // CHECK-SAME: %[[A1:.*]]: vector<16xf32>) {
190+ //-----------------------------------------------------------------------------
191+ // [Pattern: CompressStoreFolder]
192+ //-----------------------------------------------------------------------------
193+
194+ // CHECK-LABEL: func @fold_compressstore_all_true(
195+ // CHECK-SAME: %[[BASE:.*]]: memref<16xf32>,
196+ // CHECK-SAME: %[[VALUE:.*]]: vector<16xf32>) {
169197// CHECK-NEXT: %[[C:.*]] = arith.constant 0 : index
170- // CHECK-NEXT: vector.store %[[A1 ]], %[[A0 ]][%[[C]]] : memref<16xf32>, vector<16xf32>
198+ // CHECK-NEXT: vector.store %[[VALUE ]], %[[BASE ]][%[[C]]] : memref<16xf32>, vector<16xf32>
171199// CHECK-NEXT: return
172- func.func @compress1 (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
200+ func.func @fold_compressstore_all_true (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
173201 %c0 = arith.constant 0 : index
174202 %mask = vector.constant_mask [16 ] : vector <16 xi1 >
175203 vector.compressstore %base [%c0 ], %mask , %value : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 >
176204 return
177205}
178206
179- // CHECK-LABEL: func @compress2 (
180- // CHECK-SAME: %[[A0 :.*]]: memref<16xf32>,
181- // CHECK-SAME: %[[A1 :.*]]: vector<16xf32>) {
207+ // CHECK-LABEL: func @fold_compressstore_all_false (
208+ // CHECK-SAME: %[[BASE :.*]]: memref<16xf32>,
209+ // CHECK-SAME: %[[VALUE :.*]]: vector<16xf32>) {
182210// CHECK-NEXT: return
183- func.func @compress2 (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
211+ func.func @fold_compressstore_all_false (%base: memref <16 xf32 >, %value: vector <16 xf32 >) {
184212 %c0 = arith.constant 0 : index
185213 %mask = vector.constant_mask [0 ] : vector <16 xi1 >
186214 vector.compressstore %base [%c0 ], %mask , %value : memref <16 xf32 >, vector <16 xi1 >, vector <16 xf32 >
0 commit comments