@@ -926,6 +926,24 @@ func.func @bubble_up_pack_through_collapse(%1: tensor<?x16x4xf32>, %dim : index)
926926
927927// -----
928928
929+ func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm (%1: tensor <?x16 x4 xf32 >, %dim : index ) -> tensor <?x4 x8 x1 xf32 > {
930+ %collapsed = tensor.collapse_shape %1 [[0 , 1 ], [2 ]] : tensor <?x16 x4 xf32 > into tensor <?x4 xf32 >
931+ %2 = tensor.empty (%dim ) : tensor <?x4 x8 x1 xf32 >
932+ %pack = tensor.pack %collapsed inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 1 ] into %2 : tensor <?x4 xf32 > -> tensor <?x4 x8 x1 xf32 >
933+ func.return %pack : tensor <?x4 x8 x1 xf32 >
934+ }
935+ // CHECK-LABEL: func.func @bubble_up_pack_through_collapse_empty_outer_dims_perm
936+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
937+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
938+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
939+ // CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x16x4xf32>
940+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x2x4x8x1xf32>
941+ // CHECK: %[[PACK:.+]] = tensor.pack %[[ARG0]] inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %[[EMPTY]] : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
942+ // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[PACK]] {{\[}}[0, 1], [2], [3], [4]] : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1xf32>
943+ // CHECK: return %[[COLLAPSED]] : tensor<?x4x8x1xf32>
944+
945+ // -----
946+
929947func.func @bubble_up_permuted_pack_through_collapse (%1: tensor <4 x192 x16 x256 xf32 >) -> tensor <4 x32 x3072 x8 x1 xf32 > {
930948 %collapsed = tensor.collapse_shape %1 [[0 ], [1 , 2 ], [3 ]] : tensor <4 x192 x16 x256 xf32 > into tensor <4 x3072 x256 xf32 >
931949 %2 = tensor.empty () : tensor <4 x32 x3072 x8 x1 xf32 >
@@ -1269,6 +1287,27 @@ func.func @push_down_unpack_through_expand(%5: tensor<?x32x8x8xf32>, %dim: index
12691287
12701288// -----
12711289
1290+ func.func @push_down_unpack_through_expand_empty_outer_dims_perm (%5: tensor <?x32 x8 x8 xf32 >, %dim: index , %sz0: index ) -> tensor <?x256 x256 xf32 > {
1291+ %6 = tensor.empty (%dim ) : tensor <?x256 xf32 >
1292+ %unpack = tensor.unpack %5 inner_dims_pos = [0 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <?x32 x8 x8 xf32 > -> tensor <?x256 xf32 >
1293+ %expanded = tensor.expand_shape %unpack [[0 , 1 ], [2 ]] output_shape [%sz0 , 256 , 256 ] : tensor <?x256 xf32 > into tensor <?x256 x256 xf32 >
1294+ func.return %expanded : tensor <?x256 x256 xf32 >
1295+ }
1296+ // CHECK-LABEL: func.func @push_down_unpack_through_expand_empty_outer_dims_perm
1297+ // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
1298+ // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
1299+ // CHECK: %[[C32:.+]] = arith.constant 32 : index
1300+ // CHECK: %[[C0:.+]] = arith.constant 0 : index
1301+ // CHECK: %[[DIM0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x32x8x8xf32>
1302+ // CHECK: %[[SZ0:.+]] = arith.divui %[[DIM0]], %[[C32]] : index
1303+ // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1], [2], [3], [4]] output_shape [%[[SZ0]], 32, 32, 8, 8] : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
1304+ // CHECK: %[[DIM:.+]] = tensor.dim %[[EXPANDED]], %[[C0]] : tensor<?x32x32x8x8xf32>
1305+ // CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?x256x256xf32>
1306+ // CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED:.+]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1307+ // CHECK: return %[[UNPACK]] : tensor<?x256x256xf32>
1308+
1309+ // -----
1310+
12721311func.func @push_down_permuted_unpack_through_expand (%5: tensor <4 x32 x384 x8 x8 xf32 >) -> tensor <4 x12 x256 x256 xf32 > {
12731312 %6 = tensor.empty () : tensor <4 x3072 x256 xf32 >
12741313 %unpack = tensor.unpack %5 outer_dims_perm = [0 , 2 , 1 ] inner_dims_pos = [2 , 1 ] inner_tiles = [8 , 8 ] into %6 : tensor <4 x32 x384 x8 x8 xf32 > -> tensor <4 x3072 x256 xf32 >
0 commit comments