Skip to content

Commit 61d5fdf

Browse files
authored
[MLIR] Add bufferization state class to OneShotBufferization pass (#141019)
Follow-up on #138143, which was reverted due to a missing update a method signature (more specifically, the bufferization interface for `tensor::ConcatOp`) that was not catched before merging. The old PR description is reported in the next lines. This PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR. The information currently consists of a cached list of symbol tables, which aims to solve the quadratic scaling of the bufferization task with respect to the number of symbols. The PR breaks API compatibility: the bufferize method of the BufferizableOpInterface has been enriched with a reference to a BufferizationState object. The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the Symbol trait is inserted or replaced, its parent SymbolTable must be updated accordingly (see, for example, the bufferization of arith::ConstantOp, where the symbol table of the module gets the new global symbol inserted). Similarly, the invalidation of a symbol table must be performed if an operation with the SymbolTable trait is removed (this can be performed using the invalidateSymbolTable method, introduced in #138014).
1 parent 3d02834 commit 61d5fdf

27 files changed

+215
-87
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,20 @@ class AnalysisState {
578578
insideMutuallyExclusiveRegionsCache;
579579
};
580580

581+
/// BufferizationState provides information about the state of the IR during the
582+
/// bufferization process.
583+
class BufferizationState {
584+
public:
585+
/// Get a reference to the collection of cached symbol tables.
586+
SymbolTableCollection &getSymbolTables();
587+
588+
private:
589+
/// The cached symbol tables.
590+
/// The user is expected to update / invalidate the cached symbol tables if
591+
/// the bufferized operation has the Symbol or SymbolTable traits.
592+
SymbolTableCollection symbolTables;
593+
};
594+
581595
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
582596
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
583597
/// undefined contents is allocated.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
426426
/*retType=*/"::llvm::LogicalResult",
427427
/*methodName=*/"bufferize",
428428
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
429-
"const ::mlir::bufferization::BufferizationOptions &":$options),
429+
"const ::mlir::bufferization::BufferizationOptions &":$options,
430+
"::mlir::bufferization::BufferizationState &":$state),
430431
/*methodBody=*/"",
431432
/*defaultImplementation=*/[{
432433
llvm_unreachable("bufferize not implemented");

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
9393

9494
let extraClassDeclaration = [{
9595
LogicalResult bufferize(RewriterBase &rewriter,
96-
const BufferizationOptions &options);
96+
const BufferizationOptions &options,
97+
BufferizationState &state);
9798

9899
bool resultBufferizesToMemoryWrite(OpResult opResult,
99100
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
282283

283284
let extraClassDeclaration = [{
284285
LogicalResult bufferize(RewriterBase &rewriter,
285-
const BufferizationOptions &options);
286+
const BufferizationOptions &options,
287+
BufferizationState &state);
286288

287289
bool bufferizesToMemoryRead(OpOperand &opOperand,
288290
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
375377
}
376378

377379
LogicalResult bufferize(RewriterBase &rewriter,
378-
const BufferizationOptions &options);
380+
const BufferizationOptions &options,
381+
BufferizationState &state);
379382
}];
380383
}
381384

@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
458461
//===------------------------------------------------------------------===//
459462

460463
LogicalResult bufferize(RewriterBase &rewriter,
461-
const BufferizationOptions &options) const {
464+
const BufferizationOptions &options,
465+
BufferizationState &state) const {
462466
// to_tensor/to_buffer pairs fold away after bufferization.
463467
return success();
464468
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
550554
}
551555

552556
LogicalResult bufferize(RewriterBase &rewriter,
553-
const BufferizationOptions &options);
557+
const BufferizationOptions &options,
558+
BufferizationState &state);
554559
}];
555560

556561
let assemblyFormat = [{

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GlobalOp;
2929
} // namespace memref
3030

3131
namespace bufferization {
32+
class BufferizationState;
3233

3334
/// A simple analysis that detects allocation operations.
3435
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
122123
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
123124
// names. Duplicates are avoided.
124125
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
126+
SymbolTableCollection &symbolTables,
125127
uint64_t alignment,
126128
Attribute memorySpace = {});
127129

130+
void removeSymbol(Operation *op, BufferizationState &state);
131+
132+
void insertSymbol(Operation *op, BufferizationState &state);
133+
128134
} // namespace bufferization
129135
} // namespace mlir
130136

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
4545
/// additional buffer copies or set "options.copyBeforeWrite = true". The
4646
/// general bufferization entry point is `runOneShotBufferize`.
4747
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
48+
BufferizationState &bufferizationState,
4849
BufferizationStatistics *statistics = nullptr);
4950

5051
/// Bufferize the signature of `block` and its callers (i.e., ops that have the

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
270270
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
271271
LogicalResult
272272
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
273+
BufferizationState &state,
273274
BufferizationStatistics *statistics = nullptr);
274275

275276
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace bufferization {
2020
struct BufferizationStatistics;
2121
class OneShotAnalysisState;
2222
struct OneShotBufferizationOptions;
23+
class BufferizationState;
2324

2425
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
2526
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3839
/// will be inserted only to these FuncOps.
3940
llvm::LogicalResult
4041
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42+
BufferizationState &state,
4143
BufferizationStatistics *statistics = nullptr);
4244

4345
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
5052
llvm::LogicalResult runOneShotModuleBufferize(
5153
ModuleOp moduleOp,
5254
const bufferization::OneShotBufferizationOptions &options,
53-
BufferizationStatistics *statistics = nullptr);
55+
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
5456

5557
} // namespace bufferization
5658
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace mlir {
3030
namespace bufferization {
3131
class AllocTensorOp;
3232
class OneShotAnalysisState;
33+
class BufferizationState;
3334
} // namespace bufferization
3435

3536
namespace linalg {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ struct ConstantOpInterface
2424
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2525
arith::ConstantOp> {
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
27-
const BufferizationOptions &options) const {
27+
const BufferizationOptions &options,
28+
BufferizationState &state) const {
2829
auto constantOp = cast<arith::ConstantOp>(op);
2930
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
3031

@@ -46,7 +47,8 @@ struct ConstantOpInterface
4647
// Create global memory segment and replace tensor with memref pointing to
4748
// that memory segment.
4849
FailureOr<memref::GlobalOp> globalOp =
49-
getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
50+
getGlobalFor(constantOp, state.getSymbolTables(),
51+
options.bufferAlignment, memorySpace);
5052
if (failed(globalOp))
5153
return failure();
5254
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
8385
}
8486

8587
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
86-
const BufferizationOptions &options) const {
88+
const BufferizationOptions &options,
89+
BufferizationState &state) const {
8790
auto castOp = cast<arith::IndexCastOp>(op);
8891
auto resultTensorType = cast<TensorType>(castOp.getType());
8992

@@ -131,7 +134,8 @@ struct SelectOpInterface
131134
}
132135

133136
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
134-
const BufferizationOptions &options) const {
137+
const BufferizationOptions &options,
138+
BufferizationState &state) const {
135139
auto selectOp = cast<arith::SelectOp>(op);
136140
Location loc = selectOp.getLoc();
137141

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ void AnalysisState::resetCache() {
125125
insideMutuallyExclusiveRegionsCache.clear();
126126
}
127127

128+
SymbolTableCollection &BufferizationState::getSymbolTables() {
129+
return symbolTables;
130+
}
131+
128132
Region *bufferization::getNextEnclosingRepetitiveRegion(
129133
Region *region, const BufferizationOptions &options) {
130134
assert(isRepetitiveRegion(region, options) && "expected repetitive region");

0 commit comments

Comments
 (0)