diff --git a/llvm/lib/Transforms/Scalar/LoopFuse.cpp b/llvm/lib/Transforms/Scalar/LoopFuse.cpp index d6bd92d520e28..b5eb647a042b9 100644 --- a/llvm/lib/Transforms/Scalar/LoopFuse.cpp +++ b/llvm/lib/Transforms/Scalar/LoopFuse.cpp @@ -1176,6 +1176,28 @@ struct LoopFuser { return true; } + /// This function fixes PHI nodes after fusion in \p SafeToSink. + /// \p SafeToSink instructions are the instructions that are to be moved past + /// the fused loop. Thus, the PHI nodes in \p SafeToSink should be updated to + /// receive values from the fused loop if they are currently taking values + /// from the first loop (i.e. FC0)'s latch. + void fixPHINodes(ArrayRef SafeToSink, + const FusionCandidate &FC0, + const FusionCandidate &FC1) const { + for (Instruction *Inst : SafeToSink) { + // No update needed for non-PHI nodes. + PHINode *Phi = dyn_cast(Inst); + if (!Phi) + continue; + for (unsigned I = 0; I < Phi->getNumIncomingValues(); I++) { + if (Phi->getIncomingBlock(I) != FC0.Latch) + continue; + assert(FC1.Latch && "FC1 latch is not set"); + Phi->setIncomingBlock(I, FC1.Latch); + } + } + } + /// Collect instructions in the \p FC1 Preheader that can be hoisted /// to the \p FC0 Preheader or sunk into the \p FC1 Body bool collectMovablePreheaderInsts( @@ -1481,6 +1503,9 @@ struct LoopFuser { assert(I->getParent() == FC1.Preheader); I->moveBefore(*FC1.ExitBlock, FC1.ExitBlock->getFirstInsertionPt()); } + // PHI nodes in SinkInsts need to be updated to receive values from the + // fused loop. + fixPHINodes(SinkInsts, FC0, FC1); } /// Determine if two fusion candidates have identical guards diff --git a/llvm/test/Transforms/LoopFusion/sunk-phi-nodes.ll b/llvm/test/Transforms/LoopFusion/sunk-phi-nodes.ll new file mode 100644 index 0000000000000..36c6bdde781c7 --- /dev/null +++ b/llvm/test/Transforms/LoopFusion/sunk-phi-nodes.ll @@ -0,0 +1,65 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -passes=loop-fusion -S < %s 2>&1 | FileCheck %s +define dso_local i32 @check_sunk_phi_nodes() { +; CHECK-LABEL: define dso_local i32 @check_sunk_phi_nodes() { +; CHECK-NEXT: [[ENTRY:.*]]: +; CHECK-NEXT: br label %[[FOR_BODY:.*]] +; CHECK: [[FOR_BODY]]: +; CHECK-NEXT: [[SUM1_02:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD:%.*]], %[[FOR_INC6:.*]] ] +; CHECK-NEXT: [[I_01:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[INC:%.*]], %[[FOR_INC6]] ] +; CHECK-NEXT: [[I1_04:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[INC7:%.*]], %[[FOR_INC6]] ] +; CHECK-NEXT: [[SUM2_03:%.*]] = phi i32 [ 0, %[[ENTRY]] ], [ [[ADD5:%.*]], %[[FOR_INC6]] ] +; CHECK-NEXT: [[ADD]] = add nsw i32 [[SUM1_02]], [[I_01]] +; CHECK-NEXT: br label %[[FOR_INC:.*]] +; CHECK: [[FOR_INC]]: +; CHECK-NEXT: [[MUL:%.*]] = mul nsw i32 [[I1_04]], [[I1_04]] +; CHECK-NEXT: [[ADD5]] = add nsw i32 [[SUM2_03]], [[MUL]] +; CHECK-NEXT: br label %[[FOR_INC6]] +; CHECK: [[FOR_INC6]]: +; CHECK-NEXT: [[INC]] = add nsw i32 [[I_01]], 1 +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[INC]], 10 +; CHECK-NEXT: [[INC7]] = add nsw i32 [[I1_04]], 1 +; CHECK-NEXT: [[CMP3:%.*]] = icmp slt i32 [[INC7]], 10 +; CHECK-NEXT: br i1 [[CMP3]], label %[[FOR_BODY]], label %[[FOR_END8:.*]] +; CHECK: [[FOR_END8]]: +; CHECK-NEXT: [[SUM2_0_LCSSA:%.*]] = phi i32 [ [[ADD5]], %[[FOR_INC6]] ] +; CHECK-NEXT: [[SUM1_0_LCSSA:%.*]] = phi i32 [ [[ADD]], %[[FOR_INC6]] ] +; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[SUM1_0_LCSSA]], [[SUM2_0_LCSSA]] +; CHECK-NEXT: ret i32 [[TMP0]] +; +entry: + br label %for.body + +for.body: ; preds = %entry, %for.inc + %sum1.02 = phi i32 [ 0, %entry ], [ %add, %for.inc ] + %i.01 = phi i32 [ 0, %entry ], [ %inc, %for.inc ] + %add = add nsw i32 %sum1.02, %i.01 + br label %for.inc + +for.inc: ; preds = %for.body + %inc = add nsw i32 %i.01, 1 + %cmp = icmp slt i32 %inc, 10 + br i1 %cmp, label %for.body, label %for.end + +for.end: ; preds = %for.inc + %sum1.0.lcssa = phi i32 [ %add, %for.inc ] + br label %for.body4 + +for.body4: ; preds = %for.end, %for.inc6 + %i1.04 = phi i32 [ 0, %for.end ], [ %inc7, %for.inc6 ] + %sum2.03 = phi i32 [ 0, %for.end ], [ %add5, %for.inc6 ] + %mul = mul nsw i32 %i1.04, %i1.04 + %add5 = add nsw i32 %sum2.03, %mul + br label %for.inc6 + +for.inc6: ; preds = %for.body4 + %inc7 = add nsw i32 %i1.04, 1 + %cmp3 = icmp slt i32 %inc7, 10 + br i1 %cmp3, label %for.body4, label %for.end8 + +for.end8: ; preds = %for.inc6 + %sum2.0.lcssa = phi i32 [ %add5, %for.inc6 ] + %0 = add i32 %sum1.0.lcssa, %sum2.0.lcssa + ret i32 %0 +} +