From 1bc2d8eaced67b9e2e4a6893e18db49a76a4f61b Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 31 Dec 2024 14:02:45 +0000 Subject: [PATCH 1/2] [mlir][tensor] Introduce `FoldTensorCastUnPackOp` This patch specializes `FoldTensorCastProducerOp` for `tensor::UnPackOp` by introducing a dedicated pattern: `FoldTensorCastUnPackOp`. This change mirrors a similar update made for `tensor::PackOp` in #114559. Below is the updated rationale for `tensor::UnPackOp`. Currently, `FoldTensorCastProducerOp` incorrectly folds the following: ```mlir %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> ``` as: ```mlir %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %res : tensor<1x1x?x1xi32> -> tensor<7x?xi32> ``` This leads to an Op verification failure because the folder does not update the inner tile sizes in the unpack Op. This patch resolves the issue. Additional Changes: * invalid.mlir: Fixes a typo. * TensorOps.cpp: Removes unnecessary `(void)tileSize` and adds extra comments following this discussion: https://github.com/llvm/llvm-project/pull/115772. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 89 +++++++++++++++++++++- mlir/test/Dialect/Tensor/canonicalize.mlir | 21 +++++ mlir/test/Dialect/Tensor/invalid.mlir | 2 +- 3 files changed, 107 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f79c774ceb3e9..aeb11186c124d 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4837,15 +4837,17 @@ struct FoldTensorCastPackOp : public OpRewritePattern { // Already a constant newMixedTileSizes.push_back(std::get<1>(it)); } else { - int64_t tileSize = getConstantIntValue(std::get<1>(it)).value(); - assert(tileSize == shape && "tile size and dim size don't match!"); - (void)tileSize; + assert(getConstantIntValue(std::get<1>(it)).value() == shape && + "tile size and dim size don't match!"); newMixedTileSizes.push_back( (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); } } // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. PackOp newOp = rewriter.create( op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); @@ -4865,6 +4867,83 @@ struct FoldTensorCastPackOp : public OpRewritePattern { } }; +/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the +/// `tensor.cast` has source that is more static than the consuming op. +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> +/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32> +/// ``` +struct FoldTensorCastUnPackOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnPackOp op, + PatternRewriter &rewriter) const override { + if (!foldTensorCastPrecondition(op)) + return failure(); + + SmallVector newResultTypes(op->getResultTypes()); + SmallVector newOperands = getNewOperands(op, newResultTypes); + Value sourceTensor = newOperands[0]; + + // Get the updated mixed-tile-sizes attribute. + SmallVector newMixedTileSizes; + for (auto it : llvm::zip(cast(sourceTensor.getType()) + .getShape() + .take_back(op.getMixedTiles().size()), + op.getMixedTiles())) { + int64_t shape = std::get<0>(it); + // If the current source shape is dynamic, just preserve this mixed + // size. + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + // If the current source is static, update the dynamic mixed-size + // (provided the original value is dynamic). + if (Attribute attr = + llvm::dyn_cast_if_present(std::get<1>(it))) { + // Already a constant + newMixedTileSizes.push_back(std::get<1>(it)); + } else { + assert(getConstantIntValue(std::get<1>(it)).value() == shape && + "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + + // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. + UnPackOp newOp = rewriter.create( + op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), + newMixedTileSizes, op.getOuterDimsPerm()); + newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); + + // Replace op. + Value oldResult = op.getResult(); + Value newResult = newOp.getResult(); + Value replacement = (newResult.getType() != oldResult.getType()) + ? rewriter.create( + op->getLoc(), oldResult.getType(), newResult) + : newResult; + + rewriter.replaceOp(op, {replacement}); + + return success(); + } +}; + /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if /// the `tensor.cast` has source that is more static than the consuming op. /// @@ -4890,7 +4969,8 @@ struct FoldTensorCastProducerOp PatternRewriter &rewriter) const override { // Reject tensor::PackOp - there's dedicated pattern for that instead. - if (!foldTensorCastPrecondition(op) || dyn_cast(*op)) + if (!foldTensorCastPrecondition(op) || + isa(*op)) return failure(); SmallVector newResultTypes(op->getResultTypes()); @@ -4923,6 +5003,7 @@ struct FoldTensorCastProducerOp void TensorDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(getContext()); + results.add(getContext()); results.add(getContext()); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index e8fc4ce834e18..88e3691e2d629 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x %0:2 = test.destination_style_op ins(%cast : tensor) outs(%cast_0 : tensor) -> tensor, index return %0#1 : index } + // ----- // CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size @@ -2814,6 +2815,26 @@ func.func @fold_cast_pack_dynamic_tile_size( // ----- +// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { +// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +// CHECK: return %[[RES]] : tensor<7x?xi32> +func.func @fold_cast_unpack_dynamic_tile_size( + %src: tensor<1x1x8x1xi32>, + %res: tensor<7x?xi32>) -> tensor<7x?xi32> { + + %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %unpack = tensor.unpack %cast + inner_dims_pos = [0, 1] + inner_tiles = [%c8, 1] + into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> + return %unpack : tensor<7x?xi32> +} + +// ----- + // CHECK-LABEL: func.func @pack_dont_drop_attributes( // CHECK: tensor.pack {{.*}} {test_attr} func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> { diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 83cb4b9d4ab24..1de3e281bc462 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor // ----- -func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> { +func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> { // expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}} %0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32> From 10d26d94d6095cbd3f7a61ae517b5019d5ebbaaa Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Thu, 2 Jan 2025 17:19:34 +0000 Subject: [PATCH 2/2] fixup! [mlir][tensor] Introduce `FoldTensorCastUnPackOp` Address PR comments --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 92 ++++++++++------------ mlir/test/Dialect/Tensor/canonicalize.mlir | 8 +- 2 files changed, 47 insertions(+), 53 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index aeb11186c124d..24a1d55315319 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4795,6 +4795,44 @@ static SmallVector getNewOperands(DestinationStyleOpInterface op, return newOperands; } +// Given the (potentially) updated packed type, `newPackedTy`, generates an +// updated mixed-tile-sizes attribute. A tile size is updated only +// when: +// * a dim from newPackedTy is static, and +// * the corresponding size from mixedTiles is still dynamic. +// Otherwise, the original tile size is preserved. +// Note - packed-type-dim and mixed-tile-size should always match! +static SmallVector +getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, + SmallVector mixedTiles) { + SmallVector newMixedTileSizes; + for (auto it : llvm::zip(cast(newPackedTy) + .getShape() + .take_back(mixedTiles.size()), + mixedTiles)) { + int64_t shape = std::get<0>(it); + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + // If the current result dim is static, update the dynamic mixed-size + // (provided the original value is dynamic). + OpFoldResult tile = std::get<1>(it); + if (Attribute attr = llvm::dyn_cast_if_present(tile)) { + // Already a constant + newMixedTileSizes.push_back(tile); + } else { + assert(getConstantIntValue(tile).value() == shape && + "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + + return newMixedTileSizes; +} + /// Folds a tensor.cast op into a consuming tensor::PackOp op if the /// `tensor.cast` has source that is more static than the consuming op. /// @@ -4821,28 +4859,8 @@ struct FoldTensorCastPackOp : public OpRewritePattern { SmallVector newOperands = getNewOperands(op, newResultTypes); // Get the updated mixed-tile-sizes attribute. - SmallVector newMixedTileSizes; - for (auto it : llvm::zip(cast(newResultTypes[0]) - .getShape() - .take_back(op.getMixedTiles().size()), - op.getMixedTiles())) { - int64_t shape = std::get<0>(it); - if (shape == ShapedType::kDynamic) { - newMixedTileSizes.push_back(std::get<1>(it)); - continue; - } - - if (Attribute attr = - llvm::dyn_cast_if_present(std::get<1>(it))) { - // Already a constant - newMixedTileSizes.push_back(std::get<1>(it)); - } else { - assert(getConstantIntValue(std::get<1>(it)).value() == shape && - "tile size and dim size don't match!"); - newMixedTileSizes.push_back( - (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); - } - } + SmallVector newMixedTileSizes = + getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles()); // Clone op. // TODO: Strictly speaking, discardable attributes should be _discarded_ at @@ -4873,7 +4891,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern { /// Example: /// ```mlir /// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> -/// %2 = tensor.unpack %1 ... : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32> /// ``` /// /// folds into: @@ -4894,32 +4912,8 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { Value sourceTensor = newOperands[0]; // Get the updated mixed-tile-sizes attribute. - SmallVector newMixedTileSizes; - for (auto it : llvm::zip(cast(sourceTensor.getType()) - .getShape() - .take_back(op.getMixedTiles().size()), - op.getMixedTiles())) { - int64_t shape = std::get<0>(it); - // If the current source shape is dynamic, just preserve this mixed - // size. - if (shape == ShapedType::kDynamic) { - newMixedTileSizes.push_back(std::get<1>(it)); - continue; - } - - // If the current source is static, update the dynamic mixed-size - // (provided the original value is dynamic). - if (Attribute attr = - llvm::dyn_cast_if_present(std::get<1>(it))) { - // Already a constant - newMixedTileSizes.push_back(std::get<1>(it)); - } else { - assert(getConstantIntValue(std::get<1>(it)).value() == shape && - "tile size and dim size don't match!"); - newMixedTileSizes.push_back( - (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); - } - } + SmallVector newMixedTileSizes = getNewMixedTileSizes( + rewriter, sourceTensor.getType(), op.getMixedTiles()); // Clone op. // TODO: Strictly speaking, discardable attributes should be _discarded_ at diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 88e3691e2d629..01d14871072cd 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2795,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x // CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { // CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] -// CHECK-SAME: some_attr +// CHECK-SAME: test_attr // CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32> // CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> func.func @fold_cast_pack_dynamic_tile_size( @@ -2808,7 +2808,7 @@ func.func @fold_cast_pack_dynamic_tile_size( %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] - into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> + into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> return %res : tensor<1x1x8x1xi32> } @@ -2818,7 +2818,7 @@ func.func @fold_cast_pack_dynamic_tile_size( // CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( // CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, // CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { -// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {some_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> // CHECK: return %[[RES]] : tensor<7x?xi32> func.func @fold_cast_unpack_dynamic_tile_size( %src: tensor<1x1x8x1xi32>, @@ -2829,7 +2829,7 @@ func.func @fold_cast_unpack_dynamic_tile_size( %unpack = tensor.unpack %cast inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] - into %res {some_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> + into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> return %unpack : tensor<7x?xi32> }