Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 22 additions & 15 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"

namespace mlir {
namespace bufferization {
Expand Down Expand Up @@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad,
/// In the future, more general interfaces can be devised to encode similar
/// shape evolutions and map between an op and its operands.
SmallVector<OpFoldResult>
computePaddedShape(RewriterBase &rewriter, TypedValue<RankedTensorType> v,
computePaddedShape(OpBuilder &, TypedValue<RankedTensorType> v,
AffineMap indexingMap, ArrayRef<OpFoldResult> indexingSizes,
const PadTilingInterfaceOptions &options);

using PadSizeComputationFunction =
std::function<FailureOr<SmallVector<OpFoldResult>>(
RewriterBase &, OpOperand &, ArrayRef<Range>,
OpBuilder &, OpOperand &, ArrayRef<Range>,
const PadTilingInterfaceOptions &)>;

/// Specific helper for Linalg ops.
FailureOr<SmallVector<OpFoldResult>> computeIndexingMapOpInterfacePaddedShape(
RewriterBase &rewriter, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain, const PadTilingInterfaceOptions &options);
FailureOr<SmallVector<OpFoldResult>>
computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad,
ArrayRef<Range> iterationDomain,
const PadTilingInterfaceOptions &);

/// Operations and values created in the process of padding a TilingInterface
/// operation.
struct PadTilingInterfaceResult {
/// The operands of the padded op.
SmallVector<tensor::PadOp> padOps;
/// The padded op, a clone of `toPad` with padded operands.
TilingInterface paddedOp;
/// Slices of the padded op's results, same types as `toPad`.
SmallVector<Value> replacements;
};

/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`.
///
/// Pad the iterator dimensions of `toPad`.
/// * "options.paddingSizes" indicates that each padding dimension should be
/// padded to the specified padding size.
/// * "options.padToMultipleOf" indicates that the paddingSizes should be
// interpreted as the bounding box (dynamic) value to pad to.
/// * Use "options.paddingValues" to set the padding value of the created
// tensor::PadOp.
/// * The tensor::PadOp is returned on success.

FailureOr<TilingInterface>
rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad,
const PadTilingInterfaceOptions &constOptions,
SmallVector<tensor::PadOp> &padOps,
const PadSizeComputationFunction &computePaddingSizeFun =
FailureOr<PadTilingInterfaceResult>
rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad,
PadTilingInterfaceOptions options,
const PadSizeComputationFunction & =
&computeIndexingMapOpInterfacePaddedShape);

namespace detail {
Expand Down
16 changes: 7 additions & 9 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2457,26 +2457,24 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter,
}

// Set options.
TilingInterface paddedOp;
PadTilingInterfaceOptions options;
options.setPaddingValues(paddingValues)
.setPaddingSizes(getMixedPaddingSizes())
.setPadToMultipleOf(getPadToMultipleOf());

// Apply padding.
SmallVector<tensor::PadOp> newPadOps;
FailureOr<TilingInterface> maybePaddedOp = rewriteAsPaddedOp(
rewriter, cast<TilingInterface>(targetOp.getOperation()), options,
newPadOps);
if (failed(maybePaddedOp)) {
auto maybePadOps = rewriteAsPaddedOp(
rewriter, cast<TilingInterface>(targetOp.getOperation()), options);
if (failed(maybePadOps)) {
auto diag = emitSilenceableError() << "failed to pad op";
diag.attachNote(target->getLoc()) << "target op";
return diag;
}
const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value();

// Set transform results.
paddedOps.push_back(cast<TilingInterface>(maybePaddedOp->getOperation()));
padOps.append(newPadOps.begin(), newPadOps.end());
paddedOps.push_back(paddedOp);
padOps.append(paddedOperands.begin(), paddedOperands.end());
rewriter.replaceOp(targetOp.getOperation(), slicedResults);
}

results.set(cast<OpResult>(getPadded()), paddedOps);
Expand Down
Loading