Skip to content

Commit a6b2877

Browse files
committed
[MLIR] Make ParallelLoopFusion pass scan through all nested regions.
Differential Revision: https://reviews.llvm.org/D79558
1 parent 1e413a8 commit a6b2877

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,12 @@ namespace {
162162
struct ParallelLoopFusion
163163
: public LoopParallelLoopFusionBase<ParallelLoopFusion> {
164164
void runOnOperation() override {
165-
for (Region &region : getOperation()->getRegions())
166-
naivelyFuseParallelOps(region);
165+
getOperation()->walk([&](Operation *child) {
166+
for (Region &region : child->getRegions())
167+
naivelyFuseParallelOps(region);
168+
});
167169
}
168170
};
169-
170171
} // namespace
171172

172173
std::unique_ptr<Pass> mlir::createParallelLoopFusionPass() {

mlir/test/Dialect/Loops/parallel-loop-fusion.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,3 +307,53 @@ func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
307307
// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
308308
// CHECK: loop.parallel
309309
// CHECK: loop.parallel
310+
311+
// -----
312+
313+
func @nested_fuse(%A: memref<2x2xf32>, %B: memref<2x2xf32>,
314+
%C: memref<2x2xf32>, %result: memref<2x2xf32>) {
315+
%c2 = constant 2 : index
316+
%c0 = constant 0 : index
317+
%c1 = constant 1 : index
318+
%sum = alloc() : memref<2x2xf32>
319+
loop.parallel (%k) = (%c0) to (%c2) step (%c1) {
320+
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
321+
%B_elem = load %B[%i, %j] : memref<2x2xf32>
322+
%C_elem = load %C[%i, %j] : memref<2x2xf32>
323+
%sum_elem = addf %B_elem, %C_elem : f32
324+
store %sum_elem, %sum[%i, %j] : memref<2x2xf32>
325+
loop.yield
326+
}
327+
loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) {
328+
%sum_elem = load %sum[%i, %j] : memref<2x2xf32>
329+
%A_elem = load %A[%i, %j] : memref<2x2xf32>
330+
%product_elem = mulf %sum_elem, %A_elem : f32
331+
store %product_elem, %result[%i, %j] : memref<2x2xf32>
332+
loop.yield
333+
}
334+
}
335+
dealloc %sum : memref<2x2xf32>
336+
return
337+
}
338+
// CHECK-LABEL: func @nested_fuse
339+
// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
340+
// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
341+
// CHECK: [[C2:%.*]] = constant 2 : index
342+
// CHECK: [[C0:%.*]] = constant 0 : index
343+
// CHECK: [[C1:%.*]] = constant 1 : index
344+
// CHECK: [[SUM:%.*]] = alloc()
345+
// CHECK: loop.parallel
346+
// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
347+
// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
348+
// CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]]
349+
// CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]]
350+
// CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]]
351+
// CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
352+
// CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]]
353+
// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
354+
// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
355+
// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
356+
// CHECK: loop.yield
357+
// CHECK: }
358+
// CHECK: }
359+
// CHECK: dealloc [[SUM]]

0 commit comments

Comments
 (0)