From 219ec7ff3ed43d5c1696827a06395172f0e9918a Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 7 Nov 2024 17:17:15 -0500 Subject: [PATCH] [mlir][Tensor] Retain discardable attrs back in cast(pack) folder --- mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 1 + mlir/test/Dialect/Tensor/canonicalize.mlir | 10 ++++++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 8e0d010439746..147120e0e3420 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4798,6 +4798,7 @@ struct FoldTensorCastPackOp : public OpRewritePattern { PackOp newOp = rewriter.create( op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); + newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); // Replace op. Value oldResult = op.getResult(); diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir index 2c826d7ae008d..0b54c207dea84 100644 --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -2750,7 +2750,10 @@ func.func @fold_cast_multiple_results(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2x // CHECK-SAME: %[[DEST:.*]]: tensor<1x1x8x1xi32>, // CHECK-SAME: %[[SRC:.*]]: tensor<7x?xi32>, // CHECK-SAME: %[[PAD:.*]]: i32) -> tensor<1x1x8x1xi32> { -// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] : tensor<7x?xi32> -> tensor<1x1x8x1xi32> +// CHECK: %[[PACK:.*]] = tensor.pack %[[SRC]] padding_value(%[[PAD]] : i32) +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] +// CHECK-SAME: some_attr +// CHECK-SAME: : tensor<7x?xi32> -> tensor<1x1x8x1xi32> // CHECK: return %[[PACK]] : tensor<1x1x8x1xi32> func.func @fold_cast_pack_dynamic_tile_size( %dest: tensor<1x1x8x1xi32>, @@ -2759,7 +2762,10 @@ func.func @fold_cast_pack_dynamic_tile_size( %cast = tensor.cast %dest : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> %c8 = arith.constant 8 : index - %pack = tensor.pack %src padding_value(%pad : i32) inner_dims_pos = [0, 1] inner_tiles = [%c8, 1] into %cast : tensor<7x?xi32> -> tensor<1x1x?x1xi32> + %pack = tensor.pack %src padding_value(%pad : i32) + inner_dims_pos = [0, 1] + inner_tiles = [%c8, 1] + into %cast {some_attr} : tensor<7x?xi32> -> tensor<1x1x?x1xi32> %res = tensor.cast %pack : tensor<1x1x?x1xi32> to tensor<1x1x8x1xi32> return %res : tensor<1x1x8x1xi32> }