@@ -307,3 +307,53 @@ func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() {
307307// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies
308308// CHECK: loop.parallel
309309// CHECK: loop.parallel
310+
311+ // -----
312+
313+ func @nested_fuse (%A: memref <2 x2 xf32 >, %B: memref <2 x2 xf32 >,
314+ %C: memref <2 x2 xf32 >, %result: memref <2 x2 xf32 >) {
315+ %c2 = constant 2 : index
316+ %c0 = constant 0 : index
317+ %c1 = constant 1 : index
318+ %sum = alloc () : memref <2 x2 xf32 >
319+ loop.parallel (%k ) = (%c0 ) to (%c2 ) step (%c1 ) {
320+ loop.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
321+ %B_elem = load %B [%i , %j ] : memref <2 x2 xf32 >
322+ %C_elem = load %C [%i , %j ] : memref <2 x2 xf32 >
323+ %sum_elem = addf %B_elem , %C_elem : f32
324+ store %sum_elem , %sum [%i , %j ] : memref <2 x2 xf32 >
325+ loop.yield
326+ }
327+ loop.parallel (%i , %j ) = (%c0 , %c0 ) to (%c2 , %c2 ) step (%c1 , %c1 ) {
328+ %sum_elem = load %sum [%i , %j ] : memref <2 x2 xf32 >
329+ %A_elem = load %A [%i , %j ] : memref <2 x2 xf32 >
330+ %product_elem = mulf %sum_elem , %A_elem : f32
331+ store %product_elem , %result [%i , %j ] : memref <2 x2 xf32 >
332+ loop.yield
333+ }
334+ }
335+ dealloc %sum : memref <2 x2 xf32 >
336+ return
337+ }
338+ // CHECK-LABEL: func @nested_fuse
339+ // CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}},
340+ // CHECK-SAME: [[RESULT:%.*]]: {{.*}}) {
341+ // CHECK: [[C2:%.*]] = constant 2 : index
342+ // CHECK: [[C0:%.*]] = constant 0 : index
343+ // CHECK: [[C1:%.*]] = constant 1 : index
344+ // CHECK: [[SUM:%.*]] = alloc()
345+ // CHECK: loop.parallel
346+ // CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]])
347+ // CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) {
348+ // CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]]
349+ // CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]]
350+ // CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]]
351+ // CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]]
352+ // CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]]
353+ // CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]]
354+ // CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]]
355+ // CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]]
356+ // CHECK: loop.yield
357+ // CHECK: }
358+ // CHECK: }
359+ // CHECK: dealloc [[SUM]]
0 commit comments