diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f79c774ceb3e9..24a1d55315319 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4795,6 +4795,44 @@ static SmallVector getNewOperands(DestinationStyleOpInterface op, return newOperands; } +// Given the (potentially) updated packed type, `newPackedTy`, generates an +// updated mixed-tile-sizes attribute. A tile size is updated only +// when: +// * a dim from newPackedTy is static, and +// * the corresponding size from mixedTiles is still dynamic. +// Otherwise, the original tile size is preserved. +// Note - packed-type-dim and mixed-tile-size should always match! +static SmallVector +getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, + SmallVector mixedTiles) { + SmallVector newMixedTileSizes; + for (auto it : llvm::zip(cast(newPackedTy) + .getShape() + .take_back(mixedTiles.size()), + mixedTiles)) { + int64_t shape = std::get<0>(it); + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + // If the current result dim is static, update the dynamic mixed-size + // (provided the original value is dynamic). + OpFoldResult tile = std::get<1>(it); + if (Attribute attr = llvm::dyn_cast_if_present(tile)) { + // Already a constant + newMixedTileSizes.push_back(tile); + } else { + assert(getConstantIntValue(tile).value() == shape && + "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + + return newMixedTileSizes; +} + /// Folds a tensor.cast op into a consuming tensor::PackOp op if the /// `tensor.cast` has source that is more static than the consuming op. /// @@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern { SmallVector newOperands = getNewOperands(op, newResultTypes); // Get the updated mixed-tile-sizes attribute. - SmallVector newMixedTileSizes; - for (auto it : llvm::zip(cast(newResultTypes[0]) - .getShape() - .take_back(op.getMixedTiles().size()), - op.getMixedTiles())) { - int64_t shape = std::get<0>(it); - if (shape == ShapedType::kDynamic) { - newMixedTileSizes.push_back(std::get<1>(it)); - continue; - } - - if (Attribute attr = - llvm::dyn_cast_if_present(std::get<1>(it))) { - // Already a constant - newMixedTileSizes.push_back(std::get<1>(it)); - } else { - int64_t tileSize = getConstantIntValue(std::get<1>(it)).value(); - assert(tileSize == shape && "tile size and dim size don't match!"); - (void)tileSize; - newMixedTileSizes.push_back( - (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); - } - } + SmallVector newMixedTileSizes = + getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles()); // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. PackOp newOp = rewriter.create( op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); @@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern { } }; +/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the +/// `tensor.cast` has source that is more static than the consuming op. +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> +/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32> +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32> +/// ``` +struct FoldTensorCastUnPackOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(UnPackOp op, + PatternRewriter &rewriter) const override { + if (!foldTensorCastPrecondition(op)) + return failure(); + + SmallVector newResultTypes(op->getResultTypes()); + SmallVector newOperands = getNewOperands(op, newResultTypes); + Value sourceTensor = newOperands[0]; + + // Get the updated mixed-tile-sizes attribute. + SmallVector newMixedTileSizes = getNewMixedTileSizes( + rewriter, sourceTensor.getType(), op.getMixedTiles()); + + // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. + UnPackOp newOp = rewriter.create( + op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), + newMixedTileSizes, op.getOuterDimsPerm()); + newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); + + // Replace op. + Value oldResult = op.getResult(); + Value newResult = newOp.getResult(); + Value replacement = (newResult.getType() != oldResult.getType()) + ? rewriter.create( + op->getLoc(), oldResult.getType(), newResult) + : newResult; + + rewriter.replaceOp(op, {replacement}); + + return success(); + } +}; + /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if /// the `tensor.cast` has source that is more static than the consuming op. /// @@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp PatternRewriter &rewriter) const override { // Reject tensor::PackOp - there's dedicated pattern for that instead. - if (!foldTensorCastPrecondition(op) || dyn_cast(*op)) + if (!foldTensorCastPrecondition(op) || + isa(*op)) return failure(); SmallVector newResultTypes(op->getResultTypes()); @@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp void TensorDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add(getContext()); + results.add(getContext()); results.add(getContext()); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index e8fc4ce834e18..01d14871072cd 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2786,6 +2786,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x %0:2 = test.destination_style_op ins(%cast : tensor) outs(%cast_0 : tensor) -> tensor, index return %0#1 : index } + // ----- // CHECK-LABEL: func.func @fold_cast_pack_dynamic_tile_size @@ -2794,7 +2795,7 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x // CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { // CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) // CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] -// CHECK-SAME: some_attr +// CHECK-SAME: test_attr // CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32> // CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> func.func @fold_cast_pack_dynamic_tile_size( @@ -2807,13 +2808,33 @@ func.func @fold_cast_pack_dynamic_tile_size( %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] - into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> + into %cast {test_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> return %res : tensor<1x1x8x1xi32> } // ----- +// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { +// CHECK: %[[RES:.*]] = tensor.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +// CHECK: return %[[RES]] : tensor<7x?xi32> +func.func @fold_cast_unpack_dynamic_tile_size( + %src: tensor<1x1x8x1xi32>, + %res: tensor<7x?xi32>) -> tensor<7x?xi32> { + + %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %unpack = tensor.unpack %cast + inner_dims_pos = [0, 1] + inner_tiles = [%c8, 1] + into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> + return %unpack : tensor<7x?xi32> +} + +// ----- + // CHECK-LABEL: func.func @pack_dont_drop_attributes( // CHECK: tensor.pack {{.*}} {test_attr} func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> { diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir index 83cb4b9d4ab24..1de3e281bc462 100644 --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -699,7 +699,7 @@ func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor // ----- -func.func @pack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> { +func.func @unpack_invalid_output_rank(%input: tensor<256x128xf32>, %output: tensor<64x32x16xf32>) -> tensor<256x128xf32> { // expected-error@+1 {{packed rank != (unpacked rank + num tiling factors), got 3 != 4}} %0 = tensor.unpack %output inner_dims_pos = [0, 1] inner_tiles = [32, 16] into %input : tensor<64x32x16xf32> -> tensor<256x128xf32> return %0 : tensor<256x128xf32>