1818#include " mlir/Transforms/Passes.h"
1919#include " mlir/Transforms/RegionUtils.h"
2020#include " llvm/ADT/STLExtras.h"
21- #include " llvm/Support/Casting.h"
2221#include " llvm/Support/GenericIteratedDominanceFrontier.h"
2322
2423namespace mlir {
@@ -158,6 +157,8 @@ class MemorySlotPromotionAnalyzer {
158157 const DataLayout &dataLayout;
159158};
160159
160+ using BlockIndexCache = DenseMap<Region *, DenseMap<Block *, size_t >>;
161+
161162// / The MemorySlotPromoter handles the state of promoting a memory slot. It
162163// / wraps a slot and its associated allocator. This will perform the mutation of
163164// / IR.
@@ -166,7 +167,8 @@ class MemorySlotPromoter {
166167 MemorySlotPromoter (MemorySlot slot, PromotableAllocationOpInterface allocator,
167168 OpBuilder &builder, DominanceInfo &dominance,
168169 const DataLayout &dataLayout, MemorySlotPromotionInfo info,
169- const Mem2RegStatistics &statistics);
170+ const Mem2RegStatistics &statistics,
171+ BlockIndexCache &blockIndexCache);
170172
171173 // / Actually promotes the slot by mutating IR. Promoting a slot DOES
172174 // / invalidate the MemorySlotPromotionInfo of other slots. Preparation of
@@ -207,16 +209,21 @@ class MemorySlotPromoter {
207209 const DataLayout &dataLayout;
208210 MemorySlotPromotionInfo info;
209211 const Mem2RegStatistics &statistics;
212+
213+ // / Shared cache of block indices of specific regions.
214+ BlockIndexCache &blockIndexCache;
210215};
211216
212217} // namespace
213218
214219MemorySlotPromoter::MemorySlotPromoter (
215220 MemorySlot slot, PromotableAllocationOpInterface allocator,
216221 OpBuilder &builder, DominanceInfo &dominance, const DataLayout &dataLayout,
217- MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics)
222+ MemorySlotPromotionInfo info, const Mem2RegStatistics &statistics,
223+ BlockIndexCache &blockIndexCache)
218224 : slot(slot), allocator(allocator), builder(builder), dominance(dominance),
219- dataLayout(dataLayout), info(std::move(info)), statistics(statistics) {
225+ dataLayout(dataLayout), info(std::move(info)), statistics(statistics),
226+ blockIndexCache(blockIndexCache) {
220227#ifndef NDEBUG
221228 auto isResultOrNewBlockArgument = [&]() {
222229 if (BlockArgument arg = dyn_cast<BlockArgument>(slot.ptr ))
@@ -500,15 +507,29 @@ void MemorySlotPromoter::computeReachingDefInRegion(Region *region,
500507 }
501508}
502509
510+ // / Gets or creates a block index mapping for `region`.
511+ static const DenseMap<Block *, size_t > &
512+ getOrCreateBlockIndices (BlockIndexCache &blockIndexCache, Region *region) {
513+ auto [it, inserted] = blockIndexCache.try_emplace (region);
514+ if (!inserted)
515+ return it->second ;
516+
517+ DenseMap<Block *, size_t > &blockIndices = it->second ;
518+ SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks (*region);
519+ for (auto [index, block] : llvm::enumerate (topologicalOrder))
520+ blockIndices[block] = index;
521+ return blockIndices;
522+ }
523+
503524// / Sorts `ops` according to dominance. Relies on the topological order of basic
504- // / blocks to get a deterministic ordering.
505- static void dominanceSort (SmallVector<Operation *> &ops, Region ®ion) {
525+ // / blocks to get a deterministic ordering. Uses `blockIndexCache` to avoid the
526+ // / potentially expensive recomputation of a block index map.
527+ static void dominanceSort (SmallVector<Operation *> &ops, Region ®ion,
528+ BlockIndexCache &blockIndexCache) {
506529 // Produce a topological block order and construct a map to lookup the indices
507530 // of blocks.
508- DenseMap<Block *, size_t > topoBlockIndices;
509- SetVector<Block *> topologicalOrder = getTopologicallySortedBlocks (region);
510- for (auto [index, block] : llvm::enumerate (topologicalOrder))
511- topoBlockIndices[block] = index;
531+ const DenseMap<Block *, size_t > &topoBlockIndices =
532+ getOrCreateBlockIndices (blockIndexCache, ®ion);
512533
513534 // Combining the topological order of the basic blocks together with block
514535 // internal operation order guarantees a deterministic, dominance respecting
@@ -527,7 +548,8 @@ void MemorySlotPromoter::removeBlockingUses() {
527548 llvm::make_first_range (info.userToBlockingUses ));
528549
529550 // Sort according to dominance.
530- dominanceSort (usersToRemoveUses, *slot.ptr .getParentBlock ()->getParent ());
551+ dominanceSort (usersToRemoveUses, *slot.ptr .getParentBlock ()->getParent (),
552+ blockIndexCache);
531553
532554 llvm::SmallVector<Operation *> toErase;
533555 // List of all replaced values in the slot.
@@ -605,20 +627,25 @@ void MemorySlotPromoter::promoteSlot() {
605627
606628LogicalResult mlir::tryToPromoteMemorySlots (
607629 ArrayRef<PromotableAllocationOpInterface> allocators, OpBuilder &builder,
608- const DataLayout &dataLayout, Mem2RegStatistics statistics) {
630+ const DataLayout &dataLayout, DominanceInfo &dominance,
631+ Mem2RegStatistics statistics) {
609632 bool promotedAny = false ;
610633
634+ // A cache that stores deterministic block indices which are used to determine
635+ // a valid operation modification order. The block index maps are computed
636+ // lazily and cached to avoid expensive recomputation.
637+ BlockIndexCache blockIndexCache;
638+
611639 for (PromotableAllocationOpInterface allocator : allocators) {
612640 for (MemorySlot slot : allocator.getPromotableSlots ()) {
613641 if (slot.ptr .use_empty ())
614642 continue ;
615643
616- DominanceInfo dominance;
617644 MemorySlotPromotionAnalyzer analyzer (slot, dominance, dataLayout);
618645 std::optional<MemorySlotPromotionInfo> info = analyzer.computeInfo ();
619646 if (info) {
620647 MemorySlotPromoter (slot, allocator, builder, dominance, dataLayout,
621- std::move (*info), statistics)
648+ std::move (*info), statistics, blockIndexCache )
622649 .promoteSlot ();
623650 promotedAny = true ;
624651 }
@@ -640,6 +667,10 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
640667
641668 bool changed = false ;
642669
670+ auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
671+ const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove (scopeOp);
672+ auto &dominance = getAnalysis<DominanceInfo>();
673+
643674 for (Region ®ion : scopeOp->getRegions ()) {
644675 if (region.getBlocks ().empty ())
645676 continue ;
@@ -655,16 +686,12 @@ struct Mem2Reg : impl::Mem2RegBase<Mem2Reg> {
655686 allocators.emplace_back (allocator);
656687 });
657688
658- auto &dataLayoutAnalysis = getAnalysis<DataLayoutAnalysis>();
659- const DataLayout &dataLayout = dataLayoutAnalysis.getAtOrAbove (scopeOp);
660-
661689 // Attempt promoting until no promotion succeeds.
662690 if (failed (tryToPromoteMemorySlots (allocators, builder, dataLayout,
663- statistics)))
691+ dominance, statistics)))
664692 break ;
665693
666694 changed = true ;
667- getAnalysisManager ().invalidate ({});
668695 }
669696 }
670697 if (!changed)
0 commit comments