Skip to content

Commit a4ed7d8

Browse files
committed
[mlir][linalg] Allow pack consumer fusion if the tile size is greater than dimension size.
This only happens when you use tile size which is greater than or equal to the dimension size. In this case, it is a full slice, so it is fusiable. The IR can be generated during the TileAndFuse process. It is hard to fix in such driver, so we enable the naive fusion for the case. Signed-off-by: hanhanW <[email protected]>
1 parent 5138b61 commit a4ed7d8

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ struct PackOpTiling
918918
int64_t destDimSize = outerShapeWithoutTranspose[dim];
919919
bool isTiled = failed(cstTileSize) ||
920920
ShapedType::isDynamic(srcDimSize) ||
921-
cstTileSize.value() != srcDimSize;
921+
cstTileSize.value() < srcDimSize;
922922
if (!isTiled) {
923923
outerDimOffsets.push_back(offsets[dim]);
924924
if (ShapedType::isStatic(destDimSize)) {

mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,56 @@ module attributes {transform.with_named_sequence} {
451451

452452
// -----
453453

454+
#map = affine_map<(d0) -> (-d0 + 4, 16)>
455+
func.func @fuse_pack_consumer_if_single_iteration(%arg0: tensor<4x4xf32>) -> tensor<1x4x16x1xf32> {
456+
%0 = tensor.empty() : tensor<1x4x16x1xf32>
457+
%1 = tensor.empty() : tensor<4x4xf32>
458+
%2 = scf.forall (%arg1) = (0) to (4) step (16) shared_outs(%arg2 = %1) -> (tensor<4x4xf32>) {
459+
%3 = affine.min #map(%arg1)
460+
%extracted_slice = tensor.extract_slice %arg0[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
461+
%extracted_slice_0 = tensor.extract_slice %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<4x4xf32> to tensor<?x4xf32>
462+
%4 = linalg.exp ins(%extracted_slice : tensor<?x4xf32>) outs(%extracted_slice_0 : tensor<?x4xf32>) -> tensor<?x4xf32>
463+
scf.forall.in_parallel {
464+
tensor.parallel_insert_slice %4 into %arg2[%arg1, 0] [%3, 4] [1, 1] : tensor<?x4xf32> into tensor<4x4xf32>
465+
}
466+
}
467+
%cst = arith.constant 0.000000e+00 : f32
468+
%pack = linalg.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1] into %0 : tensor<4x4xf32> -> tensor<1x4x16x1xf32>
469+
return %pack : tensor<1x4x16x1xf32>
470+
}
471+
472+
module attributes {transform.with_named_sequence} {
473+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
474+
%0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg0 : (!transform.any_op) -> !transform.any_op
475+
%1 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
476+
%consumer, %fused_consumer = transform.test.fuse_consumer %0 in(%1) : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
477+
transform.yield
478+
}
479+
}
480+
// CHECK: #[[MAP:.*]] = affine_map<(d0) -> (-d0 + 4, 16)>
481+
// CHECK: func.func @fuse_pack_consumer_if_single_iteration(
482+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
483+
// CHECK-DAG: %[[PACK_INIT:.*]] = tensor.empty() : tensor<1x4x16x1xf32>
484+
// CHECK-DAG: %[[ELEM_INIT:.*]] = tensor.empty() : tensor<4x4xf32>
485+
// CHECK-DAG: %[[PAD_VAL:.*]] = arith.constant 0.000000e+00 : f32
486+
// CHECK: %{{.*}}:2 = scf.forall (%[[IV:.*]]) = (0) to (4) step (16)
487+
// CHECK-SAME: shared_outs(%[[ELEM_OUT_ARG:.*]] = %[[ELEM_INIT]], %[[PACK_OUT_ARG:.*]] = %[[PACK_INIT]])
488+
// CHECK-DAG: %[[SIZE:.+]] = affine.min #[[MAP]](%[[IV]])
489+
// CHECK-DAG: %[[ELEM_SRC:.*]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
490+
// CHECK-DAG: %[[ELEM_DEST:.*]] = tensor.extract_slice %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
491+
// CHECK: %[[ELEM:.*]] = linalg.exp
492+
// CHECK-SAME: ins(%[[ELEM_SRC]]
493+
// CHECK-SAME: outs(%[[ELEM_DEST]]
494+
// CHECK-DAG: %[[TILED_PACK_DEST:.*]] = tensor.extract_slice %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
495+
// CHECK: %[[PACK:.*]] = linalg.pack %[[ELEM]]
496+
// CHECK-SAME: padding_value(%[[PAD_VAL]] : f32)
497+
// CHECK-SAME: outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 1]
498+
// CHECK-SAME: into %[[TILED_PACK_DEST]]
499+
// CHECK: scf.forall.in_parallel {
500+
// CHECK: tensor.parallel_insert_slice %[[ELEM]] into %[[ELEM_OUT_ARG]][%[[IV]], 0] [%[[SIZE]], 4] [1, 1]
501+
// CHECK: tensor.parallel_insert_slice %[[PACK]] into %[[PACK_OUT_ARG]][%[[IV]], 0, 0, 0] [1, 4, 16, 1] [1, 1, 1, 1]
502+
503+
// -----
454504

455505
func.func @fuse_perfect_tiling_pack_consumer_with_outer_dims_perm(%arg0: tensor<64x32xf32>, %arg1: tensor<64x32xf32>, %arg2: tensor<2x64x16x1xf32>) -> tensor<2x64x16x1xf32> {
456506
%0 = scf.forall (%arg3) = (0) to (32) step (16) shared_outs(%arg4 = %arg1) -> (tensor<64x32xf32>) {

0 commit comments

Comments
 (0)