@@ -70,41 +70,10 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070
7171// -----
7272
73- // The shape of the memref and the vector don't match, but the vector is a
74- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75- // of the vector have no effect on the memref area read even if they
76- // span a non-contiguous part of the memref.
77-
78- func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
79- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
80-
81- %c0 = arith.constant 0 : index
82- %cst = arith.constant 0 : i8
83- %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
84- memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
85- return %res : vector <1 x1 x2 x2 xi8 >
86- }
87-
88- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
89- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
90- // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
91- // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
92- // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
93- // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
94- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
95- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
96- // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
97- // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
98-
99- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
100- // CHECK-128B: memref.collapse_shape
101-
102- // -----
103-
10473// The shape of the memref and the vector don't match, but the vector is a
10574// contiguous subset of the memref, so "flattenable"
10675
107- func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
76+ func.func @transfer_read_dims_mismatch_contiguous (
10877 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
10978
11079 %c0 = arith.constant 0 : index
@@ -114,7 +83,7 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
11483 return %res : vector <2 x3 x2 xi8 >
11584}
11685
117- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
86+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
11887// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
11988// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
12089// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -126,9 +95,73 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
12695// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
12796// CHECK: return %[[VEC]] : vector<2x3x2xi8>
12897
129- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
98+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
99+ // CHECK-128B: memref.collapse_shape
100+
101+ // -----
102+
103+ // The shape of the memref and the vector don't match, but the mismatch is only
104+ // at the leading unit dimensions of the vector.
105+
106+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
107+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x4 x3 x2 xi8 > {
108+
109+ %c0 = arith.constant 0 : index
110+ %cst = arith.constant 0 : i8
111+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 , %c0 ], %cst :
112+ memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x4 x3 x2 xi8 >
113+ return %res : vector <1 x1 x4 x3 x2 xi8 >
114+ }
115+
116+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
117+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
118+ // CHECK-SAME: -> vector<1x1x4x3x2xi8>
119+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
121+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
122+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
123+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
124+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
125+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
126+ // CHECK-SAME: {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
127+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
128+ // CHECK: return %[[VEC]] : vector<1x1x4x3x2xi8>
129+
130+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
131+ // CHECK-128B: memref.collapse_shape
132+
133+ // -----
134+
135+ // The memref is non-contiguous, but the vector is a contiguous subset of the
136+ // memref, so "flattenable". The leading unit dimensions of the vector have no
137+ // effect on the memref area read even if they span the non-contiguous part of
138+ // the memref.
139+
140+ func.func @transfer_read_non_contiguous_unit_dims (
141+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x3 x2 xi8 > {
142+
143+ %c0 = arith.constant 0 : index
144+ %cst = arith.constant 0 : i8
145+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
146+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x3 x2 xi8 >
147+ return %res : vector <1 x1 x3 x2 xi8 >
148+ }
149+
150+ // CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
151+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
152+ // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
153+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
154+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
155+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
156+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
157+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
158+ // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
159+ // CHECK: return %[[VAL_5]] : vector<1x1x3x2xi8>
160+
161+ // CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
130162// CHECK-128B: memref.collapse_shape
131163
164+
132165// -----
133166
134167func.func @transfer_read_dims_mismatch_non_zero_indices (
@@ -418,61 +451,92 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
418451// -----
419452
420453// The shape of the memref and the vector don't match, but the vector is a
421- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
422- // of the vector have no effect on the memref area written even if they
423- // span a non-contiguous part of the memref.
454+ // contiguous subset of the memref, so "flattenable".
424455
425- func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
426- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
427- %vec : vector <1 x 1 x 2 x 2 x i8 >) {
456+ func.func @transfer_write_dims_mismatch_contiguous (
457+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
458+ %vec : vector <2 x 2 x i8 >) {
428459
429460 %c0 = arith.constant 0 : index
430461 vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
431- vector <1 x 1 x 2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
462+ vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
432463 return
433464}
434465
435- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
436- // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
437- // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x2x2xi8>) {
438- // CHECK: %[[C0:.* ]] = arith.constant 0 : index
439- // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
466+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
467+ // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
468+ // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
469+ // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
470+ // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
440471// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
441- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
442- // CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
443- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
444- // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
472+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
473+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
474+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
475+ // CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
476+
477+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
478+ // CHECK-128B: memref.collapse_shape
479+
480+ // -----
481+
482+ // The shape of the memref and the vector don't match, but the mismatch is only
483+ // at the leading unit dimensions of the vector.
484+
485+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
486+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>,
487+ %vec : vector <1 x1 x4 x3 x2 xi8 >) {
488+
489+ %c0 = arith.constant 0 : index
490+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 , %c0 ] :
491+ vector <1 x1 x4 x3 x2 xi8 >, memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>
492+
493+ return
494+ }
495+
496+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
497+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
498+ // CHECK-SAME: %[[VEC:.+]]: vector<1x1x4x3x2xi8>
499+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
500+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
501+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
502+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
503+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
504+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
505+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
506+ // CHECK-SAME: {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
445507
446508// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
447509// CHECK-128B: memref.collapse_shape
448510
449511// -----
450512
451- // The shape of the memref and the vector don't match, but the vector is a
452- // contiguous subset of the memref, so "flattenable".
513+ // The memref is non-contiguous, but the vector is a contiguous subset of the
514+ // memref, so "flattenable". The leading unit dimensions of the vector have no
515+ // effect on the memref area read even if they span the non-contiguous part of
516+ // the memref.
453517
454- func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
455- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
456- %vec : vector <2 x 2 x i8 >) {
518+ func.func @transfer_write_non_contiguous_unit_dims (
519+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
520+ %vec : vector <1 x 1 x 3 x 2 x i8 >) {
457521
458522 %c0 = arith.constant 0 : index
459523 vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
460- vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
524+ vector <1 x 1 x 3 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
461525 return
462526}
463527
464- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
465- // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
466- // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
467- // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
468- // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
528+ // CHECK-LABEL: func.func @transfer_write_non_contiguous_unit_dims
529+ // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
530+ // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x3x2xi8>) {
531+ // CHECK: %[[C0:.* ]] = arith.constant 0 : index
532+ // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
469533// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
470- // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}} >
471- // CHECK: %[[VEC_1D:.+ ]] = vector.shape_cast %[[VEC]] : vector<2x2xi8 > to vector<4xi8 >
472- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM ]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
473- // CHECK-SAME: : vector<4xi8 >, memref<5x4x6xi8, {{.+}} >
534+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
535+ // CHECK: %[[VEC_1D:.* ]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8 > to vector<6xi8 >
536+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED ]][%[[C0]], %[[C0]], %[[C0]]]
537+ // CHECK-SAME: {in_bounds = [true]} : vector<6xi8 >, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
474538
475- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
539+ // CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims (
476540// CHECK-128B: memref.collapse_shape
477541
478542// -----
@@ -718,4 +782,3 @@ func.func @negative_out_of_bound_transfer_write(
718782// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
719783// CHECK-128B-NOT: memref.collapse_shape
720784// CHECK-128B-NOT: vector.shape_cast
721-
0 commit comments