diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index a8662a3d6f63b..5209e1145506b 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1516,7 +1516,7 @@ struct GeneralizePadOpPattern : public OpRewritePattern { }; /// Rewrites a tensor::PackOp into a sequence of: -/// * tensor::PadOp + linalg::TransposeOp + tensor::ExtractSliceOp + +/// * tensor::PadOp + linalg::TransposeOp + /// tensor::EmptyOp + tensor::InsertSliceOp ops. /// /// Required that all the outer dims of the input tensor::PackOp are 1. @@ -1537,10 +1537,6 @@ struct GeneralizePadOpPattern : public OpRewritePattern { /// ^bb0(...): /// tensor.yield %arg2 : f32 /// } : tensor<5x1xf32> to tensor -/// // ExtractSliceOp -/// %extracted_slice = tensor.extract_slice %padded[0, 0] [%tile_dim_1, 2] [1, -/// 1] -/// : tensor to tensor /// // EmptyOp + TransposeOp /// %empty = tensor.empty(%arg3) : tensor<2x?xf32> /// %transposed = linalg.transpose diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 64096954f56b9..ed9ebca4f306a 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1153,71 +1153,66 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( Location loc = packOp.getLoc(); Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); - auto inputShape = packOp.getSourceType().getShape(); DenseMap dimAndTileMapping = packOp.getDimAndTileMapping(); int64_t srcRank = packOp.getSourceRank(); - int64_t destRank = packOp.getDestRank(); - size_t numTiles = destRank - srcRank; - - // 1. Use rank-reduced tensor.extract_slice op to extract the tile: - // %extracted_tile = tensor.extract_slice(%pack_op_input) - SmallVector readOffsets(srcRank, zeroIdxAttr); - SmallVector readStrides(srcRank, oneIdxAttr); + int64_t numTiles = destRank - srcRank; - // The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as - // all outer dims are 1. - SmallVector extractSliceSizes(srcRank - numTiles, oneIdxAttr); - // The shape of the output for ExtractSliceOp. All leading unit dims are - // effectively rank-reduced, hence skipped. - SmallVector outputShapeForExtractSlice; + if (!llvm::all_of(packOp.getInnerDimsPos(), + [&srcRank, &numTiles](int64_t dimPos) { + return dimPos >= (srcRank - numTiles - 1); + })) + return rewriter.notifyMatchFailure( + packOp, "Attempting to tile non-trailing source dims!"); - // Extract the trailing sizes and shape dims for ExtractSliceOp. These should - // be equal to the inner tile sizes. + // 1. Extract the inner tile sizes. + // Where possible, values are replaced with constant attributes (to match the + // behaviour of `getPackOpSourceOrPaddedSource`). + SmallVector tileSizes; for (auto i : llvm::seq(0, srcRank)) { if (dimAndTileMapping.count(i)) { - auto [tileSize, tileSizeOfr] = + // Rather than taking the tile size as is, extact the actual constant + // value Attribute where possible, e.g.: + // [Value: %tile_size = arith.constant 8 : index] --> [Attribute: 8] + auto [_, tileSize] = getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); - extractSliceSizes.push_back(tileSizeOfr); - outputShapeForExtractSlice.push_back(tileSize); + tileSizes.push_back(tileSize); } } - Type elemType = packOp.getSourceType().getElementType(); - auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType); - - Value tile = rewriter.create( - loc, readType, input, readOffsets, extractSliceSizes, readStrides); - - // 2. Transpose the tile to match the inner tile order: + // 2. Transpose the input to match the inner tile order: // %init = tensor.empty() - // %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init) - // NOTE: Outer dims are 1 and hence effectively ignored. - SmallVector perm = getPackUnpackRankReducedPerm( - inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); + // %transposed_tile = linalg.transpose ins(%source_or_padded_source), + // outs(%init) + // Two assumptions are made: + // 1. All outer dims are 1 - the corresponding transposition doesn't matter. + // 2. Inner dims position correspond to the trailing `numTiles` dims. + SmallVector tilesPermNormalized = + getPackUnpackNormalizedPerm(srcRank, packOp.getInnerDimsPos()); + SmallVector srcPermForTranspose; + for (int64_t i = 0; i < (srcRank - numTiles); i++) + srcPermForTranspose.push_back(i); + + srcPermForTranspose.append(SmallVector(packOp.getInnerDimsPos())); LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; - llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); + llvm::interleaveComma(srcPermForTranspose, DBGS() << "perm: "); + DBGSNL();); // 2.1 Create tensor.empty (init value for TransposeOp) - SmallVector transShapeForEmptyOp; - - // Acquire tensor shape required to create EmptyOp. This will match the inner - // tile sizes. - size_t idx = numTiles; - while (idx != 0) { - transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]); - idx--; - } + SmallVector transShapeForEmptyOp(srcRank - numTiles, + oneIdxAttr); + transShapeForEmptyOp.append(tileSizes); - applyPermutationToVector(transShapeForEmptyOp, perm); - Value empty = - rewriter.create(loc, transShapeForEmptyOp, elemType); + applyPermutationToVector(transShapeForEmptyOp, + srcPermForTranspose); + Value empty = rewriter.create( + loc, transShapeForEmptyOp, packOp.getSourceType().getElementType()); // 2.2 Create linalg.transpose - auto transposedOp = - rewriter.create(loc, tile, empty, perm); + auto transposedOp = rewriter.create(loc, input, empty, + srcPermForTranspose); // 3. Insert the inner tile to the destination: // %inserted_tile = tensor.insert_slice(%transposed_tile) diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir index d0c53ae468001..1fae311467bcf 100644 --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack-tile.mlir @@ -9,19 +9,17 @@ func.func @KCRS_to_KCRSsr(%arg0: tensor<1x1x128x64xf32>, %arg1: tensor<1x1x4x8x8 // CHECK: func.func @KCRS_to_KCRSsr // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %{{.+}} = scf.for %[[R:[a-zA-Z0-9]+]] = -// CHECK: %{{.+}} = scf.for %[[S:[a-zA-Z0-9]+]] = +// CHECK: scf.for %[[R:[a-zA-Z0-9]+]] = +// CHECK: scf.for %[[S:[a-zA-Z0-9]+]] {{.*}} iter_args(%[[ITER_SLICE:.*]] = // CHECK: %[[IN_R:.+]] = affine.apply #[[MAP0]](%[[R]]) // CHECK: %[[IN_S:.+]] = affine.apply #[[MAP2]](%[[S]]) // CHECK: %[[SRC_SLICE:.+]] = tensor.extract_slice %[[SRC]] // CHECK-SAME: [0, 0, %[[IN_R]], %[[IN_S]]] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC_SLICE]] -// CHECK-SAME: [0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] : tensor<1x1x32x8xf32> to tensor<32x8xf32> -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> -// CHECK: %[[TRANSP:.+]] = linalg.transpose -// CHECK-SAME: ins(%[[TILE]] -// CHECK-SAME: outs(%[[EMPTY]] -// CHECK-SAME: permutation = [1, 0] +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<1x1x8x32xf32> +// CHECK: %[[TRANSP:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[SRC_SLICE]] : tensor<1x1x32x8xf32>) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x32xf32>) +// CHECK-SAME: permutation = [0, 1, 3, 2] // CHECK: %{{.+}} = tensor.insert_slice %[[TRANSP]] into %{{.+}} module attributes {transform.with_named_sequence} { diff --git a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir index 8abf7a11bed5c..f4b1d9a55f091 100644 --- a/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir +++ b/mlir/test/Dialect/Linalg/generalize-tensor-pack.mlir @@ -63,8 +63,7 @@ func.func @simple_pad_and_pack_dynamic_tile(%input: tensor<5x1xf32>, %output: te // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] { // CHECK: tensor.yield %[[PAD_VAL]] : f32 // CHECK-NOT: linalg.transpose -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], 2] [1, 1] : tensor to tensor -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> // CHECK: return %[[RES]] : tensor<1x1x?x2xf32> func.func @simple_pad_and_pack_dynamic_tile_cst(%input: tensor<5x1xf32>, %output: tensor<1x1x?x2xf32>, %pad: f32) -> tensor<1x1x?x2xf32> { @@ -95,10 +94,10 @@ func.func @simple_pad_and_pack_dynamic_tile_transpose(%input: tensor<5x1xf32>, % // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] { // CHECK: tensor.yield %[[PAD_VAL]] : f32 // CHECK-NEXT: } : tensor<5x1xf32> to tensor -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_1]], 2] [1, 1] : tensor to tensor // CHECK: %[[EMPTY:.*]] = tensor.empty(%[[TILE_DIM_1]]) : tensor<2x?xf32> // CHECK: %[[TR:.*]] = linalg.transpose -// CHECK-SAME: ins(%[[SLICE]] : tensor) outs(%[[EMPTY]] : tensor<2x?xf32>) +// CHECK-SAME: ins(%[[PAD:.*]] : tensor) +// CHECK-SAME: outs(%[[EMPTY]] : tensor<2x?xf32>) // CHECK-SAME: permutation = [1, 0] // CHECK: %[[RES:.*]] = tensor.insert_slice %[[TR]] into %[[DEST]][0, 0, 0, 0] [1, 1, 2, %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x2x?xf32> // CHECK: return %[[RES]] : tensor<1x1x2x?xf32> @@ -128,10 +127,10 @@ func.func @simple_pad_and_pack_scalable_tile(%input: tensor<5x1xf32>, %output: t // CHECK: %[[PAD:.+]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH]], 1] { // CHECK: tensor.yield %[[PAD_VAL]] : f32 // CHECK-NOT: linalg.transpose -// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[PAD:.+]][0, 0] {{\[}}%[[C8_VS]], 2] [1, 1] : tensor to tensor -// CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD:.*]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[C8_VS]], 2] [1, 1, 1, 1] : tensor into tensor<1x1x?x2xf32> // CHECK: return %[[RES]] : tensor<1x1x?x2xf32> + /// Same as example above, but with both tile sizes dynamic. func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: tensor<1x1x?x?xf32>, %pad: f32, %tile_dim_0: index, %tile_dim_1: index) -> tensor<1x1x?x?xf32> { @@ -149,8 +148,7 @@ func.func @simple_pad_and_pack_dynamic_tiles(%input: tensor<5x1xf32>, %output: t // CHECK: %[[PAD:.*]] = tensor.pad %[[SRC]] low[0, 0] high{{\[}}%[[PAD_HIGH_1]], %[[PAD_HIGH_2]]] { // CHECK: tensor.yield %[[PAD_VAL]] : f32 // CHECK-NOT: linalg.transpose -// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[PAD:.*]][0, 0] {{\[}}%[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1] : tensor to tensor -// CHECK: %[[RES:.*]] = tensor.insert_slice %[[SLICE]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor into tensor<1x1x?x?xf32> +// CHECK: %[[RES:.*]] = tensor.insert_slice %[[PAD]] into %[[DEST]][0, 0, 0, 0] [1, 1, %[[TILE_DIM_0]], %[[TILE_DIM_1]]] [1, 1, 1, 1] : tensor into tensor<1x1x?x?xf32> // CHECK: return %[[RES]] : tensor<1x1x?x?xf32> // ----- @@ -170,12 +168,13 @@ func.func @simple_pad_and_pack_dynamic_tile_not_all_dims_tiled(%input: tensor<1x // CHECK: ^bb0(%[[VAL_6:.*]]: index, %[[VAL_7:.*]]: index, %[[VAL_8:.*]]: index, %[[VAL_9:.*]]: index): // CHECK: tensor.yield %[[VAL_2]] : f32 // CHECK: } : tensor<1x1x5x1xf32> to tensor<1x1x?x2xf32> -// CHECK: %[[VAL_10:.*]] = tensor.extract_slice %[[VAL_11:.*]][0, 0, 0, 0] [1, 1, %[[VAL_3]], 2] [1, 1, 1, 1] : tensor<1x1x?x2xf32> to tensor -// CHECK: %[[VAL_12:.*]] = tensor.empty(%[[VAL_3]]) : tensor<2x?xf32> -// CHECK: %[[VAL_13:.*]] = linalg.transpose ins(%[[VAL_10]] : tensor) outs(%[[VAL_12]] : tensor<2x?xf32>) permutation = [1, 0] -// CHECK: %[[VAL_14:.*]] = tensor.insert_slice %[[VAL_13]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<2x?xf32> into tensor<1x1x1x1x2x?xf32> -// CHECK: return %[[VAL_14]] : tensor<1x1x1x1x2x?xf32> -// CHECK: } +// CHECK: %[[VAL_10:.*]] = tensor.empty(%[[VAL_3]]) : tensor<1x1x2x?xf32> +// CHECK: %[[VAL_11:.*]] = linalg.transpose +// CHECK-SAME: ins(%[[VAL_12:.*]] : tensor<1x1x?x2xf32>) +// CHECK-SAME: outs(%[[VAL_10]] : tensor<1x1x2x?xf32>) +// CHECK-SAME: permutation = [0, 1, 3, 2] +// CHECK: %[[VAL_13:.*]] = tensor.insert_slice %[[VAL_11]] into %[[VAL_1]][0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 2, %[[VAL_3]]] [1, 1, 1, 1, 1, 1] : tensor<1x1x2x?xf32> into tensor<1x1x1x1x2x?xf32> +// CHECK: return %[[VAL_13]] : tensor<1x1x1x1x2x?xf32> // ----- @@ -218,12 +217,11 @@ func.func @simple_KCRS_to_KRSCsr(%arg0: tensor<1x1x32x8xf32>, %arg1: tensor<1x1x // CHECK-LABEL: func.func @simple_KCRS_to_KRSCsr // CHECK-SAME: %[[SRC:[a-zA-Z0-9]+]] // CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] -// CHECK: %[[TILE:.+]] = tensor.extract_slice %[[SRC]][0, 0, 0, 0] [1, 1, 32, 8] [1, 1, 1, 1] -// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<8x32xf32> +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<1x1x8x32xf32> // CHECK: %[[TRANSP:.+]] = linalg.transpose -// CHECK-SAME: ins(%[[TILE]] : tensor<32x8xf32>) -// CHECK-SAME: outs(%[[EMPTY]] : tensor<8x32xf32>) -// CHECK-SAME: permutation = [1, 0] +// CHECK-SAME: ins(%[[SRC]] : tensor<1x1x32x8xf32> +// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x1x8x32xf32>) +// CHECK-SAME: permutation = [0, 1, 3, 2] // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[TRANSP]] into %[[DEST]] // CHECK-SAME: [0, 0, 0, 0, 0, 0] [1, 1, 1, 1, 8, 32] [1, 1, 1, 1, 1, 1] // CHECK: return %[[INSERT]]