@@ -468,16 +468,67 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar
468468
469469// -----
470470
471- // CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape
472- func.func @fold_dynamic_subview_with_memref_load_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index ) -> f32 {
471+ // CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
472+ // CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape
473+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index) -> f32
474+ func.func @fold_dynamic_subview_with_memref_load_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index ) -> f32 {
473475 %c0 = arith.constant 0 : index
474476 %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
475477 %0 = memref.load %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
476478 return %0 : f32
477479}
478- // CHECK: %[[EXPAND_SHAPE:.+]] = memref.expand_shape {{.+}} : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>>
479- // CHECK: %[[LOAD:.+]] = memref.load %[[EXPAND_SHAPE]]
480- // CHECK: return %[[LOAD]]
480+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
481+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
482+ // CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
483+ // CHECK-NEXT: return %[[VAL1]] : f32
484+
485+ // -----
486+
487+ // CHECK-DAG: #[[MAP:.*]] = affine_map<()[s0, s1] -> (s1 * s0)>
488+ // CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape
489+ // CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
490+ func.func @fold_dynamic_subview_with_memref_store_expand_shape (%arg0 : memref <16 x?xf32 , strided <[16 , 1 ]>>, %arg1 : index , %arg2 : index ) {
491+ %c0 = arith.constant 0 : index
492+ %c1f32 = arith.constant 1.0 : f32
493+ %expand_shape = memref.expand_shape %arg0 [[0 , 1 ], [2 , 3 ]] : memref <16 x?xf32 , strided <[16 , 1 ]>> into memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
494+ memref.store %c1f32 , %expand_shape [%c0 , %arg1 , %arg2 , %c0 ] : memref <1 x16 x?x1 xf32 , strided <[256 , 16 , 1 , 1 ]>>
495+ return
496+ }
497+ // CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32
498+ // CHECK-NEXT: %[[C1:.*]] = arith.constant 1 : index
499+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP]]()[%[[ARG2]], %[[C1]]]
500+ // CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[VAL0]]] : memref<16x?xf32, strided<[16, 1]>>
501+ // CHECK-NEXT: return
502+
503+ // -----
504+
505+ // CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)>
506+ // CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> (d0 * 3)>
507+ // CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim
508+ // CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index)
509+ func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim (%alloc: memref <2048 x16 xf32 >, %c10: index , %c5: index , %c0: index ) {
510+ %subview = memref.subview %alloc [%c5 , 0 ] [%c10 , 16 ] [1 , 1 ] : memref <2048 x16 xf32 > to memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
511+ %expand_shape = memref.expand_shape %subview [[0 ], [1 , 2 , 3 ]] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>> into memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
512+ %dim = memref.dim %expand_shape , %c0 : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
513+
514+ affine.for %arg6 = 0 to %dim step 64 {
515+ affine.for %arg7 = 0 to 16 step 16 {
516+ %dummy_load = affine.load %expand_shape [%arg6 , 0 , %arg7 , %arg7 ] : memref <?x1 x8 x2 xf32 , strided <[16 , 16 , 2 , 1 ], offset : ?>>
517+ affine.store %dummy_load , %subview [%arg6 , %arg7 ] : memref <?x16 xf32 , strided <[16 , 1 ], offset : ?>>
518+ }
519+ }
520+ return
521+ }
522+ // CHECK-NEXT: memref.subview
523+ // CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape
524+ // CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref<?x1x8x2xf32, strided<[16, 16, 2, 1], offset: ?>>
525+ // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 {
526+ // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 {
527+ // CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
528+ // CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]](%[[ARG5]])
529+ // CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32>
530+ // CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]]
531+ // CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32>
481532
482533// -----
483534
0 commit comments