diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp index 01be972711b1..b7afcb9019f1 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp @@ -25,7 +25,6 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/LowerToMLIR.h" #include "llvm/ADT/TypeSwitch.h" -#include "llvm/IR/Module.h" using namespace cir; using namespace llvm; @@ -483,6 +482,19 @@ class CIRWhileOpLowering : public mlir::OpConversionPattern { return; for (auto continueOp : continues) { + bool nested = false; + // When there is another loop between this WhileOp and the ContinueOp, + // we shouldn't change that loop instead. + for (mlir::Operation *parent = continueOp->getParentOp(); + parent != whileOp; parent = parent->getParentOp()) { + if (isa(parent)) { + nested = true; + break; + } + } + if (nested) + continue; + // When the ContinueOp is under an IfOp, a direct replacement of // `scf.yield` won't work: the yield would jump out of that IfOp instead. // We might need to change the WhileOp itself to achieve the same effect. diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 71752bb19c25..5800773e715b 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -1552,7 +1552,11 @@ void ConvertCIRToMLIRPass::runOnOperation() { mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect, mlir::math::MathDialect, mlir::vector::VectorDialect, mlir::LLVM::LLVMDialect>(); - target.addIllegalDialect(); + // We cannot mark cir dialect as illegal before conversion. + // The conversion of WhileOp relies on partially preserving operations from + // cir dialect, for example the `cir.continue`. If we marked cir as illegal + // here, then MLIR would think any remaining `cir.continue` indicates a + // failure, which is not what we want. if (failed(applyPartialConversion(module, target, std::move(patterns)))) signalPassFailure(); @@ -1616,8 +1620,9 @@ mlir::ModuleOp lowerFromCIRToMLIR(mlir::ModuleOp theModule, auto result = !mlir::failed(pm.run(theModule)); if (!result) - report_fatal_error( - "The pass manager failed to lower CIR to MLIR standard dialects!"); + theModule.dump(), + report_fatal_error( + "The pass manager failed to lower CIR to MLIR standard dialects!"); // Now that we ran all the lowering passes, verify the final output. if (theModule.verify().failed()) report_fatal_error( diff --git a/clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp b/clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp index 07ebcfab8ac9..fa13e11b80fb 100644 --- a/clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp +++ b/clang/test/CIR/Lowering/ThroughMLIR/while-with-continue.cpp @@ -68,3 +68,38 @@ void while_continue_2() { // CHECK: scf.yield // CHECK: } } + +void while_continue_nested() { + int i = 0; + while (i < 10) { + while (true) { + continue; + i--; + } + i++; + } + // The continue will only work on the inner while. + + // CHECK: scf.while : () -> () { + // CHECK: %[[IV:.+]] = memref.load %alloca[] + // CHECK: %[[TEN:.+]] = arith.constant 10 + // CHECK: %[[LT:.+]] = arith.cmpi slt, %[[IV]], %[[TEN]] + // CHECK: scf.condition(%[[LT]]) + // CHECK: } do { + // CHECK: memref.alloca_scope { + // CHECK: memref.alloca_scope { + // CHECK: scf.while : () -> () { + // CHECK: %[[TRUE:.+]] = arith.constant true + // CHECK: scf.condition(%[[TRUE]]) + // CHECK: } do { + // CHECK: scf.yield + // CHECK: } + // CHECK: } + // CHECK: %[[IV2:.+]] = memref.load %alloca[] + // CHECK: %[[ONE:.+]] = arith.constant 1 + // CHECK: %[[ADD:.+]] = arith.addi %[[IV2]], %[[ONE]] + // CHECK: memref.store %[[ADD]], %alloca[] + // CHECK: } + // CHECK: scf.yield + // CHECK: } +}