diff --git a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp index 9eb8a289d7d65..82bbd373ce45a 100644 --- a/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp +++ b/mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -369,9 +370,18 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp, auto comparison = rewriter.create( loc, arith::CmpIPredicate::slt, iv, upperBound); - rewriter.create(loc, comparison, firstBodyBlock, - ArrayRef(), endBlock, - ArrayRef()); + auto condBranchOp = rewriter.create( + loc, comparison, firstBodyBlock, ArrayRef(), endBlock, + ArrayRef()); + + // Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the + // llvm.loop_annotation attribute. + SmallVector llvmAttrs; + llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs), + [](auto attr) { + return isa(attr.getValue().getDialect()); + }); + condBranchOp->setDiscardableAttrs(llvmAttrs); // The result of the loop operation is the values of the condition block // arguments except the induction variable on the last iteration. rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front()); diff --git a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir index caf17bc91ced2..9ea0093eff786 100644 --- a/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir +++ b/mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -convert-scf-to-cf -split-input-file %s | FileCheck %s // CHECK-LABEL: func @simple_std_for_loop(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-NEXT: cf.br ^bb1(%{{.*}} : index) @@ -675,3 +675,25 @@ func.func @forall(%num_threads: index) { } return } + +// ----- + +// CHECK: #loop_unroll = #llvm.loop_unroll +// CHECK-NEXT: #loop_unroll1 = #llvm.loop_unroll +// CHECK-NEXT: #[[NO_UNROLL:.*]] = #llvm.loop_annotation +// CHECK-NEXT: #[[FULL_UNROLL:.*]] = #llvm.loop_annotation +// CHECK: cf.cond_br %{{.*}}, ^bb2, ^bb6 {llvm.loop_annotation = #[[NO_UNROLL]]} +// CHECK: cf.cond_br %{{.*}}, ^bb4, ^bb5 {llvm.loop_annotation = #[[FULL_UNROLL]]} +#no_unroll = #llvm.loop_annotation> +#full_unroll = #llvm.loop_annotation> +func.func @simple_std_for_loops_annotation(%arg0 : index, %arg1 : index, %arg2 : index) { + scf.for %i0 = %arg0 to %arg1 step %arg2 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + scf.for %i1 = %c0 to %c4 step %c1 { + %c1_0 = arith.constant 1 : index + } {llvm.loop_annotation = #full_unroll} + } {llvm.loop_annotation = #no_unroll} + return +} \ No newline at end of file