diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp index dec678de6d1c2..f35a9cd4cb927 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -67,6 +67,23 @@ struct PadOpTiling : public TilingInterface::ExternalModel { resultSizes.assign(sizes.begin(), sizes.end()); return success(); } + + LogicalResult getIterationDomainTileFromResultTile( + Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl &iterDomainOffsets, + SmallVectorImpl &iterDomainSizes) const { + iterDomainOffsets.assign(offsets.begin(), offsets.end()); + iterDomainSizes.assign(sizes.begin(), sizes.end()); + return success(); + } + + FailureOr + generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber, + ArrayRef offsets, + ArrayRef sizes) const { + return getTiledImplementation(op, b, offsets, sizes); + } }; template diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir index e02ab06a9d533..193fbe93e0f9e 100644 --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -116,6 +116,47 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @fuse_static_pad_tensor_3_4( +// CHECK-SAME: %[[IN:.*]]: tensor<7x9xf32> +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]] +// CHECK: scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = +// CHECK: %[[SWAP_RESULT:.*]] = scf.if +// CHECK: tensor.generate +// CHECK: else +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[IN]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1] +// CHECK: %[[PAD:.*]] = tensor.pad %[[SLICE]] +// CHECK: %[[COPY:.*]] = linalg.copy ins(%[[SWAP_RESULT:.*]] +// CHECK: tensor.insert_slice %[[COPY]] into %[[INNER_OUT]][{{.*}}, {{.*}}] [{{.*}}, {{.*}}] [1, 1] +// CHECK: return %[[RESULT]] + +func.func @fuse_static_pad_tensor_3_4(%input_tensor: tensor<7x9xf32>, + %pad_value: f32) -> tensor<15x16xf32> { + %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad_value : f32 + } : tensor<7x9xf32> to tensor<15x16xf32> + %empty = tensor.empty() : tensor<15x16xf32> + %1 = linalg.copy ins(%0 : tensor<15x16xf32>) outs(%empty : tensor<15x16xf32>) -> tensor<15x16xf32> + return %1 : tensor<15x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.fuse %copy [2, 3] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + +// ----- + // CHECK-LABEL: func @static_pad_tensor_0_3( // CHECK-SAME: %[[IN:.*]]: tensor<7x9xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index