@@ -979,3 +979,27 @@ func.func @vector_scalable_extract(%sv: vector<[8]xi32>) {
979979 %2 = vector.scalable.extract %sv [4 ] : vector <4 xi32 > from vector <[8 ]xi32 >
980980 return
981981 }
982+
983+ #matmat_accesses = [
984+ affine_map <(i , j , k ) -> (i , k )>,
985+ affine_map <(i , j , k ) -> (k , j )>,
986+ affine_map <(i , j , k ) -> (i , j )>
987+ ]
988+ #matmat_trait = {
989+ indexing_maps = #matmat_accesses ,
990+ iterator_types = [" parallel" , " parallel" , " reduction" ]
991+ }
992+ // CHECK-LABEL: func.func @contraction_masked_scalable(
993+ // CHECK-SAME: %[[A:.*]]: vector<3x4xf32>,
994+ // CHECK-SAME: %[[B:.*]]: vector<4x[8]xf32>,
995+ // CHECK-SAME: %[[C:.*]]: vector<3x[8]xf32>,
996+ // CHECK-SAME: %[[M:.*]]: vector<3x[8]x4xi1>) -> vector<3x[8]xf32> {
997+ func.func @contraction_masked_scalable (%A: vector <3 x4 xf32 >,
998+ %B: vector <4 x[8 ]xf32 >,
999+ %C: vector <3 x[8 ]xf32 >,
1000+ %M : vector <3 x[8 ]x4 xi1 >) -> vector <3 x[8 ]xf32 > {
1001+ // CHECK: vector.mask %[[M]] { vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[A]], %[[B]], %[[C]] : vector<3x4xf32>, vector<4x[8]xf32> into vector<3x[8]xf32> } : vector<3x[8]x4xi1> -> vector<3x[8]xf32>
1002+ %0 = vector.mask %M { vector.contract #matmat_trait %A , %B , %C : vector <3 x4 xf32 >, vector <4 x[8 ]xf32 > into vector <3 x[8 ]xf32 > }
1003+ : vector <3 x[8 ]x4 xi1 > -> vector <3 x[8 ]xf32 >
1004+ return %0 : vector <3 x[8 ]xf32 >
1005+ }
0 commit comments