@@ -303,8 +303,7 @@ mlir::LogicalResult CIRGenFunction::buildSimpleStmt(const Stmt *S,
303
303
304
304
case Stmt::CaseStmtClass:
305
305
case Stmt::DefaultStmtClass:
306
- assert (0 &&
307
- " Should not get here, currently handled directly from SwitchStmt" );
306
+ return buildSwitchCase (cast<SwitchCase>(*S));
308
307
break ;
309
308
310
309
case Stmt::BreakStmtClass:
@@ -715,14 +714,19 @@ CIRGenFunction::buildDefaultStmt(const DefaultStmt &S, mlir::Type condType,
715
714
return buildCaseDefaultCascade (&S, condType, caseAttrs);
716
715
}
717
716
718
- mlir::LogicalResult
719
- CIRGenFunction::buildSwitchCase (const SwitchCase &S, mlir::Type condType,
720
- SmallVector<mlir::Attribute, 4 > &caseAttrs) {
717
+ mlir::LogicalResult CIRGenFunction::buildSwitchCase (const SwitchCase &S) {
718
+ assert (!caseAttrsStack.empty () &&
719
+ " build switch case without seeting case attrs" );
720
+ assert (!condTypeStack.empty () &&
721
+ " build switch case without specifying the type of the condition" );
722
+
721
723
if (S.getStmtClass () == Stmt::CaseStmtClass)
722
- return buildCaseStmt (cast<CaseStmt>(S), condType, caseAttrs);
724
+ return buildCaseStmt (cast<CaseStmt>(S), condTypeStack.back (),
725
+ caseAttrsStack.back ());
723
726
724
727
if (S.getStmtClass () == Stmt::DefaultStmtClass)
725
- return buildDefaultStmt (cast<DefaultStmt>(S), condType, caseAttrs);
728
+ return buildDefaultStmt (cast<DefaultStmt>(S), condTypeStack.back (),
729
+ caseAttrsStack.back ());
726
730
727
731
llvm_unreachable (" expect case or default stmt" );
728
732
}
@@ -987,15 +991,13 @@ mlir::LogicalResult CIRGenFunction::buildWhileStmt(const WhileStmt &S) {
987
991
return mlir::success ();
988
992
}
989
993
990
- mlir::LogicalResult CIRGenFunction::buildSwitchBody (
991
- const Stmt *S, mlir::Type condType,
992
- llvm::SmallVector<mlir::Attribute, 4 > &caseAttrs) {
994
+ mlir::LogicalResult CIRGenFunction::buildSwitchBody (const Stmt *S) {
993
995
if (auto *compoundStmt = dyn_cast<CompoundStmt>(S)) {
994
996
mlir::Block *lastCaseBlock = nullptr ;
995
997
auto res = mlir::success ();
996
998
for (auto *c : compoundStmt->body ()) {
997
999
if (auto *switchCase = dyn_cast<SwitchCase>(c)) {
998
- res = buildSwitchCase (*switchCase, condType, caseAttrs );
1000
+ res = buildSwitchCase (*switchCase);
999
1001
} else if (lastCaseBlock) {
1000
1002
// This means it's a random stmt following up a case, just
1001
1003
// emit it as part of previous known case.
@@ -1045,12 +1047,16 @@ mlir::LogicalResult CIRGenFunction::buildSwitchStmt(const SwitchStmt &S) {
1045
1047
[&](mlir::OpBuilder &b, mlir::Location loc, mlir::OperationState &os) {
1046
1048
currLexScope->setAsSwitch ();
1047
1049
1048
- llvm::SmallVector<mlir::Attribute, 4 > caseAttrs;
1050
+ caseAttrsStack.push_back ({});
1051
+ condTypeStack.push_back (condV.getType ());
1049
1052
1050
- res = buildSwitchBody (S.getBody (), condV. getType (), caseAttrs );
1053
+ res = buildSwitchBody (S.getBody ());
1051
1054
1052
1055
os.addRegions (currLexScope->getSwitchRegions ());
1053
- os.addAttribute (" cases" , builder.getArrayAttr (caseAttrs));
1056
+ os.addAttribute (" cases" , builder.getArrayAttr (caseAttrsStack.back ()));
1057
+
1058
+ caseAttrsStack.pop_back ();
1059
+ condTypeStack.pop_back ();
1054
1060
});
1055
1061
1056
1062
if (res.failed ())
0 commit comments