@@ -468,18 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468468
469469// -----
470470
471- // CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472- // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index)
473- func.func @fold_dynamic_subview_with_memref_load_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0: index ) -> f32 {
471+ // CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
472+ // CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
473+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32
474+ func.func @fold_dynamic_subview_with_memref_load_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0: index ) -> f32 {
474475 %c0 = arith.constant 0 : index
475476 %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476477 %0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
477478 return %0 : f32
478479}
479- // CHECK: %[[C0:.*]] = arith.constant 0 : index
480- // CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
481- // CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
482- // CHECK: return %[[VAL_0]] : f32
480+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
481+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
482+ // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
483+ // CHECK-NEXT: return %[[VAL1]] : f32
484+
485+ // -----
486+
487+ // CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
488+ // CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
489+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
490+ func.func @fold_dynamic_subview_with_memref_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index , %sz0 : index ) {
491+ %c0 = arith.constant 0 : index
492+ %c1f32 = arith.constant 1.0 : f32
493+ %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
494+ memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
495+ return
496+ }
497+ // CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
498+ // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
499+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
500+ // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
501+ // CHECK-NEXT: return
502+
503+ // -----
504+
505+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
506+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
507+ // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
508+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index)
509+ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim (%alloc: memref <2048 x16 xf32 >, %c10: index , %c5: index , %c0: index , %sz0: index ) {
510+ %subview = memref.subview %alloc [%c5 , 0 ] [%c10 , 16 ] [1 , 1 ] : memref <2048 x16 xf32 > to memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
511+ %expand_shape = memref.expand_shape %subview [[0 ], [1 , 2 , 3 ]] output_shape [1 , 16 , %sz0 , 1 ] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>> into memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
512+ %dim = memref.dim %expand_shape , %c0 : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
513+
514+ affine.for %arg6 = 0 to %dim step 64 {
515+ affine.for %arg7 = 0 to 16 step 16 {
516+ %dummy_load = affine.load %expand_shape [%arg6 , 0 , %arg7 , %arg7 ] : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
517+ affine.store %dummy_load , %subview [%arg6 , %arg7 ] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
518+ }
519+ }
520+ return
521+ }
522+ // CHECK-NEXT: memref.subview
523+ // CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
524+ // CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
525+ // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
526+ // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
527+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
528+ // CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
529+ // CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
530+ // CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
531+ // CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
483532
484533// -----
485534
0 commit comments