Skip to content

Commit cb393f4

Browse files
committed
[MLIR][Shape] Canonicalize casted extent tensor operands
Both, `shape.broadcast` and `shape.cstr_broadcastable` accept dynamic and static extent tensors. If their operands are casted, we can use the original value instead. Differential Revision: https://reviews.llvm.org/D101376
1 parent d5c2492 commit cb393f4

File tree

2 files changed

+67
-15
lines changed

2 files changed

+67
-15
lines changed

mlir/lib/Dialect/Shape/IR/Shape.cpp

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -628,12 +628,45 @@ struct BroadcastFoldConstantOperandsPattern
628628
return success();
629629
}
630630
};
631+
632+
template <typename OpTy>
633+
struct CanonicalizeCastExtentTensorOperandsPattern
634+
: public OpRewritePattern<OpTy> {
635+
using OpRewritePattern<OpTy>::OpRewritePattern;
636+
637+
LogicalResult matchAndRewrite(OpTy op,
638+
PatternRewriter &rewriter) const override {
639+
// Canonicalize operands.
640+
bool anyChange = false;
641+
auto canonicalizeOperand = [&](Value operand) {
642+
if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
643+
// Only eliminate the cast if it holds no shape information.
644+
bool isInformationLoosingCast =
645+
castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
646+
if (isInformationLoosingCast) {
647+
anyChange = true;
648+
return castOp.source();
649+
}
650+
}
651+
return operand;
652+
};
653+
auto newOperands = llvm::to_vector<8>(
654+
llvm::map_range(op.getOperands(), canonicalizeOperand));
655+
656+
// Rewrite op if any change required.
657+
if (!anyChange)
658+
return failure();
659+
rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
660+
return success();
661+
}
662+
};
631663
} // namespace
632664

633665
void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
634666
MLIRContext *context) {
635667
patterns.add<BroadcastFoldConstantOperandsPattern,
636668
BroadcastForwardSingleOperandPattern,
669+
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
637670
RemoveDuplicateOperandsPattern<BroadcastOp>,
638671
RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
639672
}
@@ -716,7 +749,8 @@ void CstrBroadcastableOp::getCanonicalizationPatterns(
716749
// Canonicalization patterns have overlap with the considerations during
717750
// folding in case additional shape information is inferred at some point that
718751
// does not result in folding.
719-
patterns.add<CstrBroadcastableEqOps,
752+
patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
753+
CstrBroadcastableEqOps,
720754
RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
721755
RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
722756
}
@@ -1188,7 +1222,7 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
11881222
// ```
11891223
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
11901224
// ```
1191-
struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
1225+
struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
11921226
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
11931227

11941228
LogicalResult matchAndRewrite(tensor::CastOp op,
@@ -1214,7 +1248,7 @@ struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
12141248

12151249
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
12161250
MLIRContext *context) {
1217-
patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
1251+
patterns.add<ShapeOfCastExtentTensor, ShapeOfWithTensor>(context);
12181252
}
12191253

12201254
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Shape/canonicalize.mlir

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,8 +1115,8 @@ func @fold_div_mixed() -> !shape.size {
11151115
// CHECK-LABEL: @fold_index_cast_on_index
11161116
func @fold_index_cast_on_index(%arg: index) -> index {
11171117
// CHECK-NOT: size_to_index
1118-
%casted = shape.size_to_index %arg : index
1119-
return %casted : index
1118+
%0 = shape.size_to_index %arg : index
1119+
return %0 : index
11201120
}
11211121

11221122
// -----
@@ -1125,8 +1125,8 @@ func @fold_index_cast_on_index(%arg: index) -> index {
11251125
// CHECK-LABEL: @fold_to_extent_tensor_on_tensor
11261126
func @fold_to_extent_tensor_on_tensor(%arg: tensor<?xindex>) -> tensor<?xindex> {
11271127
// CHECK-NOT: to_extent_tensor
1128-
%casted = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
1129-
return %casted : tensor<?xindex>
1128+
%0 = shape.to_extent_tensor %arg : tensor<?xindex> -> tensor<?xindex>
1129+
return %0 : tensor<?xindex>
11301130
}
11311131

11321132
// -----
@@ -1264,9 +1264,9 @@ func @broadcast_as_from_extent_tensor(%a : tensor<?xindex>) -> !shape.shape {
12641264

12651265
// -----
12661266

1267-
// CHECK-LABEL: @casted_extent_tensor
1267+
// CHECK-LABEL: @cast_extent_tensor
12681268
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
1269-
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
1269+
func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
12701270
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
12711271
// CHECK: return %[[RESULT]] : tensor<?xindex>
12721272
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
@@ -1276,9 +1276,9 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
12761276

12771277
// -----
12781278

1279-
// CHECK-LABEL: @casted_extent_tensor
1279+
// CHECK-LABEL: @cast_extent_tensor
12801280
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
1281-
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
1281+
func @cast_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
12821282
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
12831283
// CHECK: return %[[RESULT]] : tensor<3xindex>
12841284
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
@@ -1288,8 +1288,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
12881288

12891289
// -----
12901290

1291-
// CHECK-LABEL: @casted_extent_tensor
1292-
func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
1291+
// CHECK-LABEL: @cast_extent_tensor
1292+
func @cast_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
12931293
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
12941294
%0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
12951295
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
@@ -1298,8 +1298,8 @@ func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
12981298

12991299
// -----
13001300

1301-
// CHECK-LABEL: @casted_extent_tensor
1302-
func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
1301+
// CHECK-LABEL: @cast_extent_tensor
1302+
func @cast_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
13031303
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
13041304
%0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
13051305
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
@@ -1335,3 +1335,21 @@ func @cstr_broadcastable_folding(%arg : tensor<?x4xf32>) {
13351335
%2 = shape.cstr_broadcastable %0, %1: tensor<2xindex>, tensor<1xindex>
13361336
"use"(%2) : (!shape.witness) -> ()
13371337
}
1338+
1339+
// -----
1340+
1341+
// CHECK-LABEL: @cast_extent_tensor_operands
1342+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<?xindex>, %[[ARG1:.*]]: tensor<3xindex>)
1343+
func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
1344+
%arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
1345+
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
1346+
// CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
1347+
// CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
1348+
// CHECK: return %[[WIT]], %[[RES]]
1349+
%0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
1350+
%1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
1351+
%2 = shape.cstr_broadcastable %0, %1 : tensor<3xindex>, tensor<?xindex>
1352+
%3 = shape.broadcast %0, %1 :tensor<3xindex>, tensor<?xindex>
1353+
-> tensor<?xindex>
1354+
return %2, %3 : !shape.witness, tensor<?xindex>
1355+
}

0 commit comments

Comments
 (0)