-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][Tensor] Use output_shape for ExpandShapeOp type inference #118202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesWe already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the shape inference directly use the available output shape. Full diff: https://github.com/llvm/llvm-project/pull/118202.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
index 7ff435a033985c..ebb88bf695d4c2 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
@@ -16,24 +16,6 @@
using namespace mlir;
using namespace mlir::tensor;
-/// Compute a map that for a given dimension of the expanded type gives the
-/// dimension in the collapsed type it maps to. Essentially its the inverse of
-/// the `reassocation` maps.
-static llvm::DenseMap<int64_t, int64_t>
-getExpandedDimToCollapsedDimMap(ArrayRef<AffineMap> reassociation) {
- llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim;
- for (const auto &map : enumerate(reassociation)) {
- unsigned startPos =
- cast<AffineDimExpr>(map.value().getResults().front()).getPosition();
- unsigned endPos =
- cast<AffineDimExpr>(map.value().getResults().back()).getPosition();
- for (auto dim : llvm::seq_inclusive(startPos, endPos)) {
- expandedDimToCollapsedDim[dim] = map.index();
- }
- }
- return expandedDimToCollapsedDim;
-}
-
/// For reshape op compute the shape at dimension `dimIndex` of the output in
/// terms of shape of the `src`, when the reshape op is a collapsing
/// operation. It is the product of the shape of the collapsed dimensions of the
@@ -76,86 +58,33 @@ static SmallVector<OpFoldResult, 4> getCollapsedOutputShapeFromInputShape(
}));
}
-/// For an expanding reshape op, compute the value for a dimension of the output
-/// from the shape of the input.
-static OpFoldResult getExpandedOutputDimFromInputShape(
- OpBuilder &builder, Location loc, int64_t dimIndex, Value src,
- ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation,
- llvm::DenseMap<int64_t, int64_t> &expandedDimToCollapsedDim) {
- if (!ShapedType::isDynamic(dstStaticShape[dimIndex])) {
- // Static dimension: return Attribute.
- return builder.getIndexAttr(dstStaticShape[dimIndex]);
- }
- unsigned sourceDimPos = expandedDimToCollapsedDim[dimIndex];
- unsigned startPos =
- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().front())
- .getPosition();
- unsigned endPos =
- cast<AffineDimExpr>(reassociation[sourceDimPos].getResults().back())
- .getPosition();
- int64_t linearizedStaticDim = 1;
- for (auto d :
- llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
- if (d.index() + startPos == static_cast<unsigned>(dimIndex))
- continue;
- assert(!ShapedType::isDynamic(d.value()) &&
- "single dimension cannot be expanded into multiple dynamic "
- "dimensions");
- linearizedStaticDim *= d.value();
+struct ReifyCollapseShapeOp
+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<
+ ReifyCollapseShapeOp, CollapseShapeOp> {
+ LogicalResult
+ reifyResultShapes(Operation *op, OpBuilder &b,
+ ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
+ auto loc = op->getLoc();
+ auto collapseShape = cast<CollapseShapeOp>(op);
+ reifiedReturnShapes.push_back(getCollapsedOutputShapeFromInputShape(
+ b, loc, collapseShape.getSrc(),
+ collapseShape.getResultType().getShape(),
+ collapseShape.getReassociationMaps()));
+ return success();
}
- OpFoldResult sourceDim =
- builder.create<tensor::DimOp>(loc, src, sourceDimPos).getResult();
-
- // Dynamic dimension: return Value.
- return affine::makeComposedAffineApply(
- builder, loc,
- AffineMap::get(
- 0, 1,
- builder.getAffineSymbolExpr(0).floorDiv(linearizedStaticDim)),
- sourceDim)
- ->getResult(0);
-}
-
-/// Given the `src` of an expanding reshape op, the reassociation maps and the
-/// result type, compute the shape of the result of the reshape.
-static SmallVector<OpFoldResult, 4> getExpandedOutputShapeFromInputShape(
- OpBuilder &builder, Location loc, Value src,
- ArrayRef<int64_t> dstStaticShape, ArrayRef<AffineMap> reassociation) {
- llvm::DenseMap<int64_t, int64_t> expandedDimToCollapsedDim =
- getExpandedDimToCollapsedDimMap(reassociation);
- return llvm::to_vector<4>(llvm::map_range(
- llvm::seq<int64_t>(0, dstStaticShape.size()), [&](int64_t dim) {
- return getExpandedOutputDimFromInputShape(builder, loc, dim, src,
- dstStaticShape, reassociation,
- expandedDimToCollapsedDim);
- }));
-}
-
-static SmallVector<OpFoldResult, 4>
-getReshapeOutputShapeFromInputShape(OpBuilder &builder, Location loc, Value src,
- ArrayRef<int64_t> dstStaticShape,
- ArrayRef<AffineMap> reassocation) {
- return dstStaticShape.size() >
- static_cast<size_t>(
- llvm::cast<ShapedType>(src.getType()).getRank())
- ? getExpandedOutputShapeFromInputShape(
- builder, loc, src, dstStaticShape, reassocation)
- : getCollapsedOutputShapeFromInputShape(
- builder, loc, src, dstStaticShape, reassocation);
-}
+};
-template <typename OpTy>
-struct ReifyExpandOrCollapseShapeOp
- : public ReifyRankedShapedTypeOpInterface::ExternalModel<
- ReifyExpandOrCollapseShapeOp<OpTy>, OpTy> {
+struct ReifyExpandShapeOp
+ : public ReifyRankedShapedTypeOpInterface::ExternalModel<ReifyExpandShapeOp,
+ ExpandShapeOp> {
LogicalResult
reifyResultShapes(Operation *op, OpBuilder &b,
ReifiedRankedShapedTypeDims &reifiedReturnShapes) const {
auto loc = op->getLoc();
- auto reshapeOp = cast<OpTy>(op);
- reifiedReturnShapes.push_back(getReshapeOutputShapeFromInputShape(
- b, loc, reshapeOp.getSrc(), reshapeOp.getResultType().getShape(),
- reshapeOp.getReassociationMaps()));
+ auto expandShape = cast<ExpandShapeOp>(op);
+ SmallVector<OpFoldResult> outputShape = getMixedValues(
+ expandShape.getStaticOutputShape(), expandShape.getOutputShape(), b);
+ reifiedReturnShapes.push_back(outputShape);
return success();
}
};
@@ -202,10 +131,8 @@ struct ReifyPadOp
void mlir::tensor::registerInferTypeOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) {
- ExpandShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::ExpandShapeOp>>(*ctx);
- CollapseShapeOp::attachInterface<
- ReifyExpandOrCollapseShapeOp<tensor::CollapseShapeOp>>(*ctx);
+ ExpandShapeOp::attachInterface<ReifyExpandShapeOp>(*ctx);
+ CollapseShapeOp::attachInterface<ReifyCollapseShapeOp>(*ctx);
PadOp::attachInterface<ReifyPadOp>(*ctx);
});
}
diff --git a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
index 8fb84248c9613b..0595ac2492c97a 100644
--- a/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
+++ b/mlir/test/Dialect/Linalg/resolve-shaped-type-result-dims.mlir
@@ -210,15 +210,12 @@ func.func @dim_reshape_expansion(%arg0 : tensor<6x5x?xf32>, %sz0: index) -> (ind
%3 = tensor.dim %0, %c4 : tensor<2x3x5x4x?x7xf32>
return %1, %2, %3 : index, index, index
}
-// CHECK: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: func @dim_reshape_expansion
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<6x5x?xf32>
-// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index
// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index
-// CHECK: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C2]]
-// CHECK: %[[D1:.+]] = affine.apply #[[MAP]]()[%[[D0]]]
-// CHECK: return %[[C3]], %[[C4]], %[[D1]]
+// CHECK: return %[[C3]], %[[C4]], %[[ARG1]]
// -----
diff --git a/mlir/test/Dialect/Tensor/fold-empty-op.mlir b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
index 65ceb4ff3e3df4..d3889f23e7d742 100644
--- a/mlir/test/Dialect/Tensor/fold-empty-op.mlir
+++ b/mlir/test/Dialect/Tensor/fold-empty-op.mlir
@@ -10,7 +10,6 @@ module attributes {transform.with_named_sequence} {
}
}
-// CHECK: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 28)>
// CHECK: #[[$MAP2:.+]] = affine_map<()[s0] -> (s0 * 28)>
func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4x?x7xf32> {
@@ -19,11 +18,8 @@ func.func @empty_reshape_expansion(%arg0 : index, %sz0: index) -> tensor<2x3x5x4
return %1 : tensor<2x3x5x4x?x7xf32>
}
// CHECK-LABEL: func @empty_reshape_expansion
-// CHECK-SAME: %[[ARG0:.+]]: index
-// CHECK: %[[OLD_INIT:.+]] = tensor.empty(%{{.*}}) : tensor<6x5x?xf32>
-// CHECK-NEXT: %[[DIM:.*]] = tensor.dim %[[OLD_INIT]]
-// CHECK-NEXT: %[[D:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
-// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[D]])
+// CHECK-SAME: %[[ARG0:.+]]: index, %[[ARG1:.+]]: index
+// CHECK-NEXT: %[[INIT:.+]] = tensor.empty(%[[ARG1]])
// CHECK-NEXT: return %[[INIT]]
func.func @empty_reshape_collapse(%arg0 : index) -> tensor<6x5x?xf32> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice cleanup, LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a few follow ups here. Please wait for me to get back to this.
Landed as #113501 |
We already have the output shape available in the operation, so there is no need to do any arithmetic to figure it out. This PR makes the shape inference directly use the available output shape.