diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index ae7a085a1f7a8..c89fc59c91830 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -25,7 +25,6 @@ #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallSet.h" namespace mlir { namespace bufferization { @@ -621,35 +620,43 @@ LogicalResult rewriteAsPaddedOp(RewriterBase &rewriter, LinalgOp opToPad, /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. SmallVector -computePaddedShape(RewriterBase &rewriter, TypedValue v, +computePaddedShape(OpBuilder &, TypedValue v, AffineMap indexingMap, ArrayRef indexingSizes, const PadTilingInterfaceOptions &options); using PadSizeComputationFunction = std::function>( - RewriterBase &, OpOperand &, ArrayRef, + OpBuilder &, OpOperand &, ArrayRef, const PadTilingInterfaceOptions &)>; /// Specific helper for Linalg ops. -FailureOr> computeIndexingMapOpInterfacePaddedShape( - RewriterBase &rewriter, OpOperand &operandToPad, - ArrayRef iterationDomain, const PadTilingInterfaceOptions &options); +FailureOr> +computeIndexingMapOpInterfacePaddedShape(OpBuilder &, OpOperand &operandToPad, + ArrayRef iterationDomain, + const PadTilingInterfaceOptions &); + +/// Operations and values created in the process of padding a TilingInterface +/// operation. +struct PadTilingInterfaceResult { + /// The operands of the padded op. + SmallVector padOps; + /// The padded op, a clone of `toPad` with padded operands. + TilingInterface paddedOp; + /// Slices of the padded op's results, same types as `toPad`. + SmallVector replacements; +}; -/// Pad the iterator dimensions `options.paddingDimensions` of `opToPad`. -/// +/// Pad the iterator dimensions of `toPad`. /// * "options.paddingSizes" indicates that each padding dimension should be /// padded to the specified padding size. /// * "options.padToMultipleOf" indicates that the paddingSizes should be // interpreted as the bounding box (dynamic) value to pad to. /// * Use "options.paddingValues" to set the padding value of the created // tensor::PadOp. -/// * The tensor::PadOp is returned on success. - -FailureOr -rewriteAsPaddedOp(RewriterBase &rewriter, TilingInterface opToPad, - const PadTilingInterfaceOptions &constOptions, - SmallVector &padOps, - const PadSizeComputationFunction &computePaddingSizeFun = +FailureOr +rewriteAsPaddedOp(OpBuilder &, TilingInterface toPad, + PadTilingInterfaceOptions options, + const PadSizeComputationFunction & = &computeIndexingMapOpInterfacePaddedShape); namespace detail { diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 6192d791f87aa..9a8a63e54d02d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2457,26 +2457,24 @@ transform::PadTilingInterfaceOp::apply(transform::TransformRewriter &rewriter, } // Set options. - TilingInterface paddedOp; PadTilingInterfaceOptions options; options.setPaddingValues(paddingValues) .setPaddingSizes(getMixedPaddingSizes()) .setPadToMultipleOf(getPadToMultipleOf()); - // Apply padding. - SmallVector newPadOps; - FailureOr maybePaddedOp = rewriteAsPaddedOp( - rewriter, cast(targetOp.getOperation()), options, - newPadOps); - if (failed(maybePaddedOp)) { + auto maybePadOps = rewriteAsPaddedOp( + rewriter, cast(targetOp.getOperation()), options); + if (failed(maybePadOps)) { auto diag = emitSilenceableError() << "failed to pad op"; diag.attachNote(target->getLoc()) << "target op"; return diag; } + const auto &[paddedOperands, paddedOp, slicedResults] = maybePadOps.value(); // Set transform results. - paddedOps.push_back(cast(maybePaddedOp->getOperation())); - padOps.append(newPadOps.begin(), newPadOps.end()); + paddedOps.push_back(paddedOp); + padOps.append(paddedOperands.begin(), paddedOperands.end()); + rewriter.replaceOp(targetOp.getOperation(), slicedResults); } results.set(cast(getPadded()), paddedOps); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp index 0956c5d771394..3e787a2ad0ef5 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PadTilingInterface.cpp @@ -95,10 +95,11 @@ static int64_t extractConstantMultiplier(AffineExpr expr) { /// - affine_map<(d0, d1) -> (d0 * 3 + d1)> /// In the future, more general interfaces can be devised to encode similar /// shape evolutions and map between an op and its operands. -SmallVector linalg::computePaddedShape( - RewriterBase &rewriter, TypedValue v, - AffineMap indexingMap, ArrayRef indexingSizes, - const PadTilingInterfaceOptions &options) { +SmallVector +linalg::computePaddedShape(OpBuilder &builder, TypedValue v, + AffineMap indexingMap, + ArrayRef indexingSizes, + const PadTilingInterfaceOptions &options) { Location loc = v.getLoc(); SmallVector paddedShape; auto tensorType = cast(v.getType()); @@ -109,7 +110,7 @@ SmallVector linalg::computePaddedShape( // "Full-rank" padding specification. SmallVector paddingSizes = - getFullRankPaddingSizes(rewriter, indexingSizes, options); + getFullRankPaddingSizes(builder, indexingSizes, options); // For each dimension in the operand's shape, iterate over indexingSizes and // add the various term contributions. @@ -147,28 +148,27 @@ SmallVector linalg::computePaddedShape( OpFoldResult paddingDimOfr; if (options.padToMultipleOf) { AffineExpr d0, s0; - bindDims(rewriter.getContext(), d0); - bindSymbols(rewriter.getContext(), s0); + bindDims(builder.getContext(), d0); + bindSymbols(builder.getContext(), s0); AffineMap ceilMap = AffineMap::get(1, 1, d0.ceilDiv(s0) * s0); AffineMap composedMap = projectedMap.compose(ceilMap); paddingDimOfr = affine::makeComposedFoldedAffineApply( - rewriter, loc, composedMap, - {indexingSizes[paddingDim], paddingSize}, + builder, loc, composedMap, {indexingSizes[paddingDim], paddingSize}, /*composeAffineMin=*/true); } else { // Otherwise just set to paddingSize. paddingDimOfr = affine::makeComposedFoldedAffineApply( - rewriter, loc, projectedMap, paddingSize); + builder, loc, projectedMap, paddingSize); } // Adjust for the maximum accessed index, which is (paddingSize - 1) * // multiplier. AffineExpr d0; - bindDims(rewriter.getContext(), d0); + bindDims(builder.getContext(), d0); int64_t multiplier = extractConstantMultiplier(projectedMap.getResult(0)); AffineMap subtractMap = AffineMap::get(1, 0, d0 - multiplier); OpFoldResult maxAccessIdx = affine::makeComposedFoldedAffineApply( - rewriter, loc, subtractMap, {paddingDimOfr}); + builder, loc, subtractMap, {paddingDimOfr}); terms.push_back(maxAccessIdx); LLVM_DEBUG(DBGS() << "------new term: " << terms.back() << "\n"); @@ -177,19 +177,19 @@ SmallVector linalg::computePaddedShape( // If there are no terms, just return the dim. if (terms.empty()) { paddedShape[resultIndex] = - createFoldedDimOp(rewriter, loc, v, resultIndex); + createFoldedDimOp(builder, loc, v, resultIndex); continue; } // Sum individual terms' contributions. SmallVector dims(terms.size()); - bindDimsList(rewriter.getContext(), MutableArrayRef{dims}); + bindDimsList(builder.getContext(), MutableArrayRef{dims}); AffineExpr sumExpr = dims.front(); for (unsigned i = 1; i < dims.size(); ++i) sumExpr = sumExpr + dims[i]; // Add 1 to the maximum accessed index and get the final padded size. - OpFoldResult paddedDimOfr = affine::makeComposedFoldedAffineApply( - rewriter, loc, sumExpr + 1, terms); + OpFoldResult paddedDimOfr = + affine::makeComposedFoldedAffineApply(builder, loc, sumExpr + 1, terms); paddedShape[resultIndex] = paddedDimOfr; } @@ -198,7 +198,7 @@ SmallVector linalg::computePaddedShape( FailureOr> linalg::computeIndexingMapOpInterfacePaddedShape( - RewriterBase &rewriter, OpOperand &operandToPad, + OpBuilder &builder, OpOperand &operandToPad, ArrayRef iterationDomain, const PadTilingInterfaceOptions &options) { auto transferOp = llvm::dyn_cast(operandToPad.getOwner()); @@ -206,9 +206,9 @@ linalg::computeIndexingMapOpInterfacePaddedShape( return failure(); // clang-format off - assert(llvm::all_of(iterationDomain, [&rewriter](Range r) { - return r.offset == OpFoldResult(rewriter.getIndexAttr(0)) && - r.stride == OpFoldResult(rewriter.getIndexAttr(1)); + assert(llvm::all_of(iterationDomain, [&builder](Range r) { + return r.offset == OpFoldResult(builder.getIndexAttr(0)) && + r.stride == OpFoldResult(builder.getIndexAttr(1)); }) && "expected 0-offset 1-stride loop ranges"); // clang-format on SmallVector loopUpperBounds; @@ -218,13 +218,13 @@ linalg::computeIndexingMapOpInterfacePaddedShape( AffineMap indexingMap = transferOp.getMatchingIndexingMap(&operandToPad); return computePaddedShape( - rewriter, cast>(operandToPad.get()), + builder, cast>(operandToPad.get()), indexingMap, loopUpperBounds, options); } /// Pad a single operand to `paddedShape` using `paddingValueAttr` as padding /// Value. -static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, +static Value padOperand(OpBuilder &builder, TilingInterface opToPad, TypedValue v, ArrayRef paddedShape, Attribute paddingValueAttr) { @@ -232,15 +232,15 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, if (auto complexTy = dyn_cast(getElementTypeOrSelf(v.getType()))) { if (auto complexAttr = dyn_cast(paddingValueAttr)) { - paddingValue = complex::ConstantOp::create(rewriter, opToPad.getLoc(), + paddingValue = complex::ConstantOp::create(builder, opToPad.getLoc(), complexTy, complexAttr); } } else if (isa(paddingValueAttr)) { - paddingValue = ub::PoisonOp::create(rewriter, opToPad.getLoc(), + paddingValue = ub::PoisonOp::create(builder, opToPad.getLoc(), getElementTypeOrSelf(v.getType())); } else if (auto typedAttr = dyn_cast(paddingValueAttr)) { paddingValue = - arith::ConstantOp::create(rewriter, opToPad.getLoc(), typedAttr); + arith::ConstantOp::create(builder, opToPad.getLoc(), typedAttr); } assert(paddingValue && "failed to create value from padding attribute"); @@ -259,49 +259,48 @@ static Value padOperand(RewriterBase &rewriter, TilingInterface opToPad, RankedTensorType::get(tensorShape, getElementTypeOrSelf(v)); LLVM_DEBUG(DBGS() << "--SUCCESS, makeComposedPadHighOp with type: " << paddedTensorType); - return makeComposedPadHighOp(rewriter, opToPad.getLoc(), paddedTensorType, v, + return makeComposedPadHighOp(builder, opToPad.getLoc(), paddedTensorType, v, paddingValue, /*nofold=*/false, dynDims); } -FailureOr linalg::rewriteAsPaddedOp( - RewriterBase &rewriter, TilingInterface opToPad, - const PadTilingInterfaceOptions &constOptions, - SmallVector &padOps, +FailureOr linalg::rewriteAsPaddedOp( + OpBuilder &builder, TilingInterface toPad, + PadTilingInterfaceOptions options, const PadSizeComputationFunction &computePaddingSizeFun) { - LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << opToPad << "\n"); + LLVM_DEBUG(DBGS() << "Start rewriteAsPaddedOp : " << toPad << "\n"); + SmallVector padOps; + Location loc = toPad.getLoc(); - Location loc = opToPad.getLoc(); - PadTilingInterfaceOptions options(constOptions); // Allow inference of pad values if they are not explicitly specified. // TODO: be mindful about the value depending on the actual operation. if (options.paddingValues.empty()) { - SmallVector types(opToPad->getOperandTypes()); - llvm::append_range(types, opToPad->getResultTypes()); + SmallVector types(toPad->getOperandTypes()); + llvm::append_range(types, toPad->getResultTypes()); for (Type t : types) { options.paddingValues.push_back( - rewriter.getZeroAttr(getElementTypeOrSelf(t))); + builder.getZeroAttr(getElementTypeOrSelf(t))); } } - if (llvm::any_of(opToPad->getOperands(), + if (llvm::any_of(toPad->getOperands(), [](Value v) { return isa(v.getType()); })) { - return rewriter.notifyMatchFailure(opToPad, - "expected operation on tensors"); + LLVM_DEBUG(DBGS() << "Not an operation on tensors: FAIL\n"); + return failure(); } - OpBuilder::InsertionGuard g(rewriter); - // Set IP after opToPad because we also take the dims of opToPad's output. - rewriter.setInsertionPointAfter(opToPad); + OpBuilder::InsertionGuard g(builder); + // Set IP after toPad because we also take the dims of toPad's output. + builder.setInsertionPointAfter(toPad); // 1. Get the loopUpperBounds from the TilingInterface. - SmallVector iterationDomain = opToPad.getIterationDomain(rewriter); + SmallVector iterationDomain = toPad.getIterationDomain(builder); // 2. For each operand. SmallVector newOperands; - newOperands.reserve(opToPad->getNumOperands()); - for (OpOperand &opOperand : opToPad->getOpOperands()) { + newOperands.reserve(toPad->getNumOperands()); + for (OpOperand &opOperand : toPad->getOpOperands()) { Value operand = opOperand.get(); - LLVM_DEBUG(DBGS() << "--start padding oprd: " << operand << "\n"); + LLVM_DEBUG(DBGS() << "--start padding operand: " << operand << "\n"); // 2.a. Skip scalar-like operands. Type operandType = operand.getType(); @@ -311,30 +310,31 @@ FailureOr linalg::rewriteAsPaddedOp( newOperands.push_back(operand); continue; } + // 2.a. Compute padded shape. FailureOr> maybePaddedShape = - computePaddingSizeFun(rewriter, opOperand, iterationDomain, options); + computePaddingSizeFun(builder, opOperand, iterationDomain, options); if (failed(maybePaddedShape)) { - return rewriter.notifyMatchFailure(opToPad, "could not pad op"); + LLVM_DEBUG(DBGS() << "Could not get padded shape of operand: FAIL\n"); + return failure(); } // 2.b. Expect proper `paddingValues`. // TODO: we may want to allow garbage padding in the future, in which case // we would just not assert. if (opOperand.getOperandNumber() >= options.paddingValues.size()) { - return rewriter.notifyMatchFailure(opToPad, - "--no padding value specified"); + LLVM_DEBUG(DBGS() << "Too few padding values specified: FAIL\n"); + return failure(); } Attribute paddingValueAttr = options.paddingValues[opOperand.getOperandNumber()]; // 2.c. Perform actual padding. - Value paddedOperand = padOperand( - rewriter, opToPad, cast>(operand), - *maybePaddedShape, paddingValueAttr); + Value paddedOperand = + padOperand(builder, toPad, cast>(operand), + *maybePaddedShape, paddingValueAttr); LLVM_DEBUG(DBGS() << "--done padding operand: " << paddedOperand << "\n"); - // 2.d. Perform actual padding. newOperands.push_back(paddedOperand); if (auto padOp = paddedOperand.getDefiningOp()) padOps.push_back(padOp); @@ -342,38 +342,34 @@ FailureOr linalg::rewriteAsPaddedOp( // 3. Form the resulting tensor::ExtractSliceOp. ReifiedRankedShapedTypeDims reifiedResultShapes; - if (failed(reifyResultShapes(rewriter, opToPad, reifiedResultShapes))) { - LLVM_DEBUG(DBGS() << "--failed to reify result shapes -> FAIL\n"); - return rewriter.notifyMatchFailure(opToPad, - "failed to reify result shapes"); + if (failed(reifyResultShapes(builder, toPad, reifiedResultShapes))) { + LLVM_DEBUG(DBGS() << "Failed to reify result shapes: FAIL\n"); + return failure(); } - assert(reifiedResultShapes.size() == opToPad->getNumResults() && + assert(reifiedResultShapes.size() == toPad->getNumResults() && "expected same number of results"); - // Clone `opToPad` to operate on the statically padded shapes. + // Clone `toPad` to operate on the statically padded shapes. auto resultTensorTypes = - ValueRange(newOperands).take_back(opToPad->getNumResults()).getTypes(); - // clone **should** properly notify the rewriter. + ValueRange(newOperands).take_back(toPad->getNumResults()).getTypes(); + // clone **should** properly notify the builder. TilingInterface paddedOp = - clone(rewriter, opToPad, resultTensorTypes, newOperands); + clone(builder, toPad, resultTensorTypes, newOperands); LLVM_DEBUG(DBGS() << "--cloned padded op: " << paddedOp << "\n"); - // Recover the slice out of the new static results. This keeps the original - // opToPad around because it uses the dims of the original results. + // Recover the slice out of the new static results. SmallVector paddedSubtensorResults; - paddedSubtensorResults.reserve(opToPad->getNumResults()); + paddedSubtensorResults.reserve(toPad->getNumResults()); for (const auto &en : llvm::enumerate(paddedOp->getResults())) { Value paddedResult = en.value(); int64_t resultNumber = en.index(); int64_t rank = cast(paddedResult.getType()).getRank(); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - SmallVector strides(rank, rewriter.getIndexAttr(1)); + SmallVector offsets(rank, builder.getIndexAttr(0)); + SmallVector strides(rank, builder.getIndexAttr(1)); paddedSubtensorResults.push_back(tensor::ExtractSliceOp::create( - rewriter, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], + builder, loc, paddedResult, offsets, reifiedResultShapes[resultNumber], strides)); } - rewriter.replaceOp(opToPad, paddedSubtensorResults); - - return paddedOp; + return PadTilingInterfaceResult{padOps, paddedOp, paddedSubtensorResults}; }