From d40f7052348001349164d13a50c2beff164373e8 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Fri, 1 Nov 2024 15:59:47 +0000 Subject: [PATCH 1/3] [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) Currently, `FoldTensorCastProducerOp` incorrectly folds the following: ```mlir %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32> %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> ``` as (note the static trailing dim in the result and dynamic tile dimension that corresponds to that): ```mlir %res = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x8x1xi32> ``` This triggers an Op verification failure and is due to the fact that the folder does not update the inner tile sizes in the pack Op. This PR addresses that. Note, supporting other Ops with size-like attributes is left as a TODO; --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 46 +++++++++++++++++++++- mlir/test/Dialect/Tensor/canonicalize.mlir | 23 ++++++++++- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index c2d6bc610cd92..406b557b0f0e3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4756,8 +4756,50 @@ struct FoldTensorCastProducerOp newResultTypes[dpsInitIdx++] = newOperands.back().getType(); } - // Clone op. - Operation *newOp = clone(rewriter, op, newResultTypes, newOperands); + // For ops that have sizes-like attribute, update these accordingly. + // For now, only `tensor.pack` is supported. + // TODO: Generalize to make it work with other ops as well (e.g. + // `tensor.unpack`) + SmallVector newMixedTileSizes; + if (auto pack = dyn_cast_or_null(*op)) { + for (auto it : llvm::zip(cast(newResultTypes[0]) + .getShape() + .take_back(pack.getMixedTiles().size()), + pack.getMixedTiles())) { + + int64_t shape = std::get<0>(it); + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + if (Attribute attr = + llvm::dyn_cast_if_present(std::get<1>(it))) { + // Already a constant + newMixedTileSizes.push_back(std::get<1>(it)); + } else { + auto tileSize = getConstantIntValue(std::get<1>(it)); + assert(tileSize == shape && "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + } + + // Clone op. For ops that have sizes-like attribute, make sure to udpate + // those as well. For now, only `tensor.pack` is supported. + // TODO: Generalize to make it work with other ops as well (e.g. + // `tensor.unpack`) + // Operation *newOp; + Operation *newOp; + if (auto pack = dyn_cast_or_null(*op)) { + newOp = rewriter.create( + pack.getLoc(), newOperands[0], newOperands[1], pack.getInnerDimsPos(), + newMixedTileSizes, pack.getPaddingValue(), pack.getOuterDimsPerm()); + } else { + newOp = clone(rewriter, op, newResultTypes, newOperands); + } + SmallVector replacements; replacements.reserve(newOp->getNumResults()); for (auto [oldResult, newResult] : diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 693079c3aa2fa..ebcc69250ad56 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2718,18 +2718,37 @@ func.func @dim_out_of_bounds() -> vector<7xi32> { // ----- -// CHECK-LABEL: func.func @test_destination_multiple_result( +// CHECK-LABEL: func.func @fold_cast_multiple_results( // CHECK-SAME: %[[ARG1:.*]]: tensor<2x2xf32>, // CHECK-SAME: %[[ARG2:.*]]: tensor<2x2xf32>) -> index { // CHECK: %[[RES:.*]]:2 = test.destination_style_op ins(%[[ARG1]] : tensor<2x2xf32>) // CHECK-SAME: outs(%[[ARG2]] : tensor<2x2xf32>) -> tensor<2x2xf32>, index // CHECK: return %[[RES]]#1 : index -func.func @test_destination_multiple_result(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index { +func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> index { %cast = tensor.cast %arg0 : tensor<2x2xf32> to tensor %cast_0 = tensor.cast %arg1 : tensor<2x2xf32> to tensor %0:2 = test.destination_style_op ins(%cast : tensor) outs(%cast_0 : tensor) -> tensor, index return %0#1 : index } +// ----- + +// CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size +// CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>, +// CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { +// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32> +// CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> +func.func @fold_cast_pack_dynamic_tile_size( + %dest: tensor<1x1x8x1xi32>, + %src: tensor<7x?xi32>, + %pad: i32) -> tensor<1x1x8x1xi32> { + + %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32> + %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> + return %res : tensor<1x1x8x1xi32> +} // ----- From b7b56b1f7b308f854424da0ee0b927401b5ae4d0 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 4 Nov 2024 19:01:39 +0000 Subject: [PATCH 2/3] fixup! [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) Split into two patterns. --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 187 ++++++++++++++--------- 1 file changed, 114 insertions(+), 73 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 406b557b0f0e3..2f0d7d441e19c 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4698,6 +4698,114 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// // Common Canonicalizers and Folders. //===----------------------------------------------------------------------===// +bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { + // InsertSliceOp has its own logic about folding tensor.cast ops. + if (isa(op.getOperation())) + return false; + + // Exclude DPS ops that are also LoopLike from this interface as they + // might need special handling of attached regions. + if (isa(op.getOperation())) + return false; + + // If no operand comes from a tensor::CastOp and can be folded then fail. + bool hasTensorCastOperand = + llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { + if (llvm::isa(opOperand.get())) + return false; + auto castOp = opOperand.get().getDefiningOp(); + return castOp && canFoldIntoConsumerOp(castOp); + }); + + return hasTensorCastOperand; +} + +static SmallVector getNewOperands(DestinationStyleOpInterface op, + SmallVector &newResTy) { + SmallVector newOperands; + newOperands.reserve(op->getNumOperands()); + + // Assumes that the result has dpsInits followed by nonDpsInits. + int64_t dpsInitIdx = 0; + for (OpOperand &opOperand : op->getOpOperands()) { + auto tensorCastOp = opOperand.get().getDefiningOp(); + bool fold = canFoldIntoConsumerOp(tensorCastOp); + newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); + if (op.isDpsInit(&opOperand) && + !llvm::isa(newOperands.back().getType())) + newResTy[dpsInitIdx++] = newOperands.back().getType(); + } + return newOperands; +} + +/// Folds a tensor.cast op into a consuming tensor::PackOp op if the +/// `tensor.cast` has source that is more static than the consuming op. +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor +/// %2 = tensor.pack %1 ... : tensor ... +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = tensor.pack %0 ... : tensor<8x16xf32> ... +/// ``` +struct FoldTensorCastPackOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PackOp op, + PatternRewriter &rewriter) const override { + if (!foldTensorCastPrecondition(op)) + return failure(); + + SmallVector newResultTypes(op->getResultTypes()); + SmallVector newOperands = getNewOperands(op, newResultTypes); + + // Get the updated mixed-tile-sizes attribute. + SmallVector newMixedTileSizes; + for (auto it : llvm::zip(cast(newResultTypes[0]) + .getShape() + .take_back(op.getMixedTiles().size()), + op.getMixedTiles())) { + int64_t shape = std::get<0>(it); + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + if (Attribute attr = + llvm::dyn_cast_if_present(std::get<1>(it))) { + // Already a constant + newMixedTileSizes.push_back(std::get<1>(it)); + } else { + auto tileSize = getConstantIntValue(std::get<1>(it)); + assert(tileSize == shape && "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + + // Clone op. + PackOp newOp = rewriter.create( + op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), + newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); + + SmallVector replacements; + replacements.reserve(newOp->getNumResults()); + for (auto [oldResult, newResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + newResult.getType() != oldResult.getType() + ? replacements.push_back(rewriter.create( + op->getLoc(), oldResult.getType(), newResult)) + : replacements.push_back(newResult); + } + rewriter.replaceOp(op, replacements); + + return success(); + } +}; /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if /// the `tensor.cast` has source that is more static than the consuming op. @@ -4722,83 +4830,15 @@ struct FoldTensorCastProducerOp LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override { - // InsertSliceOp has its own logic about folding tensor.cast ops. - if (isa(op.getOperation())) - return failure(); - - // Exclude DPS ops that are also LoopLike from this interface as they - // might need special handling of attached regions. - if (isa(op.getOperation())) - return failure(); - // If no operand comes from a tensor::CastOp and can be folded then fail. - bool hasTensorCastOperand = - llvm::any_of(op->getOpOperands(), [&](OpOperand &opOperand) { - if (llvm::isa(opOperand.get())) - return false; - auto castOp = opOperand.get().getDefiningOp(); - return castOp && canFoldIntoConsumerOp(castOp); - }); - if (!hasTensorCastOperand) + if (!foldTensorCastPrecondition(op) || dyn_cast(*op)) return failure(); - SmallVector newResultTypes(op->getResultTypes()); - SmallVector newOperands; - newOperands.reserve(op->getNumOperands()); - // Assumes that the result has dpsInits followed by nonDpsInits. - int64_t dpsInitIdx = 0; - for (OpOperand &opOperand : op->getOpOperands()) { - auto tensorCastOp = opOperand.get().getDefiningOp(); - bool fold = canFoldIntoConsumerOp(tensorCastOp); - newOperands.push_back(fold ? tensorCastOp.getOperand() : opOperand.get()); - if (op.isDpsInit(&opOperand) && - !llvm::isa(newOperands.back().getType())) - newResultTypes[dpsInitIdx++] = newOperands.back().getType(); - } - - // For ops that have sizes-like attribute, update these accordingly. - // For now, only `tensor.pack` is supported. - // TODO: Generalize to make it work with other ops as well (e.g. - // `tensor.unpack`) - SmallVector newMixedTileSizes; - if (auto pack = dyn_cast_or_null(*op)) { - for (auto it : llvm::zip(cast(newResultTypes[0]) - .getShape() - .take_back(pack.getMixedTiles().size()), - pack.getMixedTiles())) { - - int64_t shape = std::get<0>(it); - if (shape == ShapedType::kDynamic) { - newMixedTileSizes.push_back(std::get<1>(it)); - continue; - } + SmallVector newResultTypes(op->getResultTypes()); + SmallVector newOperands = getNewOperands(op, newResultTypes); - if (Attribute attr = - llvm::dyn_cast_if_present(std::get<1>(it))) { - // Already a constant - newMixedTileSizes.push_back(std::get<1>(it)); - } else { - auto tileSize = getConstantIntValue(std::get<1>(it)); - assert(tileSize == shape && "tile size and dim size don't match!"); - newMixedTileSizes.push_back( - (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); - } - } - } - - // Clone op. For ops that have sizes-like attribute, make sure to udpate - // those as well. For now, only `tensor.pack` is supported. - // TODO: Generalize to make it work with other ops as well (e.g. - // `tensor.unpack`) - // Operation *newOp; - Operation *newOp; - if (auto pack = dyn_cast_or_null(*op)) { - newOp = rewriter.create( - pack.getLoc(), newOperands[0], newOperands[1], pack.getInnerDimsPos(), - newMixedTileSizes, pack.getPaddingValue(), pack.getOuterDimsPerm()); - } else { - newOp = clone(rewriter, op, newResultTypes, newOperands); - } + // Clone op + auto newOp = clone(rewriter, op, newResultTypes, newOperands); SmallVector replacements; replacements.reserve(newOp->getNumResults()); @@ -4823,6 +4863,7 @@ struct FoldTensorCastProducerOp void TensorDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { + results.add(getContext()); results.add(getContext()); } From 94316c7adedf626b13d6a028ad47f0d4373bf0a5 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 5 Nov 2024 17:59:00 +0000 Subject: [PATCH 3/3] fixup! fixup! [mlir][tensor] Improve `FoldTensorCastProducerOp` (dynamic shapes) Final tweaks --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 32 +++++++++++------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 2f0d7d441e19c..1847066b2d1e3 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4699,13 +4699,11 @@ OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { // Common Canonicalizers and Folders. //===----------------------------------------------------------------------===// bool foldTensorCastPrecondition(DestinationStyleOpInterface op) { - // InsertSliceOp has its own logic about folding tensor.cast ops. - if (isa(op.getOperation())) - return false; - - // Exclude DPS ops that are also LoopLike from this interface as they + // 1. InsertSliceOp has its own logic about folding tensor.cast ops. + // 2. Exclude DPS ops that are also LoopLike from this interface as they // might need special handling of attached regions. - if (isa(op.getOperation())) + if (isa(op.getOperation()) || + isa(op.getOperation())) return false; // If no operand comes from a tensor::CastOp and can be folded then fail. @@ -4780,7 +4778,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern { // Already a constant newMixedTileSizes.push_back(std::get<1>(it)); } else { - auto tileSize = getConstantIntValue(std::get<1>(it)); + int64_t tileSize = getConstantIntValue(std::get<1>(it)).value(); assert(tileSize == shape && "tile size and dim size don't match!"); newMixedTileSizes.push_back( (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); @@ -4792,16 +4790,15 @@ struct FoldTensorCastPackOp : public OpRewritePattern { op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); - SmallVector replacements; - replacements.reserve(newOp->getNumResults()); - for (auto [oldResult, newResult] : - llvm::zip(op->getResults(), newOp->getResults())) { - newResult.getType() != oldResult.getType() - ? replacements.push_back(rewriter.create( - op->getLoc(), oldResult.getType(), newResult)) - : replacements.push_back(newResult); - } - rewriter.replaceOp(op, replacements); + // Replace op. + Value oldResult = op.getResult(); + Value newResult = newOp.getResult(); + Value replacement = (newResult.getType() != oldResult.getType()) + ? rewriter.create( + op->getLoc(), oldResult.getType(), newResult) + : newResult; + + rewriter.replaceOp(op, {replacement}); return success(); } @@ -4831,6 +4828,7 @@ struct FoldTensorCastProducerOp LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const override { + // Reject tensor::PackOp - there's dedicated pattern for that instead. if (!foldTensorCastPrecondition(op) || dyn_cast(*op)) return failure();