diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 807be7e1003c0..ba448e46913ac 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -312,6 +312,19 @@ struct ForallLowering : public OpRewritePattern { } // namespace +static void propagateLoopAttrs(Operation *scfOp, Operation *brOp) { + // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the + // llvm.loop_annotation attribute. + // LLVM requires the loop metadata to be attached on the "latch" block. Which + // is the back-edge to the header block (conditionBlock) + SmallVector llvmAttrs; + llvm::copy_if(scfOp->getAttrs(), std::back_inserter(llvmAttrs), + [](auto attr) { + return isa(attr.getValue().getDialect()); + }); + brOp->setDiscardableAttrs(llvmAttrs); +} + LogicalResult ForLowering::matchAndRewrite(ForOp forOp, PatternRewriter &rewriter) const { Location loc = forOp.getLoc(); @@ -350,17 +363,7 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, auto branchOp = cf::BranchOp::create(rewriter, loc, conditionBlock, loopCarried); - // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the - // llvm.loop_annotation attribute. - // LLVM requires the loop metadata to be attached on the "latch" block. Which - // is the back-edge to the header block (conditionBlock) - SmallVector llvmAttrs; - llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs), - [](auto attr) { - return isa(attr.getValue().getDialect()); - }); - branchOp->setDiscardableAttrs(llvmAttrs); - + propagateLoopAttrs(forOp, branchOp); rewriter.eraseOp(terminator); // Compute loop bounds before branching to the condition. @@ -589,9 +592,10 @@ LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp, rewriter.setInsertionPointToEnd(after); auto yieldOp = cast(after->getTerminator()); - rewriter.replaceOpWithNewOp(yieldOp, before, - yieldOp.getResults()); + auto latch = rewriter.replaceOpWithNewOp(yieldOp, before, + yieldOp.getResults()); + propagateLoopAttrs(whileOp, latch); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. rewriter.replaceOp(whileOp, args); @@ -631,10 +635,11 @@ DoWhileLowering::matchAndRewrite(WhileOp whileOp, // Loop around the "before" region based on condition. rewriter.setInsertionPointToEnd(before); auto condOp = cast(before->getTerminator()); - cf::CondBranchOp::create(rewriter, condOp.getLoc(), condOp.getCondition(), - before, condOp.getArgs(), continuation, - ValueRange()); + auto latch = cf::CondBranchOp::create( + rewriter, condOp.getLoc(), condOp.getCondition(), before, + condOp.getArgs(), continuation, ValueRange()); + propagateLoopAttrs(whileOp, latch); // Replace the op with values "yielded" from the "before" region, which are // visible by dominance. rewriter.replaceOp(whileOp, condOp.getArgs()); diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index e6fdb7ab5ecd8..ef0fa083a021a 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -708,4 +708,45 @@ func.func @simple_std_for_loops_annotation(%arg0 : index, %arg1 : index, %arg2 : } {llvm.loop_annotation = #full_unroll} } {llvm.loop_annotation = #no_unroll} return -} \ No newline at end of file +} + +// ----- + +// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll +// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation +// CHECK: func @simple_while_loops_annotation +// CHECK: cf.br +// CHECK: cf.cond_br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]} +// CHECK: return +#no_unroll = #llvm.loop_annotation> +func.func @simple_while_loops_annotation(%arg0 : i1) { + scf.while : () -> () { + scf.condition(%arg0) + } do { + scf.yield + } attributes {llvm.loop_annotation = #no_unroll} + return +} + +// ----- + +// CHECK: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll +// CHECK: #[[NO_UNROLL:.*]] = #llvm.loop_annotation +// CHECK: func @do_while_loops_annotation +// CHECK: cf.br +// CHECK: cf.cond_br +// CHECK: cf.br {{.*}} {llvm.loop_annotation = #[[NO_UNROLL]]} +// CHECK: return +#no_unroll = #llvm.loop_annotation> +func.func @do_while_loops_annotation() { + %c0_i32 = arith.constant 0 : i32 + scf.while (%arg2 = %c0_i32) : (i32) -> (i32) { + %0 = "test.make_condition"() : () -> i1 + scf.condition(%0) %c0_i32 : i32 + } do { + ^bb0(%arg2: i32): + scf.yield %c0_i32: i32 + } attributes {llvm.loop_annotation = #no_unroll} + return +} +