@@ -70,41 +70,10 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070
7171// -----
7272
73- // The shape of the memref and the vector don't match, but the vector is a
74- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75- // of the vector have no effect on the memref area read even if they
76- // span a non-contiguous part of the memref.
77-
78- func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
79- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
80-
81- %c0 = arith.constant 0 : index
82- %cst = arith.constant 0 : i8
83- %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
84- memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
85- return %res : vector <1 x1 x2 x2 xi8 >
86- }
87-
88- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
89- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
90- // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
91- // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
92- // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
93- // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
94- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
95- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<4xi8>
96- // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
97- // CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
98-
99- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
100- // CHECK-128B: memref.collapse_shape
101-
102- // -----
103-
10473// The shape of the memref and the vector don't match, but the vector is a
10574// contiguous subset of the memref, so "flattenable"
10675
107- func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
76+ func.func @transfer_read_dims_mismatch_contiguous (
10877 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <2 x3 x2 xi8 > {
10978
11079 %c0 = arith.constant 0 : index
@@ -114,7 +83,7 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
11483 return %res : vector <2 x3 x2 xi8 >
11584}
11685
117- // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims (
86+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous (
11887// CHECK-SAME: %[[MEM:.+]]: memref<5x4x3x2xi8, {{.+}}>) -> vector<2x3x2xi8> {
11988// CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
12089// CHECK: %[[C0:.+]] = arith.constant 0 : index
@@ -126,9 +95,73 @@ func.func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
12695// CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<12xi8> to vector<2x3x2xi8>
12796// CHECK: return %[[VEC]] : vector<2x3x2xi8>
12897
129- // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_non_unit_dims(
98+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous(
99+ // CHECK-128B: memref.collapse_shape
100+
101+ // -----
102+
103+ // The shape of the memref and the vector don't match, but the mismatch is only
104+ // at the leading unit dimensions of the vector.
105+
106+ func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
107+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x4 x3 x2 xi8 > {
108+
109+ %c0 = arith.constant 0 : index
110+ %cst = arith.constant 0 : i8
111+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 , %c0 ], %cst :
112+ memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x4 x3 x2 xi8 >
113+ return %res : vector <1 x1 x4 x3 x2 xi8 >
114+ }
115+
116+ // CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
117+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>)
118+ // CHECK-SAME: -> vector<1x1x4x3x2xi8>
119+ // CHECK: %[[C0_I8:.+]] = arith.constant 0 : i8
120+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
121+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
122+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
123+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
124+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
125+ // CHECK: %[[VEC_1D:.+]] = vector.transfer_read %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]], %[[C0_I8]]
126+ // CHECK-SAME: {in_bounds = [true]} : memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>, vector<24xi8>
127+ // CHECK: %[[VEC:.+]] = vector.shape_cast %[[VEC_1D]] : vector<24xi8> to vector<1x1x4x3x2xi8>
128+ // CHECK: return %[[VEC]] : vector<1x1x4x3x2xi8>
129+
130+ // CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous_unit_dims(
131+ // CHECK-128B: memref.collapse_shape
132+
133+ // -----
134+
135+ // The memref is non-contiguous, but the vector is a contiguous subset of the
136+ // memref, so "flattenable". The leading unit dimensions of the vector have no
137+ // effect on the memref area read even if they span the non-contiguous part of
138+ // the memref.
139+
140+ func.func @transfer_read_non_contiguous_unit_dims (
141+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x3 x2 xi8 > {
142+
143+ %c0 = arith.constant 0 : index
144+ %cst = arith.constant 0 : i8
145+ %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
146+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x3 x2 xi8 >
147+ return %res : vector <1 x1 x3 x2 xi8 >
148+ }
149+
150+ // CHECK-LABEL: func.func @transfer_read_non_contiguous_unit_dims(
151+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>>) -> vector<1x1x3x2xi8> {
152+ // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
153+ // CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
154+ // CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
155+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
156+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
157+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>, vector<6xi8>
158+ // CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<6xi8> to vector<1x1x3x2xi8>
159+ // CHECK: return %[[VAL_5]] : vector<1x1x3x2xi8>
160+
161+ // CHECK-128B-LABEL: func @transfer_read_non_contiguous_unit_dims(
130162// CHECK-128B: memref.collapse_shape
131163
164+
132165// -----
133166
134167func.func @transfer_read_dims_mismatch_non_zero_indices (
@@ -414,61 +447,92 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
414447// -----
415448
416449// The shape of the memref and the vector don't match, but the vector is a
417- // contiguous subset of the memref, so "flattenable". The leading unit dimensions
418- // of the vector have no effect on the memref area written even if they
419- // span a non-contiguous part of the memref.
450+ // contiguous subset of the memref, so "flattenable".
420451
421- func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
422- %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
423- %vec : vector <1 x 1 x 2 x 2 x i8 >) {
452+ func.func @transfer_write_dims_mismatch_contiguous (
453+ %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
454+ %vec : vector <2 x 2 x i8 >) {
424455
425456 %c0 = arith.constant 0 : index
426457 vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
427- vector <1 x 1 x 2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
458+ vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
428459 return
429460}
430461
431- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
432- // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
433- // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x2x2xi8>) {
434- // CHECK: %[[C0:.* ]] = arith.constant 0 : index
435- // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
462+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous
463+ // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
464+ // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
465+ // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
466+ // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
436467// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
437- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
438- // CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
439- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
440- // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
468+ // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}}>
469+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<2x2xi8> to vector<4xi8>
470+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
471+ // CHECK-SAME: : vector<4xi8>, memref<5x4x6xi8, {{.+}}>
472+
473+ // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous(
474+ // CHECK-128B: memref.collapse_shape
475+
476+ // -----
477+
478+ // The shape of the memref and the vector don't match, but the mismatch is only
479+ // at the leading unit dimensions of the vector.
480+
481+ func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
482+ %mem : memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>,
483+ %vec : vector <1 x1 x4 x3 x2 xi8 >) {
484+
485+ %c0 = arith.constant 0 : index
486+ vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 , %c0 ] :
487+ vector <1 x1 x4 x3 x2 xi8 >, memref <6 x5 x4 x3 x2 xi8 , strided <[120 , 24 , 6 , 2 , 1 ], offset : ?>>
488+
489+ return
490+ }
491+
492+ // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims(
493+ // CHECK-SAME: %[[MEM:.+]]: memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
494+ // CHECK-SAME: %[[VEC:.+]]: vector<1x1x4x3x2xi8>
495+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
496+ // CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[MEM]]
497+ // CHECK-SAME{LITERAL}: [[0], [1], [2, 3, 4]]
498+ // CHECK-SAME: : memref<6x5x4x3x2xi8, strided<[120, 24, 6, 2, 1], offset: ?>>
499+ // CHECK-SAME: into memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
500+ // CHECK: %[[VEC_1D:.+]] = vector.shape_cast %[[VEC]] : vector<1x1x4x3x2xi8> to vector<24xi8>
501+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
502+ // CHECK-SAME: {in_bounds = [true]} : vector<24xi8>, memref<6x5x24xi8, strided<[120, 24, 1], offset: ?>>
441503
442504// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
443505// CHECK-128B: memref.collapse_shape
444506
445507// -----
446508
447- // The shape of the memref and the vector don't match, but the vector is a
448- // contiguous subset of the memref, so "flattenable".
509+ // The memref is non-contiguous, but the vector is a contiguous subset of the
510+ // memref, so "flattenable". The leading unit dimensions of the vector have no
511+ // effect on the memref area read even if they span the non-contiguous part of
512+ // the memref.
449513
450- func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
451- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
452- %vec : vector <2 x 2 x i8 >) {
514+ func.func @transfer_write_non_contiguous_unit_dims (
515+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
516+ %vec : vector <1 x 1 x 3 x 2 x i8 >) {
453517
454518 %c0 = arith.constant 0 : index
455519 vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
456- vector <2 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
520+ vector <1 x 1 x 3 x 2 x i8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
457521 return
458522}
459523
460- // CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims
461- // CHECK-SAME: %[[MEM:.+ ]]: memref<5x4x3x2xi8, {{.+}} >,
462- // CHECK-SAME: %[[VEC:.+ ]]: vector<2x2xi8>
463- // CHECK: %[[C0:.+ ]] = arith.constant 0 : index
464- // CHECK: %[[COLLAPSED_MEM:.+ ]] = memref.collapse_shape %[[MEM]]
524+ // CHECK-LABEL: func.func @transfer_write_non_contiguous_unit_dims
525+ // CHECK-SAME: %[[MEM:.* ]]: memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?> >,
526+ // CHECK-SAME: %[[VEC:.* ]]: vector<1x1x3x2xi8>) {
527+ // CHECK: %[[C0:.* ]] = arith.constant 0 : index
528+ // CHECK: %[[COLLAPSED:.* ]] = memref.collapse_shape %[[MEM]]
465529// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
466- // CHECK-SAME: : memref<5x4x3x2xi8, {{.+}}> into memref<5x4x6xi8, {{.+}} >
467- // CHECK: %[[VEC_1D:.+ ]] = vector.shape_cast %[[VEC]] : vector<2x2xi8 > to vector<4xi8 >
468- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED_MEM ]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]}
469- // CHECK-SAME: : vector<4xi8 >, memref<5x4x6xi8, {{.+}} >
530+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48, 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
531+ // CHECK: %[[VEC_1D:.* ]] = vector.shape_cast %[[VEC]] : vector<1x1x3x2xi8 > to vector<6xi8 >
532+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED ]][%[[C0]], %[[C0]], %[[C0]]]
533+ // CHECK-SAME: {in_bounds = [true]} : vector<6xi8 >, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?> >
470534
471- // CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
535+ // CHECK-128B-LABEL: func @transfer_write_non_contiguous_unit_dims (
472536// CHECK-128B: memref.collapse_shape
473537
474538// -----
@@ -714,4 +778,3 @@ func.func @negative_out_of_bound_transfer_write(
714778// CHECK-128B-LABEL: func.func @negative_out_of_bound_transfer_write
715779// CHECK-128B-NOT: memref.collapse_shape
716780// CHECK-128B-NOT: vector.shape_cast
717-
0 commit comments