@@ -882,6 +882,16 @@ func.func @gather_and_scatter2d(%base: memref<?x?xf32>, %v: vector<16xi32>, %mas
882
882
return
883
883
}
884
884
885
+ // CHECK-LABEL: @gather_multi_dims
886
+ func.func @gather_multi_dims (%base: memref <?xf32 >, %v: vector <2 x16 xi32 >, %mask: vector <2 x16 xi1 >, %pass_thru: vector <2 x16 xf32 >) -> vector <2 x16 xf32 > {
887
+ %c0 = arith.constant 0 : index
888
+ // CHECK: %[[X:.*]] = vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
889
+ %0 = vector.gather %base [%c0 ][%v ], %mask , %pass_thru : memref <?xf32 >, vector <2 x16 xi32 >, vector <2 x16 xi1 >, vector <2 x16 xf32 > into vector <2 x16 xf32 >
890
+ // CHECK: vector.scatter %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %[[X]] : memref<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32>
891
+ vector.scatter %base [%c0 ][%v ], %mask , %0 : memref <?xf32 >, vector <2 x16 xi32 >, vector <2 x16 xi1 >, vector <2 x16 xf32 >
892
+ return %0 : vector <2 x16 xf32 >
893
+ }
894
+
885
895
// CHECK-LABEL: @gather_on_tensor
886
896
func.func @gather_on_tensor (%base: tensor <?xf32 >, %v: vector <16 xi32 >, %mask: vector <16 xi1 >, %pass_thru: vector <16 xf32 >) -> vector <16 xf32 > {
887
897
%c0 = arith.constant 0 : index
@@ -890,14 +900,6 @@ func.func @gather_on_tensor(%base: tensor<?xf32>, %v: vector<16xi32>, %mask: vec
890
900
return %0 : vector <16 xf32 >
891
901
}
892
902
893
- // CHECK-LABEL: @gather_multi_dims
894
- func.func @gather_multi_dims (%base: tensor <?xf32 >, %v: vector <2 x16 xi32 >, %mask: vector <2 x16 xi1 >, %pass_thru: vector <2 x16 xf32 >) -> vector <2 x16 xf32 > {
895
- %c0 = arith.constant 0 : index
896
- // CHECK: vector.gather %{{.*}}[%{{.*}}] [%{{.*}}], %{{.*}}, %{{.*}} : tensor<?xf32>, vector<2x16xi32>, vector<2x16xi1>, vector<2x16xf32> into vector<2x16xf32>
897
- %0 = vector.gather %base [%c0 ][%v ], %mask , %pass_thru : tensor <?xf32 >, vector <2 x16 xi32 >, vector <2 x16 xi1 >, vector <2 x16 xf32 > into vector <2 x16 xf32 >
898
- return %0 : vector <2 x16 xf32 >
899
- }
900
-
901
903
// CHECK-LABEL: @expand_and_compress
902
904
func.func @expand_and_compress (%base: memref <?xf32 >, %mask: vector <16 xi1 >, %pass_thru: vector <16 xf32 >) {
903
905
%c0 = arith.constant 0 : index
0 commit comments