-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR][Linalg] Modify rewriteAsPaddedOp
to not remove pre-padded op
#163467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
f3a357b
to
dab6342
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: James Newling (newling) ChangesRefactor/redesign I previously found it difficult to work with this API (in IREE), as the original (pre-padded) operation was still useful for a while after it's replacement was created. I believe @Groverkss also has a use case where he wants the pre-padded value to stick around. Full diff: https://github.com/llvm/llvm-project/pull/163467.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index ae7a085a1f7a8..db75379cc21c0 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -25,7 +25,6 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
-#include "llvm/ADT/SmallSet.h"
namespace mlir {
namespace bufferization {
@@ -621,35 +620,39 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
-computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
+computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);
using PadSizeComputationFunction =
std::function<FailureOr<SmallVector<OpFoldResult>>(
- RewriterBase &, OpOperand &, ArrayRef<Range>,
+ OpBuilder &, OpOperand &, ArrayRef<Range>,
const PadTilingInterfaceOptions &)>;
/// Specific helper for Linalg ops.
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
-/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
-///
+/// Pad the iterator dimensions of `toPad`.
/// * "options.paddingSizes" indicates that each padding dimension should be
/// padded to the specified padding size.
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
// interpreted as the bounding box (dynamic) value to pad to.
/// * Use "options.paddingValues" to set the padding value of the created
// tensor::PadOp.
-/// * The tensor::PadOp is returned on success.
-
-FailureOr<TilingInterface>
-rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
- const PadSizeComputationFunction &computePaddingSizeFun =
+struct PadTilingInterfaceResult {
+ /// The operands of the padded op.
+ SmallVector<tensor::PadOp> padOps;
+ /// The padded op, a clone of `toPad` with padded operands.
+ TilingInterface paddedOp;
+ /// Slices of the padded op's results, same types as `toPad`.
+ SmallVector<Value> replacements;
+};
+FailureOr<PadTilingInterfaceResult>
+rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
+ const PadSizeComputationFunction & =
&computeIndexingMapOpInterfacePaddedShape);
namespace detail {
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index d8f983f98ae77..8ac882024e58d 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2457,26 +2457,27 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
}
// Set options.
- TilingInterface paddedOp;
PadTilingInterfaceOptions options;
options.setPaddingValues(paddingValues)
.setPaddingSizes(getMixedPaddingSizes())
.setPadToMultipleOf(getPadToMultipleOf());
- // Apply padding.
- SmallVector<tensor::PadOp> newPadOps;
- FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
- rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
- newPadOps);
- if (failed(maybePaddedOp)) {
+ auto maybePadOps = rewriteAsPaddedOp(
+ rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
+ if (failed(maybePadOps)) {
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
+ const auto &[paddedOperands, paddedOp, slicedResults] = *maybePadOps;
+
// Set transform results.
- paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
- padOps.append(newPadOps.begin(), newPadOps.end());
+ paddedOps.push_back(paddedOp);
+ padOps.append(paddedOperands.begin(), paddedOperands.end());
+
+ // erase targetOp:
+ rewriter.replaceOp(targetOp.getOperation(), slicedResults);
}
results.set(cast<OpResult>(getPadded()), paddedOps);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
index 0956c5d771394..513ce2c52ec87 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp
@@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) {
/// - affine_map<(d0, d1) -> (d0 * 3 + d1)>
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
-SmallVector<OpFoldResult> linalg::computePaddedShape(
- RewriterBase &rewriter, TypedValue<RankedTensorType> v,
- AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
- const PadTilingInterfaceOptions &options) {
+SmallVector<OpFoldResult>
+linalg::computePaddedShape(OpBuilder &rewriter, TypedValue<RankedTensorType> v,
+ AffineMap indexingMap,
+ ArrayRef<OpFoldResult> indexingSizes,
+ const PadTilingInterfaceOptions &options) {
Location loc = v.getLoc();
SmallVector<OpFoldResult> paddedShape;
auto tensorType = cast<RankedTensorType>(v.getType());
@@ -198,7 +199,7 @@ SmallVector<OpFoldResult> linalg::computePaddedShape(
FailureOr<SmallVector<OpFoldResult>>
linalg::computeIndexingMapOpInterfacePaddedShape(
- RewriterBase &rewriter, OpOperand &operandToPad,
+ OpBuilder &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options) {
auto transferOp =
llvm::dyn_cast<IndexingMapOpInterface>(operandToPad.getOwner());
@@ -224,7 +225,7 @@ linalg::computeIndexingMapOpInterfacePaddedShape(
/// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding
/// Value.
-static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
+static Value padOperand(OpBuilder &rewriter, TilingInterface opToPad,
TypedValue<RankedTensorType> v,
ArrayRef<OpFoldResult> paddedShape,
Attribute paddingValueAttr) {
@@ -263,45 +264,44 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad,
paddingValue, /*nofold=*/false, dynDims);
}
-FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
- RewriterBase &rewriter, TilingInterface opToPad,
- const PadTilingInterfaceOptions &constOptions,
- SmallVector<tensor::PadOp> &padOps,
+FailureOr<PadTilingInterfaceResult> linalg::rewriteAsPaddedOp(
+ OpBuilder &builder, TilingInterface toPad,
+ PadTilingInterfaceOptions options,
const PadSizeComputationFunction &computePaddingSizeFun) {
- LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n");
+ LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n");
+ SmallVector<tensor::PadOp> padOps;
+ Location loc = toPad.getLoc();
- Location loc = opToPad.getLoc();
- PadTilingInterfaceOptions options(constOptions);
// Allow inference of pad values if they are not explicitly specified.
// TODO: be mindful about the value depending on the actual operation.
if (options.paddingValues.empty()) {
- SmallVector<Type> types(opToPad->getOperandTypes());
- llvm::append_range(types, opToPad->getResultTypes());
+ SmallVector<Type> types(toPad->getOperandTypes());
+ llvm::append_range(types, toPad->getResultTypes());
for (Type t : types) {
options.paddingValues.push_back(
- rewriter.getZeroAttr(getElementTypeOrSelf(t)));
+ builder.getZeroAttr(getElementTypeOrSelf(t)));
}
}
- if (llvm::any_of(opToPad->getOperands(),
+ if (llvm::any_of(toPad->getOperands(),
[](Value v) { return isa<MemRefType>(v.getType()); })) {
- return rewriter.notifyMatchFailure(opToPad,
- "expected operation on tensors");
+ LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n");
+ return failure();
}
- OpBuilder::InsertionGuard g(rewriter);
- // Set IP after opToPad because we also take the dims of opToPad's output.
- rewriter.setInsertionPointAfter(opToPad);
+ OpBuilder::InsertionGuard g(builder);
+ // Set IP after toPad because we also take the dims of toPad's output.
+ builder.setInsertionPointAfter(toPad);
// 1. Get the loopUpperBounds from the TilingInterface.
- SmallVector<Range> iterationDomain = opToPad.getIterationDomain(rewriter);
+ SmallVector<Range> iterationDomain = toPad.getIterationDomain(builder);
// 2. For each operand.
SmallVector<Value> newOperands;
- newOperands.reserve(opToPad->getNumOperands());
- for (OpOperand &opOperand : opToPad->getOpOperands()) {
+ newOperands.reserve(toPad->getNumOperands());
+ for (OpOperand &opOperand : toPad->getOpOperands()) {
Value operand = opOperand.get();
- LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n");
+ LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n");
// 2.a. Skip scalar-like operands.
Type operandType = operand.getType();
@@ -311,30 +311,31 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
newOperands.push_back(operand);
continue;
}
+
// 2.a. Compute padded shape.
FailureOr<SmallVector<OpFoldResult>> maybePaddedShape =
- computePaddingSizeFun(rewriter, opOperand, iterationDomain, options);
+ computePaddingSizeFun(builder, opOperand, iterationDomain, options);
if (failed(maybePaddedShape)) {
- return rewriter.notifyMatchFailure(opToPad, "could not pad op");
+ LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n");
+ return failure();
}
// 2.b. Expect proper `paddingValues`.
// TODO: we may want to allow garbage padding in the future, in which case
// we would just not assert.
if (opOperand.getOperandNumber() >= options.paddingValues.size()) {
- return rewriter.notifyMatchFailure(opToPad,
- "--no padding value specified");
+ LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n");
+ return failure();
}
Attribute paddingValueAttr =
options.paddingValues[opOperand.getOperandNumber()];
// 2.c. Perform actual padding.
- Value paddedOperand = padOperand(
- rewriter, opToPad, cast<TypedValue<RankedTensorType>>(operand),
- *maybePaddedShape, paddingValueAttr);
+ Value paddedOperand =
+ padOperand(builder, toPad, cast<TypedValue<RankedTensorType>>(operand),
+ *maybePaddedShape, paddingValueAttr);
LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n");
- // 2.d. Perform actual padding.
newOperands.push_back(paddedOperand);
if (auto padOp = paddedOperand.getDefiningOp<tensor::PadOp>())
padOps.push_back(padOp);
@@ -342,38 +343,34 @@ FailureOr<TilingInterface> linalg::rewriteAsPaddedOp(
// 3. Form the resulting tensor::ExtractSliceOp.
ReifiedRankedShapedTypeDims reifiedResultShapes;
- if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) {
- LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n");
- return rewriter.notifyMatchFailure(opToPad,
- "failed to reify result shapes");
+ if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) {
+ LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n");
+ return failure();
}
- assert(reifiedResultShapes.size() == opToPad->getNumResults() &&
+ assert(reifiedResultShapes.size() == toPad->getNumResults() &&
"expected same number of results");
- // Clone `opToPad` to operate on the statically padded shapes.
+ // Clone `toPad` to operate on the statically padded shapes.
auto resultTensorTypes =
- ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes();
- // clone **should** properly notify the rewriter.
+ ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes();
+ // clone **should** properly notify the builder.
TilingInterface paddedOp =
- clone(rewriter, opToPad, resultTensorTypes, newOperands);
+ clone(builder, toPad, resultTensorTypes, newOperands);
LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n");
- // Recover the slice out of the new static results. This keeps the original
- // opToPad around because it uses the dims of the original results.
+ // Recover the slice out of the new static results.
SmallVector<Value> paddedSubtensorResults;
- paddedSubtensorResults.reserve(opToPad->getNumResults());
+ paddedSubtensorResults.reserve(toPad->getNumResults());
for (const auto &en : llvm::enumerate(paddedOp->getResults())) {
Value paddedResult = en.value();
int64_t resultNumber = en.index();
int64_t rank = cast<RankedTensorType>(paddedResult.getType()).getRank();
- SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
- SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
+ SmallVector<OpFoldResult> offsets(rank, builder.getIndexAttr(0));
+ SmallVector<OpFoldResult> strides(rank, builder.getIndexAttr(1));
paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create(
- rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
+ builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber],
strides));
}
- rewriter.replaceOp(opToPad, paddedSubtensorResults);
-
- return paddedOp;
+ return PadTilingInterfaceResult{padOps, paddedOp, paddedSubtensorResults};
}
|
Signed-off-by: James Newling <[email protected]>
dab6342
to
a50f53f
Compare
Signed-off-by: James Newling <[email protected]>
Refactor/redesign
FailureOr<TilingInterface> rewriteAsPaddedOp(...)
to not remove unpadded operation. This is more in line with how other transformations like tiling work, where the user of the transformation decides when to replace the actual operation. Instead of this, return all info as a struct.