@@ -70,28 +70,29 @@ func.func @transfer_read_dims_match_contiguous_empty_stride(
7070
7171// -----
7272
73- // The shape of the memref and the vector don't match, but the vector,
74- // ignoring the unit dimensions, is a contiguous subset of the memref,
75- // so "flattenable"
73+ // The shape of the memref and the vector don't match, but the vector is a
74+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
75+ // of the vector have no effect on the memref area read even if they
76+ // span a non-contiguous part of the memref.
7677
7778func.func @transfer_read_dims_mismatch_contiguous_unit_dims (
78- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
79+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>) -> vector <1 x1 x2 x2 xi8 > {
7980
8081 %c0 = arith.constant 0 : index
8182 %cst = arith.constant 0 : i8
8283 %res = vector.transfer_read %mem [%c0 , %c0 , %c0 , %c0 ], %cst :
83- memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
84+ memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>, vector <1 x1 x2 x2 xi8 >
8485 return %res : vector <1 x1 x2 x2 xi8 >
8586}
8687
8788// CHECK-LABEL: func.func @transfer_read_dims_mismatch_contiguous_unit_dims(
88- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
89+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>) -> vector<1x1x2x2xi8> {
8990// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i8
9091// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
9192// CHECK: %[[VAL_3:.*]] = memref.collapse_shape %[[MEM]]
9293// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
93- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
94- // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>, vector<4xi8>
94+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
95+ // CHECK: %[[VAL_4:.*]] = vector.transfer_read %[[VAL_3]][%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]], %[[VAL_1]] {in_bounds = [true]} : memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>, vector<4xi8>
9596// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8>
9697// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8>
9798
@@ -416,31 +417,40 @@ func.func @transfer_write_dims_match_contiguous_empty_stride(
416417
417418// -----
418419
420+ // The shape of the memref and the vector don't match, but the vector is a
421+ // contiguous subset of the memref, so "flattenable". The leading unit dimensions
422+ // of the vector have no effect on the memref area written even if they
423+ // span a non-contiguous part of the memref.
424+
419425func.func @transfer_write_dims_mismatch_contiguous_unit_dims (
420- %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
426+ %mem : memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>,
421427 %vec : vector <1 x1 x2 x2 xi8 >) {
422428
423429 %c0 = arith.constant 0 : index
424430 vector.transfer_write %vec , %mem [%c0 , %c0 , %c0 , %c0 ] :
425- vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>
431+ vector <1 x1 x2 x2 xi8 >, memref <5 x4 x3 x2 xi8 , strided <[48 , 6 , 2 , 1 ], offset : ?>>
426432 return
427433}
428434
429435// CHECK-LABEL: func.func @transfer_write_dims_mismatch_contiguous_unit_dims
430- // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>>,
436+ // CHECK-SAME: %[[MEM:.*]]: memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>>,
431437// CHECK-SAME: %[[VEC:.*]]: vector<1x1x2x2xi8>) {
432438// CHECK: %[[C0:.*]] = arith.constant 0 : index
433439// CHECK: %[[COLLAPSED:.*]] = memref.collapse_shape %[[MEM]]
434440// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]]
435- // CHECK-SAME: : memref<5x4x3x2xi8, strided<[24 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[24 , 6, 1], offset: ?>>
441+ // CHECK-SAME: : memref<5x4x3x2xi8, strided<[48 , 6, 2, 1], offset: ?>> into memref<5x4x6xi8, strided<[48 , 6, 1], offset: ?>>
436442// CHECK: %[[VEC_1D:.*]] = vector.shape_cast %[[VEC]] : vector<1x1x2x2xi8> to vector<4xi8>
437- // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[24, 6, 1], offset: ?>>
443+ // CHECK: vector.transfer_write %[[VEC_1D]], %[[COLLAPSED]][%[[C0]], %[[C0]], %[[C0]]]
444+ // CHECK-SAME: {in_bounds = [true]} : vector<4xi8>, memref<5x4x6xi8, strided<[48, 6, 1], offset: ?>>
438445
439446// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous_unit_dims(
440447// CHECK-128B: memref.collapse_shape
441448
442449// -----
443450
451+ // The shape of the memref and the vector don't match, but the vector is a
452+ // contiguous subset of the memref, so "flattenable".
453+
444454func.func @transfer_write_dims_mismatch_contiguous_non_unit_dims (
445455 %mem : memref <5 x4 x3 x2 xi8 , strided <[24 , 6 , 2 , 1 ], offset : ?>>,
446456 %vec : vector <2 x2 xi8 >) {
0 commit comments