diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 6700b4e0c2cb6..8718c57b9e86c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1017,9 +1017,22 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( sliceOp.getMixedSizes(), zeroSliceGuard); if (failed(tilingResult)) return failure(); - // All shapes are static and the data source is actually used. Rewrite into - // pad(extract_slice(x)). - rewriter.replaceOp(sliceOp, tilingResult->tiledValues); + + RankedTensorType sourceType = sliceOp.getSourceType(); + RankedTensorType resultType = sliceOp.getResultType(); + + // If the extract_slice is not rank-reduced, all shapes are static and the + // data source is actually used. Rewrite into pad(extract_slice(x)). + if (sourceType.getRank() == resultType.getRank()) { + rewriter.replaceOp(sliceOp, tilingResult->tiledValues); + return success(); + } + + // Handle rank-reduced slice by creating another extract_slice op. + Value rankReduced = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, sliceOp.getLoc(), tilingResult->tiledValues[0], resultType); + + rewriter.replaceOp(sliceOp, rankReduced); return success(); } diff --git a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir index d43b9a7ac6c04..6a056bab98807 100644 --- a/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir +++ b/mlir/test/Dialect/Linalg/subtensor-of-padtensor.mlir @@ -129,6 +129,26 @@ func.func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32) // ----- +// CHECK-LABEL: @static_rank_reduce +// CHECK-SAME: %[[ARG0:.*]]: tensor<8x16x4xf32>, %[[PADVAL:.*]]: f32 +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, 0] [1, 14, 4] [1, 1, 1] : tensor<8x16x4xf32> to tensor<1x14x4xf32> +// CHECK: %[[PADDED:.*]] = tensor.pad %[[SLICE]] low[0, 2, 0] high[0, 0, 0] { +// CHECK: } : tensor<1x14x4xf32> to tensor<1x16x4xf32> +// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[PADDED]][0, 0, 0] [1, 16, 4] [1, 1, 1] : tensor<1x16x4xf32> to tensor<16x4xf32> +// CHECK: return %[[RESULT]] +func.func @static_rank_reduce(%arg0: tensor<8x16x4xf32>, %pad: f32) + -> tensor<16x4xf32> { + %0 = tensor.pad %arg0 low[0, 2, 0] high[0, 0, 0] { + ^bb0(%i: index, %j: index, %k: index): + tensor.yield %pad : f32 + } : tensor<8x16x4xf32> to tensor<8x18x4xf32> + %1 = tensor.extract_slice %0[0, 0, 0] [1, 16, 4] [1, 1, 1] + : tensor<8x18x4xf32> to tensor<16x4xf32> + return %1 : tensor<16x4xf32> +} + +// ----- + // CHECK-LABEL: @dynamic_high_pad // CHECK-SAME: %[[ARG0:.*]]: tensor // CHECK-NOT: tensor.pad @@ -217,6 +237,27 @@ func.func @dynamic_zero_high_padding(%arg0 : tensor, %pad : f32, return %1 : tensor } +// ----- + +// CHECK-LABEL: @dynamic_rank_reduce +// CHECK: %[[TEMP:.*]] = scf.if %{{.*}} -> (tensor<1x4xf32>) { +// CHECK: tensor.generate +// CHECK: } else { +// CHECK: %[[SLICE:.*]] = tensor.extract_slice %{{.*}} : tensor to tensor +// CHECK: tensor.pad %[[SLICE]] low[0, 0] high[%{{.*}}, 3] { +// CHECK: } : tensor to tensor<1x4xf32> +// CHECK: } +// CHECK: %[[RESULT:.*]] = tensor.extract_slice %[[TEMP]]{{.*}} : tensor<1x4xf32> to tensor<4xf32> +// CHECK: return %[[RESULT]] +func.func @dynamic_rank_reduce(%arg0 : tensor, %s1: index, %pad : f32) -> tensor<4xf32> { + %0 = tensor.pad %arg0 low[0, 0] high[7, 8] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %pad : f32 + } : tensor to tensor + %1 = tensor.extract_slice %0[2, 4] [1, 4] [1, 1] : tensor to tensor<4xf32> + return %1 : tensor<4xf32> +} + // ----- // CHECK-LABEL: @nopaddim_with_dynamic_extract( // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32>