@@ -896,6 +896,16 @@ struct GreedyFusion {
896896 if (fusedLoopInsPoint == nullptr )
897897 continue ;
898898
899+ // It's possible this fusion is at an inner depth (i.e., there are
900+ // common surrounding affine loops for the source and destination for
901+ // ops). We need to get this number because the call to canFuseLoops
902+ // needs to be passed the absolute depth. The max legal depth and the
903+ // depths we try below are however *relative* and as such don't include
904+ // the common depth.
905+ SmallVector<AffineForOp, 4 > surroundingLoops;
906+ getAffineForIVs (*dstAffineForOp, &surroundingLoops);
907+ unsigned numSurroundingLoops = surroundingLoops.size ();
908+
899909 // Compute the innermost common loop depth for dstNode
900910 // producer-consumer loads/stores.
901911 SmallVector<Operation *, 2 > dstMemrefOps;
@@ -907,7 +917,8 @@ struct GreedyFusion {
907917 if (producerConsumerMemrefs.count (
908918 cast<AffineWriteOpInterface>(op).getMemRef ()))
909919 dstMemrefOps.push_back (op);
910- unsigned dstLoopDepthTest = getInnermostCommonLoopDepth (dstMemrefOps);
920+ unsigned dstLoopDepthTest =
921+ getInnermostCommonLoopDepth (dstMemrefOps) - numSurroundingLoops;
911922
912923 // Check the feasibility of fusing src loop nest into dst loop nest
913924 // at loop depths in range [1, dstLoopDepthTest].
@@ -916,9 +927,10 @@ struct GreedyFusion {
916927 depthSliceUnions.resize (dstLoopDepthTest);
917928 FusionStrategy strategy (FusionStrategy::ProducerConsumer);
918929 for (unsigned i = 1 ; i <= dstLoopDepthTest; ++i) {
919- FusionResult result = affine::canFuseLoops (
920- srcAffineForOp, dstAffineForOp,
921- /* dstLoopDepth=*/ i, &depthSliceUnions[i - 1 ], strategy);
930+ FusionResult result =
931+ affine::canFuseLoops (srcAffineForOp, dstAffineForOp,
932+ /* dstLoopDepth=*/ i + numSurroundingLoops,
933+ &depthSliceUnions[i - 1 ], strategy);
922934
923935 if (result.value == FusionResult::Success)
924936 maxLegalFusionDepth = i;
@@ -1125,9 +1137,18 @@ struct GreedyFusion {
11251137 SmallVector<Operation *, 2 > dstLoadOpInsts;
11261138 dstNode->getLoadOpsForMemref (memref, &dstLoadOpInsts);
11271139
1140+ // It's possible this fusion is at an inner depth (i.e., there are common
1141+ // surrounding affine loops for the source and destination for ops). We
1142+ // need to get this number because the call to canFuseLoops needs to be
1143+ // passed the absolute depth. The max legal depth and the depths we try
1144+ // below are however *relative* and as such don't include the common
1145+ // depth.
1146+ SmallVector<AffineForOp, 4 > surroundingLoops;
1147+ getAffineForIVs (*dstAffineForOp, &surroundingLoops);
1148+ unsigned numSurroundingLoops = surroundingLoops.size ();
11281149 SmallVector<AffineForOp, 4 > dstLoopIVs;
11291150 getAffineForIVs (*dstLoadOpInsts[0 ], &dstLoopIVs);
1130- unsigned dstLoopDepthTest = dstLoopIVs.size ();
1151+ unsigned dstLoopDepthTest = dstLoopIVs.size () - numSurroundingLoops ;
11311152 auto sibAffineForOp = cast<AffineForOp>(sibNode->op );
11321153
11331154 // Compute loop depth and slice union for fusion.
@@ -1136,14 +1157,18 @@ struct GreedyFusion {
11361157 unsigned maxLegalFusionDepth = 0 ;
11371158 FusionStrategy strategy (memref);
11381159 for (unsigned i = 1 ; i <= dstLoopDepthTest; ++i) {
1139- FusionResult result = affine::canFuseLoops (
1140- sibAffineForOp, dstAffineForOp,
1141- /* dstLoopDepth=*/ i, &depthSliceUnions[i - 1 ], strategy);
1160+ FusionResult result =
1161+ affine::canFuseLoops (sibAffineForOp, dstAffineForOp,
1162+ /* dstLoopDepth=*/ i + numSurroundingLoops,
1163+ &depthSliceUnions[i - 1 ], strategy);
11421164
11431165 if (result.value == FusionResult::Success)
11441166 maxLegalFusionDepth = i;
11451167 }
11461168
1169+ LLVM_DEBUG (llvm::dbgs () << " Max legal depth for fusion: "
1170+ << maxLegalFusionDepth << ' \n ' );
1171+
11471172 // Skip if fusion is not feasible at any loop depths.
11481173 if (maxLegalFusionDepth == 0 )
11491174 continue ;
@@ -1238,9 +1263,15 @@ struct GreedyFusion {
12381263 SmallVector<AffineForOp, 4 > loops;
12391264 getAffineForIVs (*user, &loops);
12401265 // Skip 'use' if it is not within a loop nest.
1241- if (loops.empty ())
1266+ // Find the surrounding affine.for nested immediately within the
1267+ // block.
1268+ auto *it = llvm::find_if (loops, [&](AffineForOp loop) {
1269+ return loop->getBlock () == &mdg->block ;
1270+ });
1271+ // Skip 'use' if it is not within a loop nest in `block`.
1272+ if (it == loops.end ())
12421273 continue ;
1243- Node *sibNode = mdg->getForOpNode (loops[ 0 ] );
1274+ Node *sibNode = mdg->getForOpNode (*it );
12441275 assert (sibNode != nullptr );
12451276 // Skip 'use' if it not a sibling to 'dstNode'.
12461277 if (sibNode->id == dstNode->id )
@@ -1373,9 +1404,17 @@ void LoopFusion::runOnBlock(Block *block) {
13731404}
13741405
13751406void LoopFusion::runOnOperation () {
1376- for (Region ®ion : getOperation ()->getRegions ())
1377- for (Block &block : region.getBlocks ())
1378- runOnBlock (&block);
1407+ // Call fusion on every op that has at least two affine.for nests (in post
1408+ // order).
1409+ getOperation ()->walk ([&](Operation *op) {
1410+ for (Region ®ion : op->getRegions ()) {
1411+ for (Block &block : region.getBlocks ()) {
1412+ auto affineFors = block.getOps <AffineForOp>();
1413+ if (!affineFors.empty () && !llvm::hasSingleElement (affineFors))
1414+ runOnBlock (&block);
1415+ }
1416+ }
1417+ });
13791418}
13801419
13811420std::unique_ptr<Pass> mlir::affine::createLoopFusionPass (
0 commit comments