diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h index 43c97d57e1834..adccbef754ec5 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h @@ -598,13 +598,14 @@ class BufferizationState { FailureOr allocateTensorForShapedValue(OpBuilder &b, Location loc, Value shapedValue, const BufferizationOptions &options, - bool copy = true); + const BufferizationState &state, bool copy = true); /// Lookup the buffer for the given value. If the value was not bufferized /// yet, wrap it in a ToBufferOp. Otherwise, it is the result of a ToTensorOp, /// from which the memref operand is returned. FailureOr getBuffer(RewriterBase &rewriter, Value value, - const BufferizationOptions &options); + const BufferizationOptions &options, + const BufferizationState &state); /// Return the buffer type for a given Value (tensor) after bufferization /// without bufferizing any IR. @@ -615,7 +616,8 @@ FailureOr getBuffer(RewriterBase &rewriter, Value value, /// /// This function is a wrapper around BufferizableOpInterface::getBufferType. FailureOr getBufferType(Value value, - const BufferizationOptions &options); + const BufferizationOptions &options, + const BufferizationState &state); /// Return the buffer type for a given Value (tensor) after bufferization /// without bufferizing any IR. This function (and not the other overload @@ -629,6 +631,7 @@ FailureOr getBufferType(Value value, /// This function is a wrapper around `BufferizableOpInterface::getBufferType`. FailureOr getBufferType(Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack); /// Return "true" if the given op has tensor semantics and should be bufferized. @@ -709,6 +712,7 @@ AliasingOpOperandList defaultGetAliasingOpOperands(Value value, /// places. FailureOr defaultGetBufferType(Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack); /// This is the default implementation of diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td index 72974a8c808fd..cafe05fe5f189 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -381,13 +381,14 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /*retType=*/"::llvm::LogicalResult", /*methodName=*/"resolveConflicts", /*args=*/(ins "::mlir::RewriterBase &":$rewriter, - "const ::mlir::bufferization::AnalysisState &":$state), + "const ::mlir::bufferization::AnalysisState &":$analysisState, + "const ::mlir::bufferization::BufferizationState &":$bufferizationState), /*methodBody=*/"", /*defaultImplementation=*/[{ auto bufferizableOp = ::llvm::cast($_op.getOperation()); return bufferizableOp.resolveTensorOpOperandConflicts( - rewriter, state); + rewriter, analysisState, bufferizationState); }] >, InterfaceMethod< @@ -528,6 +529,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /*methodName=*/"getBufferType", /*args=*/(ins "::mlir::Value":$value, "const ::mlir::bufferization::BufferizationOptions &":$options, + "const ::mlir::bufferization::BufferizationState &":$state, "::llvm::SmallVector<::mlir::Value> &":$invocationStack), /*methodBody=*/"", /*defaultImplementation=*/[{ @@ -536,7 +538,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { assert(invocationStack.back() == value && "inconsistant invocation stack"); return ::mlir::bufferization::detail::defaultGetBufferType( - value, options, invocationStack); + value, options, state, invocationStack); }] >, InterfaceMethod< @@ -621,7 +623,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> { /// form of `bufferization.alloc_tensor` ops. ::llvm::LogicalResult resolveTensorOpOperandConflicts( ::mlir::RewriterBase &rewriter, - const ::mlir::bufferization::AnalysisState &state); + const ::mlir::bufferization::AnalysisState &analysisState, + const ::mlir::bufferization::BufferizationState &bufferizationState); /// Return `true` if the given OpOperand creates an alias but does neither /// read nor write. This implies that `bufferizesToMemoryRead` and diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td index dafa4b9b183f2..3d4dcdee2663b 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td @@ -112,6 +112,7 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor", FailureOr getBufferType( Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack); RankedTensorType getType() { @@ -471,7 +472,7 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [ FailureOr getBufferType( Value value, const BufferizationOptions &options, - SmallVector &invocationStack) { + const BufferizationState &state, SmallVector &invocationStack) { return ::llvm::cast(getMemref().getType()); } }]; diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h index cf86b9a23f59e..a441b8b66659e 100644 --- a/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h +++ b/mlir/include/mlir/Dialect/Bufferization/IR/UnstructuredControlFlow.h @@ -34,12 +34,13 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { // Note: The user may want to override this function for OpResults in // case the bufferized result type is different from the bufferized type of // the aliasing OpOperand (if any). if (isa(value)) - return bufferization::detail::defaultGetBufferType(value, options, + return bufferization::detail::defaultGetBufferType(value, options, state, invocationStack); // Compute the buffer type of the block argument by computing the bufferized @@ -65,7 +66,7 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel callerType = memrefType; } else { FailureOr maybeCallerType = - bufferization::getBufferType(opOperand->get(), options, + bufferization::getBufferType(opOperand->get(), options, state, invocationStack); if (failed(maybeCallerType)) return failure(); @@ -81,9 +82,9 @@ struct OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel if (bufferType == callerType) continue; - // If the computed buffer type does not match the computed buffer type - // of the earlier forwarded operands, fall back to a buffer type with a - // fully dynamic layout map. + // If the computed buffer type does not match the computed buffer type + // of the earlier forwarded operands, fall back to a buffer type with a + // fully dynamic layout map. #ifndef NDEBUG if (auto rankedTensorType = dyn_cast(tensorType)) { assert(bufferType.hasRank() && callerType.hasRank() && diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h index 70e3defee0867..c1f5654abbf9b 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -62,7 +62,8 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, /// `BufferizableOpInterface`. The buffer types of tensor block arguments are /// computed with `BufferizableOpIntercace::getBufferType`. LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, - const BufferizationOptions &options); + const BufferizationOptions &options, + BufferizationState &state); } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h index a4ee893ca5341..e17d5264a1a45 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Transforms.h @@ -75,12 +75,15 @@ void hoistBuffersFromLoops(Operation *op); /// additional buffer allocations. LogicalResult insertTensorCopies(Operation *op, const OneShotBufferizationOptions &options, + const BufferizationState &bufferizationState, BufferizationStatistics *statistics = nullptr); /// Resolve RaW and other conflicts by inserting bufferization.alloc_tensor ops. /// After applying this transform, the IR can be bufferized without inserting /// additional buffer allocations. -LogicalResult insertTensorCopies(Operation *op, const AnalysisState &state); +LogicalResult insertTensorCopies(Operation *op, + const AnalysisState &analysisState, + const BufferizationState &bufferizationState); /// Populate patterns to lower tensor.empty ops to bufferization.alloc_tensor /// ops. diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp index f646326ffc58f..a57d58ab28d28 100644 --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -90,7 +90,8 @@ struct IndexCastOpInterface auto castOp = cast(op); auto resultTensorType = cast(castOp.getType()); - FailureOr source = getBuffer(rewriter, castOp.getIn(), options); + FailureOr source = + getBuffer(rewriter, castOp.getIn(), options, state); if (failed(source)) return failure(); auto sourceType = cast(source->getType()); @@ -151,9 +152,9 @@ struct SelectOpInterface // the moment (one for each tensor). When copying the op result, only one // copy would be needed. FailureOr maybeTrueBuffer = - getBuffer(rewriter, selectOp.getTrueValue(), options); + getBuffer(rewriter, selectOp.getTrueValue(), options, state); FailureOr maybeFalseBuffer = - getBuffer(rewriter, selectOp.getFalseValue(), options); + getBuffer(rewriter, selectOp.getFalseValue(), options, state); if (failed(maybeTrueBuffer) || failed(maybeFalseBuffer)) return failure(); Value trueBuffer = *maybeTrueBuffer; @@ -164,7 +165,7 @@ struct SelectOpInterface // both of them to the most dynamic MemRef type. if (trueBuffer.getType() != falseBuffer.getType()) { auto targetType = - bufferization::getBufferType(selectOp.getResult(), options); + bufferization::getBufferType(selectOp.getResult(), options, state); if (failed(targetType)) return failure(); if (trueBuffer.getType() != *targetType) @@ -182,13 +183,14 @@ struct SelectOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto selectOp = cast(op); assert(value == selectOp.getResult() && "invalid value"); - auto trueType = bufferization::getBufferType(selectOp.getTrueValue(), - options, invocationStack); - auto falseType = bufferization::getBufferType(selectOp.getFalseValue(), - options, invocationStack); + auto trueType = bufferization::getBufferType( + selectOp.getTrueValue(), options, state, invocationStack); + auto falseType = bufferization::getBufferType( + selectOp.getFalseValue(), options, state, invocationStack); if (failed(trueType) || failed(falseType)) return failure(); if (*trueType == *falseType) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp index 14fa4c1ed8159..1d6e1bdaf80f5 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -165,7 +165,8 @@ Operation *bufferization::getOwnerOfValue(Value value) { /// allocated. FailureOr bufferization::allocateTensorForShapedValue( OpBuilder &b, Location loc, Value shapedValue, - const BufferizationOptions &options, bool copy) { + const BufferizationOptions &options, const BufferizationState &state, + bool copy) { Value tensor; if (llvm::isa(shapedValue.getType())) { tensor = shapedValue; @@ -210,7 +211,8 @@ FailureOr bufferization::allocateTensorForShapedValue( // Add 'memory_space' attribute. Not needed if 'copy' operand is specified. if (copy) return allocTensorOp.getResult(); - FailureOr copyBufferType = getBufferType(tensor, options); + FailureOr copyBufferType = + getBufferType(tensor, options, state); if (failed(copyBufferType)) return failure(); std::optional memorySpace = copyBufferType->getMemorySpace(); @@ -222,7 +224,8 @@ FailureOr bufferization::allocateTensorForShapedValue( } LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( - RewriterBase &rewriter, const AnalysisState &state) { + RewriterBase &rewriter, const AnalysisState &analysisState, + const BufferizationState &bufferizationState) { OpBuilder::InsertionGuard g(rewriter); Operation *op = getOperation(); SmallVector outOfPlaceOpOperands; @@ -235,16 +238,18 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( Type operandType = opOperand.get().getType(); if (!llvm::isa(operandType)) continue; - if (state.isInPlace(opOperand)) + if (analysisState.isInPlace(opOperand)) continue; if (llvm::isa(operandType)) return op->emitError("copying of unranked tensors is not implemented"); - AliasingValueList aliasingValues = state.getAliasingValues(opOperand); + AliasingValueList aliasingValues = + analysisState.getAliasingValues(opOperand); if (aliasingValues.getNumAliases() == 1 && isa(aliasingValues.getAliases()[0].value) && - !state.bufferizesToMemoryWrite(opOperand) && - state.getAliasingOpOperands(aliasingValues.getAliases()[0].value) + !analysisState.bufferizesToMemoryWrite(opOperand) && + analysisState + .getAliasingOpOperands(aliasingValues.getAliases()[0].value) .getNumAliases() == 1 && !isa( aliasingValues.getAliases()[0].value.getType())) { @@ -256,12 +261,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( // cannot be copied at the moment). Value value = aliasingValues.getAliases()[0].value; outOfPlaceValues.push_back(value); - if (!state.canOmitTensorCopy(opOperand)) + if (!analysisState.canOmitTensorCopy(opOperand)) copiedOpValues.insert(value); } else { // In all other cases, make a copy of the OpOperand. outOfPlaceOpOperands.push_back(&opOperand); - if (!state.canOmitTensorCopy(opOperand)) + if (!analysisState.canOmitTensorCopy(opOperand)) copiedOpOperands.insert(&opOperand); } } @@ -270,8 +275,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( rewriter.setInsertionPoint(op); for (OpOperand *opOperand : outOfPlaceOpOperands) { FailureOr copy = allocateTensorForShapedValue( - rewriter, op->getLoc(), opOperand->get(), state.getOptions(), - copiedOpOperands.contains(opOperand)); + rewriter, op->getLoc(), opOperand->get(), analysisState.getOptions(), + bufferizationState, copiedOpOperands.contains(opOperand)); if (failed(copy)) return failure(); rewriter.modifyOpInPlace(op, [&]() { opOperand->set(*copy); }); @@ -281,8 +286,8 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( rewriter.setInsertionPointAfter(op); for (Value value : outOfPlaceValues) { FailureOr copy = allocateTensorForShapedValue( - rewriter, op->getLoc(), value, state.getOptions(), - copiedOpValues.count(value)); + rewriter, op->getLoc(), value, analysisState.getOptions(), + bufferizationState, copiedOpValues.count(value)); if (failed(copy)) return failure(); SmallVector uses = llvm::to_vector( @@ -665,7 +670,8 @@ static void ensureToBufferOpIsValid(Value tensor, Type memrefType) { } FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, - const BufferizationOptions &options) { + const BufferizationOptions &options, + const BufferizationState &state) { #ifndef NDEBUG auto tensorType = llvm::dyn_cast(value.getType()); assert(tensorType && "unexpected non-tensor type"); @@ -678,7 +684,7 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, // Insert to_buffer op. OpBuilder::InsertionGuard g(rewriter); setInsertionPointAfter(rewriter, value); - FailureOr memrefType = getBufferType(value, options); + FailureOr memrefType = getBufferType(value, options, state); if (failed(memrefType)) return failure(); ensureToBufferOpIsValid(value, *memrefType); @@ -689,14 +695,16 @@ FailureOr bufferization::getBuffer(RewriterBase &rewriter, Value value, /// Return the buffer type for a given Value (tensor) after bufferization. FailureOr -bufferization::getBufferType(Value value, const BufferizationOptions &options) { +bufferization::getBufferType(Value value, const BufferizationOptions &options, + const BufferizationState &state) { SmallVector invocationStack; - return getBufferType(value, options, invocationStack); + return getBufferType(value, options, state, invocationStack); } /// Return the buffer type for a given Value (tensor) after bufferization. FailureOr bufferization::getBufferType(Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) { assert(llvm::isa(value.getType()) && "unexpected non-tensor type"); @@ -708,7 +716,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options, Operation *op = getOwnerOfValue(value); auto bufferizableOp = options.dynCastBufferizableOp(op); if (bufferizableOp) - return bufferizableOp.getBufferType(value, options, invocationStack); + return bufferizableOp.getBufferType(value, options, state, invocationStack); // Op is not bufferizable. auto memSpace = @@ -944,6 +952,7 @@ AliasingOpOperandList bufferization::detail::defaultGetAliasingOpOperands( FailureOr bufferization::detail::defaultGetBufferType( Value value, const BufferizationOptions &options, + const BufferizationState &bufferizationState, SmallVector &invocationStack) { assert(llvm::isa(value.getType()) && "expected tensor type"); @@ -954,14 +963,15 @@ FailureOr bufferization::detail::defaultGetBufferType( // Value is an OpResult. Operation *op = getOwnerOfValue(value); auto opResult = llvm::cast(value); - AnalysisState state(options); - AliasingOpOperandList aliases = state.getAliasingOpOperands(opResult); + AnalysisState analysisState(options); + AliasingOpOperandList aliases = analysisState.getAliasingOpOperands(opResult); if (aliases.getNumAliases() > 0 && aliases.getAliases()[0].relation == BufferRelation::Equivalent) { // If the OpResult has an equivalent OpOperand, both OpResult and // OpOperand bufferize to the exact same buffer type. Value equivalentOperand = aliases.getAliases().front().opOperand->get(); - return getBufferType(equivalentOperand, options, invocationStack); + return getBufferType(equivalentOperand, options, bufferizationState, + invocationStack); } // If we do not know the memory space and there is no default memory space, diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 91eccb0ab7430..dc54ac94aed32 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -163,14 +163,15 @@ LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter, // Get "copy" buffer. Value copyBuffer; if (getCopy()) { - FailureOr maybeCopyBuffer = getBuffer(rewriter, getCopy(), options); + FailureOr maybeCopyBuffer = + getBuffer(rewriter, getCopy(), options, state); if (failed(maybeCopyBuffer)) return failure(); copyBuffer = *maybeCopyBuffer; } // Create memory allocation. - auto allocType = bufferization::getBufferType(getResult(), options); + auto allocType = bufferization::getBufferType(getResult(), options, state); if (failed(allocType)) return failure(); SmallVector dynamicDims = getDynamicSizes(); @@ -223,6 +224,7 @@ AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand, FailureOr AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) { assert(value == getResult() && "invalid value"); @@ -231,8 +233,8 @@ AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options, if (getMemorySpace().has_value()) { memorySpace = *getMemorySpace(); } else if (getCopy()) { - auto copyBufferType = - bufferization::getBufferType(getCopy(), options, invocationStack); + auto copyBufferType = bufferization::getBufferType(getCopy(), options, + state, invocationStack); if (failed(copyBufferType)) return failure(); memorySpace = copyBufferType->getMemorySpace(); @@ -532,7 +534,7 @@ void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter, const BufferizationOptions &options, BufferizationState &state) { - FailureOr buffer = getBuffer(rewriter, getTensor(), options); + FailureOr buffer = getBuffer(rewriter, getTensor(), options, state); if (failed(buffer)) return failure(); rewriter.create(getLoc(), *buffer); @@ -583,7 +585,8 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, bool tensorDest = isa(getDest().getType()); Value buffer; if (tensorDest) { - FailureOr maybeBuffer = getBuffer(rewriter, getDest(), options); + FailureOr maybeBuffer = + getBuffer(rewriter, getDest(), options, state); if (failed(maybeBuffer)) return failure(); buffer = *maybeBuffer; @@ -591,7 +594,7 @@ MaterializeInDestinationOp::bufferize(RewriterBase &rewriter, assert(isa(getDest().getType()) && "expected memref type"); buffer = getDest(); } - auto srcBuffer = getBuffer(rewriter, getSource(), options); + auto srcBuffer = getBuffer(rewriter, getSource(), options, state); if (failed(srcBuffer)) return failure(); if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer))) diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 67f373d912dd4..c7681d309a4af 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -280,8 +280,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op, BufferizationState &bufferizationState, BufferizationStatistics *statistics) { if (options.copyBeforeWrite) { - AnalysisState state(options); - if (failed(insertTensorCopies(op, state))) + AnalysisState analysisState(options); + if (failed(insertTensorCopies(op, analysisState, bufferizationState))) return failure(); } @@ -396,7 +396,8 @@ LogicalResult bufferization::bufferizeOp(Operation *op, LogicalResult bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, - const BufferizationOptions &options) { + const BufferizationOptions &options, + BufferizationState &state) { OpBuilder::InsertionGuard g(rewriter); auto bufferizableOp = options.dynCastBufferizableOp(block->getParentOp()); if (!bufferizableOp) @@ -412,7 +413,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, } FailureOr memrefType = - bufferization::getBufferType(bbArg, options); + bufferization::getBufferType(bbArg, options, state); if (failed(memrefType)) return failure(); newTypes.push_back(*memrefType); @@ -463,7 +464,7 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, continue; } FailureOr operandBufferType = - bufferization::getBufferType(operand, options); + bufferization::getBufferType(operand, options, state); if (failed(operandBufferType)) return failure(); rewriter.setInsertionPointAfterValue(operand); diff --git a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp index 6210f1d787bf4..a0168da44b7b3 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp @@ -213,6 +213,7 @@ struct CallOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto callOp = cast(op); @@ -255,7 +256,7 @@ struct CallOpInterface // Returning a memref. FailureOr resultType = - bufferization::getBufferType(result, options); + bufferization::getBufferType(result, options, state); if (failed(resultType)) return failure(); resultTypes.push_back(*resultType); @@ -278,7 +279,7 @@ struct CallOpInterface // Retrieve buffers for tensor operands. FailureOr maybeBuffer = - getBuffer(rewriter, opOperand.get(), options); + getBuffer(rewriter, opOperand.get(), options, state); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; @@ -291,7 +292,8 @@ struct CallOpInterface // result type. FailureOr maybeMemRefType = bufferization::getBufferType( - funcOp.getArgument(opOperand.getOperandNumber()), options); + funcOp.getArgument(opOperand.getOperandNumber()), options, + state); if (failed(maybeMemRefType)) return failure(); memRefType = *maybeMemRefType; @@ -396,6 +398,7 @@ struct FuncOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto funcOp = cast(op); auto bbArg = cast(value); @@ -406,7 +409,7 @@ struct FuncOpInterface options); return OpWithUnstructuredControlFlowBufferizableOpInterfaceExternalModel:: - getBufferType(op, value, options, invocationStack); + getBufferType(op, value, options, state, invocationStack); } /// Rewrite function bbArgs and return values into buffer form. This function @@ -459,7 +462,7 @@ struct FuncOpInterface // 1. Bufferize every block. for (Block &block : funcOp.getBody()) if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, - options))) + options, state))) return failure(); // 2. Bufferize the operands of the all return op. diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index de820e9c8f8af..33a922d59224b 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -1379,7 +1379,7 @@ LogicalResult bufferization::runOneShotBufferize( // Run One-Shot Analysis and insert buffer copies (on the tensor level) // only where needed. This is the default and much more efficient than // copy-before-write. - if (failed(insertTensorCopies(op, options, statistics))) + if (failed(insertTensorCopies(op, options, state, statistics))) return failure(); // If test-analysis-only is set, the IR was annotated with RaW conflict diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp index 90ceea4d69680..dee2af8271ce8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp @@ -584,7 +584,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( "invalid combination of bufferization flags"); if (!options.copyBeforeWrite) { if (options.noAnalysisFuncFilter.empty()) { - if (failed(insertTensorCopies(moduleOp, options, statistics))) + if (failed(insertTensorCopies(moduleOp, options, state, statistics))) return failure(); } else { // FuncOps whose names are specified in options.noAnalysisFuncFilter will @@ -600,7 +600,8 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( }; OneShotBufferizationOptions updatedOptions(options); updatedOptions.opFilter.denyOperation(analysisFilterFn); - if (failed(insertTensorCopies(moduleOp, updatedOptions, statistics))) + if (failed( + insertTensorCopies(moduleOp, updatedOptions, state, statistics))) return failure(); } } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index 4326b19f3104d..784d95a5dd22a 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -28,28 +28,29 @@ using namespace mlir::bufferization; LogicalResult mlir::bufferization::insertTensorCopies( Operation *op, const OneShotBufferizationOptions &options, + const BufferizationState &bufferizationState, BufferizationStatistics *statistics) { - OneShotAnalysisState state(op, options); + OneShotAnalysisState analysisState(op, options); // Run normal One-Shot Bufferize analysis or One-Shot Module Bufferize // analysis depending on whether function boundary bufferization is enabled or // not. if (options.bufferizeFunctionBoundaries) { - if (failed(analyzeModuleOp(cast(op), state, statistics))) + if (failed(analyzeModuleOp(cast(op), analysisState, statistics))) return failure(); } else { - if (failed(analyzeOp(op, state, statistics))) + if (failed(analyzeOp(op, analysisState, statistics))) return failure(); } if (options.testAnalysisOnly) return success(); - return insertTensorCopies(op, state); + return insertTensorCopies(op, analysisState, bufferizationState); } -LogicalResult -mlir::bufferization::insertTensorCopies(Operation *op, - const AnalysisState &state) { +LogicalResult mlir::bufferization::insertTensorCopies( + Operation *op, const AnalysisState &analysisState, + const BufferizationState &bufferizationState) { IRRewriter rewriter(op->getContext()); // It may be more efficient to walk in pre-order here, but the current @@ -62,14 +63,16 @@ mlir::bufferization::insertTensorCopies(Operation *op, nestedOp->getParentWithTrait() != op) return WalkResult::skip(); - auto bufferizableOp = state.getOptions().dynCastBufferizableOp(nestedOp); + auto bufferizableOp = + analysisState.getOptions().dynCastBufferizableOp(nestedOp); if (!bufferizableOp) return WalkResult::skip(); // Find inplacability conflicts and resolve them. (Typically with explicit // tensor copies in the form of AllocTensorOps.) rewriter.setInsertionPoint(nestedOp); - if (failed(bufferizableOp.resolveConflicts(rewriter, state))) + if (failed(bufferizableOp.resolveConflicts(rewriter, analysisState, + bufferizationState))) return WalkResult::interrupt(); return WalkResult::advance(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp index b6a498a57c036..9044d89c80bd6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.cpp @@ -24,10 +24,9 @@ using namespace mlir::bufferization; namespace { /// Generic conversion for any DestinationStyleOpInterface on tensors. -static LogicalResult -bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, - DestinationStyleOpInterface op, - const BufferizationOptions &options) { +static LogicalResult bufferizeDestinationStyleOpInterface( + RewriterBase &rewriter, DestinationStyleOpInterface op, + const BufferizationOptions &options, const BufferizationState &state) { // Take a guard before anything else. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(op); @@ -49,7 +48,8 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, newInputBuffers.push_back(opOperand->get()); continue; } - FailureOr buffer = getBuffer(rewriter, opOperand->get(), options); + FailureOr buffer = + getBuffer(rewriter, opOperand->get(), options, state); if (failed(buffer)) return failure(); newInputBuffers.push_back(*buffer); @@ -60,7 +60,7 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, for (OpResult opResult : op->getOpResults()) { OpOperand *opOperand = op.getDpsInitOperand(opResult.getResultNumber()); FailureOr resultBuffer = - getBuffer(rewriter, opOperand->get(), options); + getBuffer(rewriter, opOperand->get(), options, state); if (failed(resultBuffer)) return failure(); newOutputBuffers.push_back(*resultBuffer); @@ -76,10 +76,10 @@ bufferizeDestinationStyleOpInterface(RewriterBase &rewriter, // new op. Since the new op does not have any tensor results, it does not // return anything. assert(op->getNumRegions() == 1 && "expected that op has 1 region"); - OperationState state(op->getLoc(), op->getName(), newOperands, TypeRange{}, - op->getAttrs()); - state.addRegion(); - Operation *newOp = Operation::create(state); + OperationState opState(op->getLoc(), op->getName(), newOperands, TypeRange{}, + op->getAttrs()); + opState.addRegion(); + Operation *newOp = Operation::create(opState); newOp->getRegion(0).getBlocks().splice(newOp->getRegion(0).begin(), op->getRegion(0).getBlocks()); @@ -151,7 +151,7 @@ struct LinalgOpInterface const BufferizationOptions &options, BufferizationState &state) const { return bufferizeDestinationStyleOpInterface( - rewriter, cast(op), options); + rewriter, cast(op), options, state); } }; @@ -179,11 +179,11 @@ struct SoftmaxOpInterface BufferizationState &state) const { auto softmaxOp = cast(op); FailureOr inputBuffer = - getBuffer(rewriter, softmaxOp.getInput(), options); + getBuffer(rewriter, softmaxOp.getInput(), options, state); if (failed(inputBuffer)) return failure(); FailureOr outputBuffer = - getBuffer(rewriter, softmaxOp.getOutput(), options); + getBuffer(rewriter, softmaxOp.getOutput(), options, state); if (failed(outputBuffer)) return failure(); rewriter.create(softmaxOp.getLoc(), diff --git a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp index a69bc9e5088ae..ff6af63eee531 100644 --- a/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/MLProgram/Transforms/BufferizableOpInterfaceImpl.cpp @@ -138,7 +138,8 @@ struct GlobalStoreOpInterface auto targetMemref = rewriter.create( loc, memrefType, globalStoreOp.getGlobalAttr().getLeafReference()); - auto sourceMemref = getBuffer(rewriter, globalStoreOp.getValue(), options); + auto sourceMemref = + getBuffer(rewriter, globalStoreOp.getValue(), options, state); if (failed(sourceMemref)) { return failure(); } diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp index 3ff1f5c49aece..46fa77a7dc4e6 100644 --- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp @@ -104,11 +104,12 @@ struct ConditionOpInterface for (const auto &it : llvm::enumerate(conditionOp.getArgs())) { Value value = it.value(); if (isa(value.getType())) { - FailureOr maybeBuffer = getBuffer(rewriter, value, options); + FailureOr maybeBuffer = + getBuffer(rewriter, value, options, state); if (failed(maybeBuffer)) return failure(); FailureOr resultType = bufferization::getBufferType( - whileOp.getAfterArguments()[it.index()], options); + whileOp.getAfterArguments()[it.index()], options, state); if (failed(resultType)) return failure(); Value buffer = castBuffer(rewriter, *maybeBuffer, *resultType); @@ -196,7 +197,7 @@ struct ExecuteRegionOpInterface // Bufferize every block. for (Block &block : newOp.getRegion()) if (failed(bufferization::bufferizeBlockSignature(&block, rewriter, - options))) + options, state))) return failure(); // Update all uses of the old op. @@ -251,7 +252,7 @@ struct IfOpInterface newTypes.push_back(result.getType()); continue; } - auto bufferType = bufferization::getBufferType(result, options); + auto bufferType = bufferization::getBufferType(result, options, state); if (failed(bufferType)) return failure(); newTypes.push_back(*bufferType); @@ -275,6 +276,7 @@ struct IfOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto ifOp = cast(op); auto thenYieldOp = cast(ifOp.thenBlock()->getTerminator()); @@ -290,8 +292,8 @@ struct IfOpInterface // True branch was already bufferized. thenBufferType = cast(thenValue.getType()); } else { - auto maybeBufferType = - bufferization::getBufferType(thenValue, options, invocationStack); + auto maybeBufferType = bufferization::getBufferType( + thenValue, options, state, invocationStack); if (failed(maybeBufferType)) return failure(); thenBufferType = *maybeBufferType; @@ -300,8 +302,8 @@ struct IfOpInterface // False branch was already bufferized. elseBufferType = cast(elseValue.getType()); } else { - auto maybeBufferType = - bufferization::getBufferType(elseValue, options, invocationStack); + auto maybeBufferType = bufferization::getBufferType( + elseValue, options, state, invocationStack); if (failed(maybeBufferType)) return failure(); elseBufferType = *maybeBufferType; @@ -362,7 +364,7 @@ struct IndexSwitchOpInterface newTypes.push_back(result.getType()); continue; } - auto bufferType = bufferization::getBufferType(result, options); + auto bufferType = bufferization::getBufferType(result, options, state); if (failed(bufferType)) return failure(); newTypes.push_back(*bufferType); @@ -390,6 +392,7 @@ struct IndexSwitchOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto switchOp = cast(op); assert(value.getDefiningOp() == op && "invalid value"); @@ -401,8 +404,8 @@ struct IndexSwitchOpInterface Value yieldedValue = yieldOp->getOperand(resultNum); if (auto bufferType = dyn_cast(yieldedValue.getType())) return bufferType; - auto maybeBufferType = - bufferization::getBufferType(yieldedValue, options, invocationStack); + auto maybeBufferType = bufferization::getBufferType( + yieldedValue, options, state, invocationStack); if (failed(maybeBufferType)) return failure(); return maybeBufferType; @@ -468,12 +471,12 @@ DenseSet getEquivalentBuffers(Block::BlockArgListType bbArgs, /// given OpOperands. If an operand is not a tensor, return the original value. static FailureOr> getBuffers(RewriterBase &rewriter, const MutableOperandRange &operands, - const BufferizationOptions &options) { + const BufferizationOptions &options, BufferizationState &state) { SmallVector result; for (OpOperand &opOperand : operands) { if (isa(opOperand.get().getType())) { FailureOr resultBuffer = - getBuffer(rewriter, opOperand.get(), options); + getBuffer(rewriter, opOperand.get(), options, state); if (failed(resultBuffer)) return failure(); result.push_back(*resultBuffer); @@ -521,10 +524,11 @@ getBbArgReplacements(RewriterBase &rewriter, Block::BlockArgListType bbArgs, /// layout map and a cast must be inserted. static FailureOr computeLoopRegionIterArgBufferType( Operation *loopOp, BlockArgument iterArg, Value initArg, Value yieldedValue, - const BufferizationOptions &options, SmallVector &invocationStack) { + const BufferizationOptions &options, const BufferizationState &state, + SmallVector &invocationStack) { // Determine the buffer type of the init_arg. auto initArgBufferType = - bufferization::getBufferType(initArg, options, invocationStack); + bufferization::getBufferType(initArg, options, state, invocationStack); if (failed(initArgBufferType)) return failure(); @@ -550,8 +554,8 @@ static FailureOr computeLoopRegionIterArgBufferType( } else { // Note: This typically triggers a recursive call for the buffer type of // the iter_arg. - auto maybeBufferType = - bufferization::getBufferType(yieldedValue, options, invocationStack); + auto maybeBufferType = bufferization::getBufferType(yieldedValue, options, + state, invocationStack); if (failed(maybeBufferType)) return failure(); yieldedValueBufferType = *maybeBufferType; @@ -649,13 +653,16 @@ struct ForOpInterface return true; } - LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, - const AnalysisState &state) const { + LogicalResult + resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &analysisState, + const BufferizationState &bufferizationState) const { auto bufferizableOp = cast(op); - if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + if (failed(bufferizableOp.resolveTensorOpOperandConflicts( + rewriter, analysisState, bufferizationState))) return failure(); - if (state.getOptions().copyBeforeWrite) + if (analysisState.getOptions().copyBeforeWrite) return success(); // According to the `getAliasing...` implementations, a bufferized OpResult @@ -683,12 +690,13 @@ struct ForOpInterface doesNotAliasExternalValue( it.value(), &forOp.getRegion(), /*exceptions=*/forOp.getRegionIterArg(it.index()), - static_cast(state))) { + static_cast(analysisState))) { yieldValues.push_back(it.value()); continue; } FailureOr alloc = allocateTensorForShapedValue( - rewriter, yieldOp.getLoc(), it.value(), state.getOptions()); + rewriter, yieldOp.getLoc(), it.value(), analysisState.getOptions(), + bufferizationState); if (failed(alloc)) return failure(); yieldValues.push_back(*alloc); @@ -701,6 +709,7 @@ struct ForOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto forOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); @@ -709,7 +718,8 @@ struct ForOpInterface if (auto opResult = dyn_cast(value)) { // The type of an OpResult must match the corresponding iter_arg type. BlockArgument bbArg = forOp.getTiedLoopRegionIterArg(opResult); - return bufferization::getBufferType(bbArg, options, invocationStack); + return bufferization::getBufferType(bbArg, options, state, + invocationStack); } // Compute result/argument number. @@ -722,7 +732,7 @@ struct ForOpInterface BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum]; Value initArg = forOp.getInitArgs()[resultNum]; return computeLoopRegionIterArgBufferType( - op, iterArg, initArg, yieldedValue, options, invocationStack); + op, iterArg, initArg, yieldedValue, options, state, invocationStack); } LogicalResult bufferize(Operation *op, RewriterBase &rewriter, @@ -737,7 +747,7 @@ struct ForOpInterface // The new memref init_args of the loop. FailureOr> maybeInitArgs = - getBuffers(rewriter, forOp.getInitArgsMutable(), options); + getBuffers(rewriter, forOp.getInitArgsMutable(), options, state); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; @@ -752,7 +762,7 @@ struct ForOpInterface castedInitArgs.push_back(initArg); continue; } - auto targetType = bufferization::getBufferType(result, options); + auto targetType = bufferization::getBufferType(result, options, state); if (failed(targetType)) return failure(); castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); @@ -891,13 +901,16 @@ struct WhileOpInterface return true; } - LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, - const AnalysisState &state) const { + LogicalResult + resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &analysisState, + const BufferizationState &bufferizationState) const { auto bufferizableOp = cast(op); - if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + if (failed(bufferizableOp.resolveTensorOpOperandConflicts( + rewriter, analysisState, bufferizationState))) return failure(); - if (state.getOptions().copyBeforeWrite) + if (analysisState.getOptions().copyBeforeWrite) return success(); // According to the `getAliasing...` implementations, a bufferized OpResult @@ -914,9 +927,10 @@ struct WhileOpInterface // For every yielded value, is the value equivalent to its corresponding // bbArg? DenseSet equivalentYieldsBefore = getEquivalentBuffers( - whileOp.getBeforeArguments(), conditionOp.getArgs(), state); - DenseSet equivalentYieldsAfter = getEquivalentBuffers( - whileOp.getAfterArguments(), whileOp.getYieldOp().getResults(), state); + whileOp.getBeforeArguments(), conditionOp.getArgs(), analysisState); + DenseSet equivalentYieldsAfter = + getEquivalentBuffers(whileOp.getAfterArguments(), + whileOp.getYieldOp().getResults(), analysisState); // Update "before" region. rewriter.setInsertionPoint(conditionOp); @@ -931,7 +945,8 @@ struct WhileOpInterface continue; } FailureOr alloc = allocateTensorForShapedValue( - rewriter, conditionOp.getLoc(), value, state.getOptions()); + rewriter, conditionOp.getLoc(), value, analysisState.getOptions(), + bufferizationState); if (failed(alloc)) return failure(); beforeYieldValues.push_back(*alloc); @@ -956,7 +971,7 @@ struct WhileOpInterface // The new memref init_args of the loop. FailureOr> maybeInitArgs = - getBuffers(rewriter, whileOp.getInitsMutable(), options); + getBuffers(rewriter, whileOp.getInitsMutable(), options, state); if (failed(maybeInitArgs)) return failure(); SmallVector initArgs = *maybeInitArgs; @@ -971,7 +986,7 @@ struct WhileOpInterface castedInitArgs.push_back(initArg); continue; } - auto targetType = bufferization::getBufferType(beforeArg, options); + auto targetType = bufferization::getBufferType(beforeArg, options, state); if (failed(targetType)) return failure(); castedInitArgs.push_back(castBuffer(rewriter, initArg, *targetType)); @@ -984,7 +999,7 @@ struct WhileOpInterface return bbArg.getType(); // TODO: error handling return llvm::cast( - *bufferization::getBufferType(bbArg, options)); + *bufferization::getBufferType(bbArg, options, state)); })); // Construct a new scf.while op with memref instead of tensor values. @@ -1029,6 +1044,7 @@ struct WhileOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto whileOp = cast(op); assert(getOwnerOfValue(value) == op && "invalid value"); @@ -1041,7 +1057,7 @@ struct WhileOpInterface auto yieldOp = whileOp.getYieldOp(); Value yieldedValue = yieldOp.getOperand(bbArg.getArgNumber()); return computeLoopRegionIterArgBufferType( - op, bbArg, initArg, yieldedValue, options, invocationStack); + op, bbArg, initArg, yieldedValue, options, state, invocationStack); } } @@ -1062,7 +1078,7 @@ struct WhileOpInterface // scf.condition was already bufferized. return cast(conditionYieldedVal.getType()); } - return bufferization::getBufferType(conditionYieldedVal, options, + return bufferization::getBufferType(conditionYieldedVal, options, state, invocationStack); } @@ -1161,7 +1177,8 @@ struct YieldOpInterface for (const auto &it : llvm::enumerate(yieldOp.getResults())) { Value value = it.value(); if (isa(value.getType())) { - FailureOr maybeBuffer = getBuffer(rewriter, value, options); + FailureOr maybeBuffer = + getBuffer(rewriter, value, options, state); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; @@ -1169,14 +1186,14 @@ struct YieldOpInterface if (isa( yieldOp->getParentOp())) { FailureOr resultType = bufferization::getBufferType( - yieldOp->getParentOp()->getResult(it.index()), options); + yieldOp->getParentOp()->getResult(it.index()), options, state); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); } else if (auto whileOp = dyn_cast(yieldOp->getParentOp())) { FailureOr resultType = bufferization::getBufferType( - whileOp.getBeforeArguments()[it.index()], options); + whileOp.getBeforeArguments()[it.index()], options, state); if (failed(resultType)) return failure(); buffer = castBuffer(rewriter, buffer, *resultType); @@ -1236,7 +1253,7 @@ struct ForallOpInterface // Get buffers for all output operands. SmallVector buffers; for (Value out : forallOp.getOutputs()) { - FailureOr buffer = getBuffer(rewriter, out, options); + FailureOr buffer = getBuffer(rewriter, out, options, state); if (failed(buffer)) return failure(); buffers.push_back(*buffer); @@ -1283,6 +1300,7 @@ struct ForallOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto forallOp = cast(op); @@ -1290,13 +1308,14 @@ struct ForallOpInterface // A tensor block argument has the same bufferized type as the // corresponding output operand. return bufferization::getBufferType( - forallOp.getTiedOpOperand(bbArg)->get(), options, invocationStack); + forallOp.getTiedOpOperand(bbArg)->get(), options, state, + invocationStack); // The bufferized result type is the same as the bufferized type of the // corresponding output operand. return bufferization::getBufferType( forallOp.getOutputs()[cast(value).getResultNumber()], options, - invocationStack); + state, invocationStack); } bool isRepetitiveRegion(Operation *op, unsigned index) const { diff --git a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp index e8cab76d3c753..dc91117a51936 100644 --- a/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.cpp @@ -119,7 +119,7 @@ struct AssumingYieldOpInterface SmallVector newResults; for (Value value : yieldOp.getOperands()) { if (isa(value.getType())) { - FailureOr buffer = getBuffer(rewriter, value, options); + FailureOr buffer = getBuffer(rewriter, value, options, state); if (failed(buffer)) return failure(); newResults.push_back(*buffer); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 7c7c64f2aef01..a3ab53d818115 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -152,8 +152,10 @@ class SparsificationAndBufferizationPass // invalidate the results of the analysis. From now on, only small and // localized rewrites are allowed, such as replacing a tensor op with its // memref equivalent. - if (failed(bufferization::insertTensorCopies(getOperation(), - bufferizationOptions))) + bufferization::BufferizationState bufferizationState; + + if (failed(bufferization::insertTensorCopies( + getOperation(), bufferizationOptions, bufferizationState))) return signalPassFailure(); // Option `testAnalysisOnly` is a debug/testing flag. If set, the results of diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp index 630e970cd4b19..4b778b768d136 100644 --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -51,10 +51,11 @@ struct CastOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto castOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( - castOp.getSource(), options, invocationStack); + castOp.getSource(), options, state, invocationStack); if (failed(maybeSrcBufferType)) return failure(); Attribute memorySpace = maybeSrcBufferType->getMemorySpace(); @@ -89,13 +90,13 @@ struct CastOpInterface // The result buffer still has the old (pre-cast) type. FailureOr resultBuffer = - getBuffer(rewriter, castOp.getSource(), options); + getBuffer(rewriter, castOp.getSource(), options, state); if (failed(resultBuffer)) return failure(); // Compute the new type. auto resultMemRefType = - bufferization::getBufferType(castOp.getResult(), options); + bufferization::getBufferType(castOp.getResult(), options, state); if (failed(resultMemRefType)) return failure(); if (resultBuffer->getType() == *resultMemRefType) { @@ -141,10 +142,11 @@ struct CollapseShapeOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto collapseShapeOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( - collapseShapeOp.getSrc(), options, invocationStack); + collapseShapeOp.getSrc(), options, state, invocationStack); if (failed(maybeSrcBufferType)) return failure(); auto srcBufferType = llvm::cast(*maybeSrcBufferType); @@ -168,7 +170,7 @@ struct CollapseShapeOpInterface auto collapseShapeOp = cast(op); RankedTensorType tensorResultType = collapseShapeOp.getResultType(); FailureOr maybeBuffer = - getBuffer(rewriter, collapseShapeOp.getSrc(), options); + getBuffer(rewriter, collapseShapeOp.getSrc(), options, state); if (failed(maybeBuffer)) return failure(); Value buffer = *maybeBuffer; @@ -210,7 +212,7 @@ struct CollapseShapeOpInterface // TODO: Create alloc_tensor ops during TensorCopyInsertion. AnalysisState analysisState(options); FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, op->getLoc(), collapseShapeOp.getSrc(), options); + rewriter, op->getLoc(), collapseShapeOp.getSrc(), options, state); if (failed(tensorAlloc)) return failure(); auto memrefType = @@ -252,7 +254,7 @@ struct DimOpInterface const BufferizationOptions &options, BufferizationState &state) const { auto dimOp = cast(op); - FailureOr v = getBuffer(rewriter, dimOp.getSource(), options); + FailureOr v = getBuffer(rewriter, dimOp.getSource(), options, state); if (failed(v)) return failure(); replaceOpWithNewBufferizedOp(rewriter, op, *v, @@ -286,7 +288,8 @@ struct EmptyOpInterface // Allocate a tensor. This emits a "bufferization.alloc_tensor" op. FailureOr allocTensor = allocateTensorForShapedValue( - rewriter, op->getLoc(), emptyOp.getResult(), options, /*copy=*/false); + rewriter, op->getLoc(), emptyOp.getResult(), options, state, + /*copy=*/false); if (failed(allocTensor)) return failure(); rewriter.replaceOp(op, *allocTensor); @@ -317,10 +320,11 @@ struct ExpandShapeOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto expandShapeOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( - expandShapeOp.getSrc(), options, invocationStack); + expandShapeOp.getSrc(), options, state, invocationStack); if (failed(maybeSrcBufferType)) return failure(); auto srcBufferType = llvm::cast(*maybeSrcBufferType); @@ -338,7 +342,7 @@ struct ExpandShapeOpInterface auto expandShapeOp = cast(op); auto tensorResultType = expandShapeOp.getResultType(); FailureOr buffer = - getBuffer(rewriter, expandShapeOp.getSrc(), options); + getBuffer(rewriter, expandShapeOp.getSrc(), options, state); if (failed(buffer)) return failure(); @@ -382,13 +386,13 @@ struct ExtractSliceOpInterface // Get source buffer. FailureOr srcMemref = - getBuffer(rewriter, extractSliceOp.getSource(), options); + getBuffer(rewriter, extractSliceOp.getSource(), options, state); if (failed(srcMemref)) return failure(); // Take a subview of the source buffer. - auto resultMemrefType = - bufferization::getBufferType(extractSliceOp.getResult(), options); + auto resultMemrefType = bufferization::getBufferType( + extractSliceOp.getResult(), options, state); if (failed(resultMemrefType)) return failure(); Value subView = rewriter.create( @@ -401,11 +405,12 @@ struct ExtractSliceOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto extractSliceOp = cast(op); assert(value == extractSliceOp.getResult() && "invalid value"); auto srcMemrefType = bufferization::getBufferType( - extractSliceOp.getSource(), options, invocationStack); + extractSliceOp.getSource(), options, state, invocationStack); if (failed(srcMemrefType)) return failure(); SmallVector mixedOffsets = extractSliceOp.getMixedOffsets(); @@ -442,7 +447,7 @@ struct ExtractOpInterface BufferizationState &state) const { auto extractOp = cast(op); FailureOr srcMemref = - getBuffer(rewriter, extractOp.getTensor(), options); + getBuffer(rewriter, extractOp.getTensor(), options, state); if (failed(srcMemref)) return failure(); replaceOpWithNewBufferizedOp(rewriter, op, *srcMemref, @@ -491,12 +496,12 @@ struct FromElementsOpInterface auto shape = tensorType.getShape(); // TODO: Create alloc_tensor ops during TensorCopyInsertion. FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, loc, fromElementsOp.getResult(), options, + rewriter, loc, fromElementsOp.getResult(), options, state, /*copy=*/false); if (failed(tensorAlloc)) return failure(); FailureOr memrefType = - bufferization::getBufferType(*tensorAlloc, options); + bufferization::getBufferType(*tensorAlloc, options, state); if (failed(memrefType)) return failure(); Value buffer = rewriter.create( @@ -607,7 +612,7 @@ struct GenerateOpInterface // Allocate memory. Location loc = op->getLoc(); FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, loc, generateOp.getResult(), options, + rewriter, loc, generateOp.getResult(), options, state, /*copy=*/false); if (failed(tensorAlloc)) return failure(); @@ -633,7 +638,7 @@ struct InsertOpInterface BufferizationState &state) const { auto insertOp = cast(op); FailureOr destMemref = - getBuffer(rewriter, insertOp.getDest(), options); + getBuffer(rewriter, insertOp.getDest(), options, state); if (failed(destMemref)) return failure(); rewriter.create(insertOp.getLoc(), insertOp.getScalar(), @@ -695,7 +700,7 @@ struct InsertSliceOpInterface // Get destination buffer. FailureOr dstMemref = - getBuffer(rewriter, insertSliceOp.getDest(), options); + getBuffer(rewriter, insertSliceOp.getDest(), options, state); if (failed(dstMemref)) return failure(); @@ -712,7 +717,7 @@ struct InsertSliceOpInterface // Copy tensor. If this tensor.insert_slice has a matching // tensor.extract_slice, the copy operation will eventually fold away. FailureOr srcMemref = - getBuffer(rewriter, insertSliceOp.getSource(), options); + getBuffer(rewriter, insertSliceOp.getSource(), options, state); if (failed(srcMemref)) return failure(); if (failed(options.createMemCpy(rewriter, loc, *srcMemref, subView))) @@ -749,11 +754,12 @@ struct PadOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { // Infer memory space from the source tensor. auto padOp = cast(op); auto maybeSrcBufferType = bufferization::getBufferType( - padOp.getSource(), options, invocationStack); + padOp.getSource(), options, state, invocationStack); if (failed(maybeSrcBufferType)) return failure(); MemRefLayoutAttrInterface layout; @@ -797,9 +803,9 @@ struct PadOpInterface } // Allocate a buffer for the padded result. - FailureOr tensorAlloc = - allocateTensorForShapedValue(rewriter, loc, padOp.getResult(), options, - /*copy=*/false); + FailureOr tensorAlloc = allocateTensorForShapedValue( + rewriter, loc, padOp.getResult(), options, state, + /*copy=*/false); if (failed(tensorAlloc)) return failure(); @@ -846,7 +852,8 @@ struct RankOpInterface const BufferizationOptions &options, BufferizationState &state) const { auto rankOp = cast(op); - FailureOr v = getBuffer(rewriter, rankOp.getTensor(), options); + FailureOr v = + getBuffer(rewriter, rankOp.getTensor(), options, state); if (failed(v)) return failure(); replaceOpWithNewBufferizedOp(rewriter, op, rankOp.getType(), @@ -885,13 +892,13 @@ struct ReshapeOpInterface BufferizationState &state) const { auto reshapeOp = cast(op); FailureOr srcBuffer = - getBuffer(rewriter, reshapeOp.getSource(), options); + getBuffer(rewriter, reshapeOp.getSource(), options, state); FailureOr shapeBuffer = - getBuffer(rewriter, reshapeOp.getShape(), options); + getBuffer(rewriter, reshapeOp.getShape(), options, state); if (failed(srcBuffer) || failed(shapeBuffer)) return failure(); auto maybeResultMemRefType = - bufferization::getBufferType(reshapeOp.getResult(), options); + bufferization::getBufferType(reshapeOp.getResult(), options, state); if (failed(maybeResultMemRefType)) return failure(); @@ -901,7 +908,7 @@ struct ReshapeOpInterface auto srcType = llvm::dyn_cast(srcBuffer->getType()); if (srcType && !srcType.getLayout().isIdentity()) { FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, op->getLoc(), reshapeOp.getSource(), options); + rewriter, op->getLoc(), reshapeOp.getSource(), options, state); if (failed(tensorAlloc)) return failure(); auto memrefType = MemRefType::get( @@ -920,11 +927,12 @@ struct ReshapeOpInterface FailureOr getBufferType(Operation *op, Value value, const BufferizationOptions &options, + const BufferizationState &state, SmallVector &invocationStack) const { auto reshapeOp = cast(op); assert(value == reshapeOp.getResult() && "unexpected value provided"); auto maybeSourceBufferType = bufferization::getBufferType( - reshapeOp.getSource(), options, invocationStack); + reshapeOp.getSource(), options, state, invocationStack); if (failed(maybeSourceBufferType)) return failure(); return getMemRefTypeWithStaticIdentityLayout( @@ -966,11 +974,11 @@ struct ParallelInsertSliceOpInterface // Get source and destination buffers. FailureOr destBuffer = - getBuffer(rewriter, parallelInsertSliceOp.getDest(), options); + getBuffer(rewriter, parallelInsertSliceOp.getDest(), options, state); if (failed(destBuffer)) return failure(); FailureOr srcBuffer = - getBuffer(rewriter, parallelInsertSliceOp.getSource(), options); + getBuffer(rewriter, parallelInsertSliceOp.getSource(), options, state); if (failed(srcBuffer)) return failure(); @@ -1015,8 +1023,10 @@ struct ParallelInsertSliceOpInterface /// tensor.parallel_insert_slice op has implicit inplace behavior. We /// shouldn't create copy to resolve conflict. - LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, - const AnalysisState &state) const { + LogicalResult + resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &analysisState, + const BufferizationState &bufferizationState) const { return success(); } }; @@ -1038,7 +1048,7 @@ struct SplatOpInterface // Allocate memory. Location loc = op->getLoc(); FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, loc, splatOp.getResult(), options, + rewriter, loc, splatOp.getResult(), options, state, /*copy=*/false); if (failed(tensorAlloc)) return failure(); @@ -1097,7 +1107,7 @@ struct ConcatOpInterface // Allocate memory. Location loc = op->getLoc(); FailureOr tensorAlloc = allocateTensorForShapedValue( - rewriter, loc, concatOp.getResult(), options, + rewriter, loc, concatOp.getResult(), options, state, /*copy=*/false); if (failed(tensorAlloc)) return failure(); @@ -1147,7 +1157,7 @@ struct ConcatOpInterface for (auto operand : concatOp.getInputs()) { // Get the buffer for the operand. - FailureOr srcBuffer = getBuffer(rewriter, operand, options); + FailureOr srcBuffer = getBuffer(rewriter, operand, options, state); if (failed(srcBuffer)) return failure(); diff --git a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp index 45b6e7c512947..9da051150e409 100644 --- a/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp @@ -53,7 +53,8 @@ struct TransferReadOpInterface auto readOp = cast(op); assert(isa(readOp.getShapedType()) && "only tensor types expected"); - FailureOr buffer = getBuffer(rewriter, readOp.getBase(), options); + FailureOr buffer = + getBuffer(rewriter, readOp.getBase(), options, state); if (failed(buffer)) return failure(); replaceOpWithNewBufferizedOp( @@ -112,7 +113,7 @@ struct TransferWriteOpInterface // Create a new transfer_write on buffer that doesn't have a return value. FailureOr resultBuffer = - getBuffer(rewriter, writeOp.getBase(), options); + getBuffer(rewriter, writeOp.getBase(), options, state); if (failed(resultBuffer)) return failure(); rewriter.create( @@ -155,7 +156,8 @@ struct GatherOpInterface auto gatherOp = cast(op); assert(isa(gatherOp.getBaseType()) && "only tensor types expected"); - FailureOr buffer = getBuffer(rewriter, gatherOp.getBase(), options); + FailureOr buffer = + getBuffer(rewriter, gatherOp.getBase(), options, state); if (failed(buffer)) return failure(); replaceOpWithNewBufferizedOp( @@ -184,10 +186,13 @@ struct MaskOpInterface return {{&yieldOp->getOpOperand(resultNum), BufferRelation::Equivalent}}; } - LogicalResult resolveConflicts(Operation *op, RewriterBase &rewriter, - const AnalysisState &state) const { + LogicalResult + resolveConflicts(Operation *op, RewriterBase &rewriter, + const AnalysisState &analysisState, + const BufferizationState &bufferizationState) const { auto bufferizableOp = cast(op); - if (failed(bufferizableOp.resolveTensorOpOperandConflicts(rewriter, state))) + if (failed(bufferizableOp.resolveTensorOpOperandConflicts( + rewriter, analysisState, bufferizationState))) return failure(); // TODO: Remove this function when vector.mask bodies can bufferize @@ -302,7 +307,8 @@ struct YieldOpInterface SmallVector newResults; for (Value value : yieldOp.getOperands()) { if (isa(value.getType())) { - FailureOr maybeBuffer = getBuffer(rewriter, value, options); + FailureOr maybeBuffer = + getBuffer(rewriter, value, options, state); if (failed(maybeBuffer)) return failure(); newResults.push_back(*maybeBuffer); diff --git a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp index 2991a3c165ee2..dfaebccde7dcc 100644 --- a/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp +++ b/mlir/test/lib/Dialect/Bufferization/TestTensorCopyInsertion.cpp @@ -48,7 +48,11 @@ struct TestTensorCopyInsertionPass options.defaultMemorySpaceFn = [](TensorType t) -> std::optional { return std::nullopt; }; } - if (failed(bufferization::insertTensorCopies(getOperation(), options))) + + bufferization::BufferizationState bufferizationState; + + if (failed(bufferization::insertTensorCopies(getOperation(), options, + bufferizationState))) signalPassFailure(); }