From 9db82cbe624038b44bf3353e186035e5a6cd8502 Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 01/13] [AST] Avoid walking a few more expressions in WalkToVarDecls Avoid walking TapExprs, SingleValueStmtExprs, and key paths. The latter is important since they can contain invalid VarDecls that will no longer be visited by the ASTWalker after key path resolution, so we don't want to create case body vars for them. --- lib/AST/Pattern.cpp | 10 +++++++--- .../IDE/crashers_fixed/2a9258ff5360857f.swift | 6 ++++++ .../22119c306663e633.swift | 2 +- .../7d3e143ce48c791.swift | 2 +- .../9ae5dcaffa1a80.swift | 2 +- 5 files changed, 16 insertions(+), 6 deletions(-) create mode 100644 validation-test/IDE/crashers_fixed/2a9258ff5360857f.swift rename validation-test/{compiler_crashers_2 => compiler_crashers_2_fixed}/22119c306663e633.swift (74%) rename validation-test/{compiler_crashers_2 => compiler_crashers_2_fixed}/7d3e143ce48c791.swift (82%) rename validation-test/{compiler_crashers_2 => compiler_crashers_2_fixed}/9ae5dcaffa1a80.swift (82%) diff --git a/lib/AST/Pattern.cpp b/lib/AST/Pattern.cpp index 1ceb418fa1fdb..8651f0862fe29 100644 --- a/lib/AST/Pattern.cpp +++ b/lib/AST/Pattern.cpp @@ -208,10 +208,14 @@ namespace { return Action::Continue(P); } - // Only walk into an expression insofar as it doesn't open a new scope - - // that is, don't walk into a closure body. PreWalkResult walkToExprPre(Expr *E) override { - if (isa(E)) { + // Only walk into an expression insofar as it doesn't open a new scope - + // that is, don't walk into a closure body, TapExpr, or + // SingleValueStmtExpr. Also don't walk into key paths since any nested + // VarDecls are invalid there, and after being diagnosed by key path + // resolution the ASTWalker won't visit them. + if (isa(E) || isa(E) || + isa(E) || isa(E)) { return Action::SkipNode(E); } return Action::Continue(E); diff --git a/validation-test/IDE/crashers_fixed/2a9258ff5360857f.swift b/validation-test/IDE/crashers_fixed/2a9258ff5360857f.swift new file mode 100644 index 0000000000000..4437c5b9745c2 --- /dev/null +++ b/validation-test/IDE/crashers_fixed/2a9258ff5360857f.swift @@ -0,0 +1,6 @@ +// {"kind":"complete","original":"19b878fc","signature":"swift::NamingPatternRequest::evaluate(swift::Evaluator&, swift::VarDecl*) const","signatureAssert":"Assertion failed: (foundVarDecl && \"VarDecl not declared in its parent?\"), function evaluate"} +// RUN: %target-swift-ide-test -code-completion -batch-code-completion -skip-filecheck -code-completion-diagnostics -source-filename %s +{ + guard let \a = answer + a + #^^# diff --git a/validation-test/compiler_crashers_2/22119c306663e633.swift b/validation-test/compiler_crashers_2_fixed/22119c306663e633.swift similarity index 74% rename from validation-test/compiler_crashers_2/22119c306663e633.swift rename to validation-test/compiler_crashers_2_fixed/22119c306663e633.swift index 529049852cb1c..719b833fb260d 100644 --- a/validation-test/compiler_crashers_2/22119c306663e633.swift +++ b/validation-test/compiler_crashers_2_fixed/22119c306663e633.swift @@ -1,5 +1,5 @@ // {"signature":"swift::NamingPatternRequest::evaluate(swift::Evaluator&, swift::VarDecl*) const"} -// RUN: not --crash %target-swift-frontend -typecheck %s +// RUN: not %target-swift-frontend -typecheck %s { let a = $0.c ; switch a { diff --git a/validation-test/compiler_crashers_2/7d3e143ce48c791.swift b/validation-test/compiler_crashers_2_fixed/7d3e143ce48c791.swift similarity index 82% rename from validation-test/compiler_crashers_2/7d3e143ce48c791.swift rename to validation-test/compiler_crashers_2_fixed/7d3e143ce48c791.swift index 84ce24fe547d3..3f26f7f33addc 100644 --- a/validation-test/compiler_crashers_2/7d3e143ce48c791.swift +++ b/validation-test/compiler_crashers_2_fixed/7d3e143ce48c791.swift @@ -1,5 +1,5 @@ // {"kind":"typecheck","signature":"swift::NamingPatternRequest::evaluate(swift::Evaluator&, swift::VarDecl*) const","signatureAssert":"Assertion failed: (foundVarDecl && \"VarDecl not declared in its parent?\"), function evaluate"} -// RUN: not --crash %target-swift-frontend -typecheck %s +// RUN: not %target-swift-frontend -typecheck %s { if case.(let \ a) { a diff --git a/validation-test/compiler_crashers_2/9ae5dcaffa1a80.swift b/validation-test/compiler_crashers_2_fixed/9ae5dcaffa1a80.swift similarity index 82% rename from validation-test/compiler_crashers_2/9ae5dcaffa1a80.swift rename to validation-test/compiler_crashers_2_fixed/9ae5dcaffa1a80.swift index 51ce39bd172be..5c064b5476976 100644 --- a/validation-test/compiler_crashers_2/9ae5dcaffa1a80.swift +++ b/validation-test/compiler_crashers_2_fixed/9ae5dcaffa1a80.swift @@ -1,5 +1,5 @@ // {"signature":"void (anonymous namespace)::StmtChecker::checkSiblingCaseStmts(swift::CaseStmt* const*, swift::CaseStmt* const*, swift::CaseParentKind, bool&, swift::Type)"} -// RUN: not --crash %target-swift-frontend -typecheck %s +// RUN: not %target-swift-frontend -typecheck %s enum a func b(c : a) -> Int { switch c { From 805b6d9c39da04206d170f8f10e8fb8722b47d57 Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 02/13] [CS] Remove some dead code in `visitCaseItemPattern` We don't wire up the parent variables until after type-checking, and `recordInferredSwitchCasePatternVars` already handles joining the pattern types, so we can remove this. --- lib/Sema/CSSyntacticElement.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index e94ef29160b65..0021e83f565e8 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -738,23 +738,6 @@ class SyntacticElementConstraintGenerator LocatorPathElt::ContextualType(context.purpose)}); cs.addConstraint(ConstraintKind::Equal, context.getType(), patternType, loc); - - // For any pattern variable that has a parent variable (i.e., another - // pattern variable with the same name in the same case), require that - // the types be equivalent. - pattern->forEachNode([&](Pattern *pattern) { - auto namedPattern = dyn_cast(pattern); - if (!namedPattern) - return; - - auto var = namedPattern->getDecl(); - if (auto parentVar = var->getParentVarDecl()) { - cs.addConstraint( - ConstraintKind::Equal, cs.getType(parentVar), cs.getType(var), - cs.getConstraintLocator( - locator, LocatorPathElt::PatternMatch(namedPattern))); - } - }); } void visitPatternBinding(PatternBindingDecl *patternBinding, From 63286ae3f09b67710ca37dda8a18a5d47e694c3f Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 03/13] [SILGen] Fix case emission when there are no case body vars We want to call the `bodyEmitter`, since that has the extra logic necessary to handle `do-catch` statements. Previously this didn't cause any issues since `hasCaseBodyVariables` would have always been true for parsed `do-catch`s, but I'm planning on changing that. --- lib/SILGen/SILGenPattern.cpp | 2 +- test/SILGen/pr-84149.swift | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 test/SILGen/pr-84149.swift diff --git a/lib/SILGen/SILGenPattern.cpp b/lib/SILGen/SILGenPattern.cpp index a201c94989b1d..10bffeda60a91 100644 --- a/lib/SILGen/SILGenPattern.cpp +++ b/lib/SILGen/SILGenPattern.cpp @@ -2958,7 +2958,7 @@ void PatternMatchEmission::emitSharedCaseBlocks( SWIFT_DEFER { assert(SGF.getCleanupsDepth() == PatternMatchStmtDepth); }; if (!caseBlock->hasCaseBodyVariables()) { - emitCaseBody(caseBlock); + bodyEmitter(caseBlock); continue; } diff --git a/test/SILGen/pr-84149.swift b/test/SILGen/pr-84149.swift new file mode 100644 index 0000000000000..c35a4cd0be188 --- /dev/null +++ b/test/SILGen/pr-84149.swift @@ -0,0 +1,15 @@ +// RUN: %target-swift-emit-silgen %s -verify + +enum E : Error { + case a(Int), b(Int) +} + +func bar() throws {} + +// Make sure we can correctly emit this without crashing. +func foo() throws { + do { + try bar() + } catch E.a, E.b { + } +} From 3e97d729e863db69eca56a7c6873697dcd4b09f5 Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 04/13] [AST] Turn `CaseBodyVariables` into an `ArrayRef` We don't need to store a `MutableArrayRef`. --- include/swift/AST/Stmt.h | 26 ++++---------------------- lib/AST/Stmt.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 26 deletions(-) diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 01e66e1eb1324..95e80b0bc49f5 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -1222,12 +1222,12 @@ class CaseStmt final llvm::PointerIntPair BodyAndHasFallthrough; - std::optional> CaseBodyVariables; + std::optional> CaseBodyVariables; CaseStmt(CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc, ArrayRef CaseLabelItems, SourceLoc UnknownAttrLoc, SourceLoc ItemTerminatorLoc, BraceStmt *Body, - std::optional> CaseBodyVariables, + std::optional> CaseBodyVariables, std::optional Implicit, NullablePtr fallthroughStmt); @@ -1248,7 +1248,7 @@ class CaseStmt final create(ASTContext &C, CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc, ArrayRef CaseLabelItems, SourceLoc UnknownAttrLoc, SourceLoc ItemTerminatorLoc, BraceStmt *Body, - std::optional> CaseBodyVariables, + std::optional> CaseBodyVariables, std::optional Implicit = std::nullopt, NullablePtr fallthroughStmt = nullptr); @@ -1350,32 +1350,14 @@ class CaseStmt final /// where one wants a non-asserting version, \see /// getCaseBodyVariablesOrEmptyArray. ArrayRef getCaseBodyVariables() const { - ArrayRef a = *CaseBodyVariables; - return a; + return *CaseBodyVariables; } bool hasCaseBodyVariables() const { return CaseBodyVariables.has_value(); } - /// Return an MutableArrayRef containing the case body variables of this - /// CaseStmt. - /// - /// Asserts if case body variables was not explicitly initialized. In contexts - /// where one wants a non-asserting version, \see - /// getCaseBodyVariablesOrEmptyArray. - MutableArrayRef getCaseBodyVariables() { - return *CaseBodyVariables; - } - ArrayRef getCaseBodyVariablesOrEmptyArray() const { if (!CaseBodyVariables) return ArrayRef(); - ArrayRef a = *CaseBodyVariables; - return a; - } - - MutableArrayRef getCaseBodyVariablesOrEmptyArray() { - if (!CaseBodyVariables) - return MutableArrayRef(); return *CaseBodyVariables; } diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index c570fdd7a5833..0ffcc69d803e3 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -753,7 +753,7 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, ArrayRef caseLabelItems, SourceLoc unknownAttrLoc, SourceLoc itemTerminatorLoc, BraceStmt *body, - std::optional> caseBodyVariables, + std::optional> caseBodyVariables, std::optional implicit, NullablePtr fallthroughStmt) : Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, itemIntroducerLoc)), @@ -781,13 +781,13 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, new (&items[i]) CaseLabelItem(caseLabelItems[i]); items[i].getPattern()->markOwnedByStatement(this); } - for (auto *vd : caseBodyVariables.value_or(MutableArrayRef())) { + for (auto *vd : getCaseBodyVariablesOrEmptyArray()) { vd->setParentPatternStmt(this); } } namespace { -static MutableArrayRef +static ArrayRef getCaseVarDecls(ASTContext &ctx, ArrayRef labelItems) { // Grab the first case label item pattern and use it to initialize the case // body var decls. @@ -871,7 +871,7 @@ CaseStmt * CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind, SourceLoc caseLoc, ArrayRef caseLabelItems, SourceLoc unknownAttrLoc, SourceLoc colonLoc, BraceStmt *body, - std::optional> caseVarDecls, + std::optional> caseVarDecls, std::optional implicit, NullablePtr fallthroughStmt) { void *mem = From 84befd43abe60776cdc5e3b57c6047c858b46205 Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 05/13] [AST] Make case body variables for CaseStmt non-optional We don't really care about the distinction between empty and nil here. --- include/swift/AST/Stmt.h | 23 ++++--------------- lib/AST/ASTScopeLookup.cpp | 6 ++--- lib/AST/ASTVerifier.cpp | 2 +- lib/AST/Decl.cpp | 4 ++-- lib/AST/Stmt.cpp | 8 +++---- lib/SILGen/SILGenPattern.cpp | 4 ++-- lib/Sema/BuilderTransform.cpp | 2 +- lib/Sema/CSSyntacticElement.cpp | 4 ++-- .../DerivedConformanceCodable.cpp | 7 +++--- .../DerivedConformanceComparable.cpp | 4 ++-- .../DerivedConformanceEquatableHashable.cpp | 8 +++---- lib/Sema/MiscDiagnostics.cpp | 2 +- lib/Sema/TypeCheckStmt.cpp | 10 ++++---- 13 files changed, 34 insertions(+), 50 deletions(-) diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 95e80b0bc49f5..274d93b4ea6a1 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -1222,13 +1222,12 @@ class CaseStmt final llvm::PointerIntPair BodyAndHasFallthrough; - std::optional> CaseBodyVariables; + ArrayRef CaseBodyVariables; CaseStmt(CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc, ArrayRef CaseLabelItems, SourceLoc UnknownAttrLoc, SourceLoc ItemTerminatorLoc, BraceStmt *Body, - std::optional> CaseBodyVariables, - std::optional Implicit, + ArrayRef CaseBodyVariables, std::optional Implicit, NullablePtr fallthroughStmt); public: @@ -1248,7 +1247,7 @@ class CaseStmt final create(ASTContext &C, CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc, ArrayRef CaseLabelItems, SourceLoc UnknownAttrLoc, SourceLoc ItemTerminatorLoc, BraceStmt *Body, - std::optional> CaseBodyVariables, + ArrayRef CaseBodyVariables, std::optional Implicit = std::nullopt, NullablePtr fallthroughStmt = nullptr); @@ -1293,7 +1292,7 @@ class CaseStmt final void setBody(BraceStmt *body) { BodyAndHasFallthrough.setPointer(body); } /// True if the case block declares any patterns with local variable bindings. - bool hasBoundDecls() const { return CaseBodyVariables.has_value(); } + bool hasCaseBodyVariables() const { return !CaseBodyVariables.empty(); } /// Get the source location of the 'case', 'default', or 'catch' of the first /// label. @@ -1345,20 +1344,8 @@ class CaseStmt final } /// Return an ArrayRef containing the case body variables of this CaseStmt. - /// - /// Asserts if case body variables was not explicitly initialized. In contexts - /// where one wants a non-asserting version, \see - /// getCaseBodyVariablesOrEmptyArray. ArrayRef getCaseBodyVariables() const { - return *CaseBodyVariables; - } - - bool hasCaseBodyVariables() const { return CaseBodyVariables.has_value(); } - - ArrayRef getCaseBodyVariablesOrEmptyArray() const { - if (!CaseBodyVariables) - return ArrayRef(); - return *CaseBodyVariables; + return CaseBodyVariables; } /// Find the next case statement within the same 'switch' or 'do-catch', diff --git a/lib/AST/ASTScopeLookup.cpp b/lib/AST/ASTScopeLookup.cpp index 2c1adbac6600c..80e17ecbb2a3a 100644 --- a/lib/AST/ASTScopeLookup.cpp +++ b/lib/AST/ASTScopeLookup.cpp @@ -378,10 +378,10 @@ bool CaseLabelItemScope::lookupLocalsOrMembers(DeclConsumer consumer) const { } bool CaseStmtBodyScope::lookupLocalsOrMembers(DeclConsumer consumer) const { - for (auto *var : stmt->getCaseBodyVariablesOrEmptyArray()) + for (auto *var : stmt->getCaseBodyVariables()) { if (consumer.consume({var})) - return true; - + return true; + } return false; } diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index f7e74dfa54041..039f5db5bd67b 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -2795,7 +2795,7 @@ class Verifier : public ASTWalker { // guarantee that all case label items bind corresponding patterns and // the case body var decls of a case stmt are created from the var decls // of the first case label items. - if (!caseStmt->hasBoundDecls()) { + if (!caseStmt->hasCaseBodyVariables()) { Out << "parent CaseStmt of VarDecl does not have any case body " "decls?!\n"; abort(); diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index 0b3f42a541144..c470d81eee8f8 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -8164,7 +8164,7 @@ findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) { auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * { // Check if inputVD is in our case body var decls if we have any. If we do, // treat its pattern as our first case label item pattern. - for (auto *vd : cs->getCaseBodyVariablesOrEmptyArray()) { + for (auto *vd : cs->getCaseBodyVariables()) { if (vd == inputVD) { return cs->getMutableCaseLabelItems().front().getPattern(); } @@ -8345,7 +8345,7 @@ bool VarDecl::isCaseBodyVariable() const { auto *caseStmt = dyn_cast_or_null(getRecursiveParentPatternStmt()); if (!caseStmt) return false; - return llvm::any_of(caseStmt->getCaseBodyVariablesOrEmptyArray(), + return llvm::any_of(caseStmt->getCaseBodyVariables(), [&](VarDecl *vd) { return vd == this; }); } diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index 0ffcc69d803e3..f680663caa093 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -752,8 +752,7 @@ SourceLoc CaseLabelItem::getEndLoc() const { CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, ArrayRef caseLabelItems, SourceLoc unknownAttrLoc, SourceLoc itemTerminatorLoc, - BraceStmt *body, - std::optional> caseBodyVariables, + BraceStmt *body, ArrayRef caseBodyVariables, std::optional implicit, NullablePtr fallthroughStmt) : Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, itemIntroducerLoc)), @@ -781,7 +780,7 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, new (&items[i]) CaseLabelItem(caseLabelItems[i]); items[i].getPattern()->markOwnedByStatement(this); } - for (auto *vd : getCaseBodyVariablesOrEmptyArray()) { + for (auto *vd : getCaseBodyVariables()) { vd->setParentPatternStmt(this); } } @@ -871,8 +870,7 @@ CaseStmt * CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind, SourceLoc caseLoc, ArrayRef caseLabelItems, SourceLoc unknownAttrLoc, SourceLoc colonLoc, BraceStmt *body, - std::optional> caseVarDecls, - std::optional implicit, + ArrayRef caseVarDecls, std::optional implicit, NullablePtr fallthroughStmt) { void *mem = ctx.Allocate(totalSizeToAlloc( diff --git a/lib/SILGen/SILGenPattern.cpp b/lib/SILGen/SILGenPattern.cpp index 10bffeda60a91..0b07d085ef202 100644 --- a/lib/SILGen/SILGenPattern.cpp +++ b/lib/SILGen/SILGenPattern.cpp @@ -2836,7 +2836,7 @@ void PatternMatchEmission::initSharedCaseBlockDest(CaseStmt *caseBlock, result.first->second.first = block; // Add args for any pattern variables if we have any. - for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) { + for (auto *vd : caseBlock->getCaseBodyVariables()) { if (!vd->hasName()) continue; @@ -2867,7 +2867,7 @@ void PatternMatchEmission::emitAddressOnlyAllocations() { // If we have a shared case with bound decls, setup the arguments for the // shared block by emitting the temporary allocation used for the arguments // of the shared block. - for (auto *vd : caseBlock->getCaseBodyVariablesOrEmptyArray()) { + for (auto *vd : caseBlock->getCaseBodyVariables()) { if (!vd->hasName()) continue; diff --git a/lib/Sema/BuilderTransform.cpp b/lib/Sema/BuilderTransform.cpp index 024aba53e70b9..2cdd76f45cac0 100644 --- a/lib/Sema/BuilderTransform.cpp +++ b/lib/Sema/BuilderTransform.cpp @@ -650,7 +650,7 @@ class ResultBuilderTransform caseStmt->getCaseLabelItems(), caseStmt->hasUnknownAttr() ? caseStmt->getStartLoc() : SourceLoc(), caseStmt->getItemTerminatorLoc(), cloneBraceWith(body, newBody), - caseStmt->getCaseBodyVariablesOrEmptyArray(), caseStmt->isImplicit(), + caseStmt->getCaseBodyVariables(), caseStmt->isImplicit(), caseStmt->getFallthroughStmt()); return std::make_pair(caseVarRef.get(), newCase); diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index 0021e83f565e8..e5bac5ebacd21 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -1378,7 +1378,7 @@ class SyntacticElementConstraintGenerator pattern->forEachVariable([&](VarDecl *var) { recordVar(var); }); } - for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) { + for (auto bodyVar : caseStmt->getCaseBodyVariables()) { if (!bodyVar->hasName()) continue; @@ -2007,7 +2007,7 @@ class SyntacticElementSolutionApplication bindSwitchCasePatternVars(context.getAsDeclContext(), caseStmt); - for (auto *expected : caseStmt->getCaseBodyVariablesOrEmptyArray()) { + for (auto *expected : caseStmt->getCaseBodyVariables()) { assert(expected->hasName()); auto prev = expected->getParentVarDecl(); auto type = solution.getResolvedType(prev)->mapTypeOutOfContext(); diff --git a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp index 325fe35f1b086..99de54d096a65 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp @@ -946,7 +946,7 @@ createEnumSwitch(ASTContext &C, DeclContext *DC, Expr *expr, EnumDecl *enumDecl, // .(let a0, let a1, ...) SmallVector payloadVars; Pattern *subpattern = nullptr; - std::optional> caseBodyVarDecls; + ArrayRef caseBodyVarDecls; if (createSubpattern) { subpattern = DerivedConformance::enumElementPayloadSubpattern( @@ -965,7 +965,7 @@ createEnumSwitch(ASTContext &C, DeclContext *DC, Expr *expr, EnumDecl *enumDecl, vNew->setImplicit(); copy[i] = vNew; } - caseBodyVarDecls.emplace(copy); + caseBodyVarDecls = copy; } } @@ -988,8 +988,7 @@ createEnumSwitch(ASTContext &C, DeclContext *DC, Expr *expr, EnumDecl *enumDecl, auto stmt = CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), labelItem, SourceLoc(), SourceLoc(), caseBody, - /*case body vardecls*/ - createSubpattern ? caseBodyVarDecls : std::nullopt); + caseBodyVarDecls); cases.push_back(stmt); } } diff --git a/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp index c5468e83a70ed..93d823fac3bfd 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp @@ -128,7 +128,7 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v enumType, elt, rhsSubpattern, /*DC*/ ltDecl); auto hasBoundDecls = !lhsPayloadVars.empty(); - std::optional> caseBodyVarDecls; + ArrayRef caseBodyVarDecls; if (hasBoundDecls) { // We allocated a direct copy of our lhs var decls for the case // body. @@ -141,7 +141,7 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v vNew->setImplicit(); copy[i] = vNew; } - caseBodyVarDecls.emplace(copy); + caseBodyVarDecls = copy; } // case (.(let l0, let l1, ...), .(let r0, let r1, ...)) diff --git a/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp index 6ae83dc2f8f6c..fc16fd5cf425d 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp @@ -194,7 +194,7 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl, enumType, elt, rhsSubpattern, /*DC*/ eqDecl); auto hasBoundDecls = !lhsPayloadVars.empty(); - std::optional> caseBodyVarDecls; + ArrayRef caseBodyVarDecls; if (hasBoundDecls) { // We allocated a direct copy of our lhs var decls for the case // body. @@ -207,7 +207,7 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl, vNew->setImplicit(); copy[i] = vNew; } - caseBodyVarDecls.emplace(copy); + caseBodyVarDecls = copy; } // case (.(let l0, let l1, ...), .(let r0, let r1, ...)) @@ -736,7 +736,7 @@ deriveBodyHashable_enum_hasAssociatedValues_hashInto( } auto hasBoundDecls = !payloadVars.empty(); - std::optional> caseBodyVarDecls; + ArrayRef caseBodyVarDecls; if (hasBoundDecls) { auto copy = C.Allocate(payloadVars.size()); for (unsigned i : indices(payloadVars)) { @@ -747,7 +747,7 @@ deriveBodyHashable_enum_hasAssociatedValues_hashInto( vNew->setImplicit(); copy[i] = vNew; } - caseBodyVarDecls.emplace(copy); + caseBodyVarDecls = copy; } auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc()); diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index ab095aa450efe..be5c7ac80ec64 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -3459,7 +3459,7 @@ class VarDeclUsageChecker : public ASTWalker { // Make sure that we setup our case body variables. if (auto *caseStmt = dyn_cast(S)) { - for (auto *vd : caseStmt->getCaseBodyVariablesOrEmptyArray()) { + for (auto *vd : caseStmt->getCaseBodyVariables()) { VarDecls[vd] |= RK_Defined; } } diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 5d76f4508909b..15fa32c241286 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -126,7 +126,7 @@ namespace { // The ASTWalker doesn't walk the case body variables, contextualize them // ourselves. if (auto *CS = dyn_cast(S)) { - for (auto *CaseVar : CS->getCaseBodyVariablesOrEmptyArray()) + for (auto *CaseVar : CS->getCaseBodyVariables()) CaseVar->setDeclContext(ParentDC); } // A few statements store DeclContexts, update them. @@ -341,7 +341,7 @@ namespace { PreWalkResult walkToStmtPre(Stmt *S) override { if (auto caseStmt = dyn_cast(S)) { - for (auto var : caseStmt->getCaseBodyVariablesOrEmptyArray()) + for (auto var : caseStmt->getCaseBodyVariables()) setLocalDiscriminator(var); } return Action::Continue(S); @@ -970,7 +970,7 @@ bool swift::checkFallthroughStmt(FallthroughStmt *FS) { // decls. So if we match against the case body var decls, // transitively we will match all of the other case label items in // the fallthrough destination as well. - auto previousVars = previousBlock->getCaseBodyVariablesOrEmptyArray(); + auto previousVars = previousBlock->getCaseBodyVariables(); for (auto *expected : vars) { bool matched = false; if (!expected->hasName()) @@ -1600,7 +1600,7 @@ class StmtChecker : public StmtVisitor { } // Setup the types of our case body var decls. - for (auto *expected : caseBlock->getCaseBodyVariablesOrEmptyArray()) { + for (auto *expected : caseBlock->getCaseBodyVariables()) { assert(expected->hasName()); auto prev = expected->getParentVarDecl(); if (prev->hasInterfaceType()) @@ -3302,7 +3302,7 @@ void swift::bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *caseStmt) { } // Wire up the case body variables to the latest patterns. - for (auto bodyVar : caseStmt->getCaseBodyVariablesOrEmptyArray()) { + for (auto bodyVar : caseStmt->getCaseBodyVariables()) { recordVar(nullptr, bodyVar); } } From 245e2874ae0ac6f527a30cc71f9aa927fce2552b Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 06/13] [AST] Eagerly wire up VarDecl parents when creating CaseStmt Rather than waiting until type-checking, we can set the parents immediately when we create the CaseStmt. This requires fixing up NamingPatternRequest to look at the recursive parent statement as now the VarDecl may have a variable parent. --- lib/AST/Stmt.cpp | 72 +++++++++++++++++++++++++++----------- lib/Sema/TypeCheckDecl.cpp | 2 +- 2 files changed, 52 insertions(+), 22 deletions(-) diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index f680663caa093..81cd39c9c2579 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -773,33 +773,63 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, MutableArrayRef items{getTrailingObjects(), static_cast(Bits.CaseStmt.NumPatterns)}; - // At the beginning mark all of our var decls as being owned by this - // statement. In the typechecker we wireup the case stmt var decl list since - // we know everything is lined up/typechecked then. for (unsigned i : range(Bits.CaseStmt.NumPatterns)) { new (&items[i]) CaseLabelItem(caseLabelItems[i]); - items[i].getPattern()->markOwnedByStatement(this); - } - for (auto *vd : getCaseBodyVariables()) { - vd->setParentPatternStmt(this); + // Mark the CaseStmt as the parent for any canonical VarDecls in the + // pattern. + items[i].getPattern()->forEachVariable([&](VarDecl *VD) { + if (!VD->getParentVarDecl()) + VD->setParentPatternStmt(this); + }); } } namespace { -static ArrayRef -getCaseVarDecls(ASTContext &ctx, ArrayRef labelItems) { - // Grab the first case label item pattern and use it to initialize the case - // body var decls. - SmallVector tmp; - labelItems.front().getPattern()->collectVariables(tmp); - return ctx.AllocateTransform( - llvm::ArrayRef(tmp), [&](VarDecl *vOld) -> VarDecl * { - auto *vNew = new (ctx) VarDecl( - /*IsStatic*/ false, vOld->getIntroducer(), vOld->getNameLoc(), - vOld->getName(), vOld->getDeclContext()); - vNew->setImplicit(); - return vNew; - }); +/// Produces an array of internal case body variables, and binds all of the +/// pattern variables that occur within the case to their "parent" pattern +/// variables, forming chains of variables with the same name. +/// +/// Given a case such as: +/// \code +/// case .a(let x), .b(let x), .c(let x): +/// \endcode +/// +/// Each case item contains a (different) pattern variable named `x`. This +/// function will set the "parent" variable of the second and third `x` +/// variables to the `x` variable immediately to its left. A fourth `x` will be +/// the body case variable, whose parent will be set to the `x` within the final +/// case item. +static ArrayRef getCaseVarDecls(ASTContext &ctx, + ArrayRef labelItems) { + SmallVector caseVars; + llvm::SmallDenseMap allVars; + + auto foundVar = [&](VarDecl *VD) { + if (!VD->hasName()) + return; + + auto &entry = allVars[VD->getName()]; + if (entry) { + VD->setParentVarDecl(entry); + } else { + auto *caseVar = new (ctx) VarDecl( + /*IsStatic*/ false, VD->getIntroducer(), VD->getNameLoc(), + VD->getName(), VD->getDeclContext()); + caseVar->setImplicit(); + caseVars.push_back(caseVar); + } + entry = VD; + }; + + for (auto &caseItem : labelItems) + caseItem.getPattern()->forEachVariable(foundVar); + + // Now that we've collected the case variables, ensure they're parented to + // the last pattern variables we saw. + for (auto caseVar : caseVars) + foundVar(caseVar); + + return ctx.AllocateCopy(caseVars); } struct FallthroughFinder : ASTWalker { diff --git a/lib/Sema/TypeCheckDecl.cpp b/lib/Sema/TypeCheckDecl.cpp index 0a5c0116d4c5b..6ecc1023c0354 100644 --- a/lib/Sema/TypeCheckDecl.cpp +++ b/lib/Sema/TypeCheckDecl.cpp @@ -2715,7 +2715,7 @@ NamingPatternRequest::evaluate(Evaluator &evaluator, VarDecl *VD) const { } if (!namingPattern) { - if (auto parentStmt = VD->getParentPatternStmt()) { + if (auto parentStmt = VD->getRecursiveParentPatternStmt()) { // Try type checking parent control statement. if (auto condStmt = dyn_cast(parentStmt)) { // The VarDecl is defined inside a condition of a `if` or `while` stmt. From c02c69a7832a52e56a54b262128ef65fb80ef9be Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 07/13] [AST] Introduce `CaseStmt::createImplicit` This allows us to re-use the same logic to create the case body variables. --- include/swift/AST/Stmt.h | 10 ++-- lib/AST/Stmt.cpp | 22 +++++++-- .../DerivedConformance/DerivedConformance.cpp | 9 ++-- .../DerivedConformanceCodable.cpp | 23 +-------- .../DerivedConformanceCodingKey.cpp | 17 +++---- .../DerivedConformanceComparable.cpp | 28 ++--------- .../DerivedConformanceEquatableHashable.cpp | 49 +++---------------- .../DerivedConformanceRawRepresentable.cpp | 17 +++---- unittests/Sema/ConstraintGenerationTests.cpp | 11 ++--- 9 files changed, 55 insertions(+), 131 deletions(-) diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 274d93b4ea6a1..2b205f46bce96 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -1243,13 +1243,17 @@ class CaseStmt final ArrayRef CaseLabelItems, BraceStmt *Body); + static CaseStmt * + createImplicit(ASTContext &ctx, CaseParentKind parentKind, + ArrayRef caseLabelItems, BraceStmt *body, + NullablePtr fallthroughStmt = nullptr); + static CaseStmt * create(ASTContext &C, CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc, ArrayRef CaseLabelItems, SourceLoc UnknownAttrLoc, SourceLoc ItemTerminatorLoc, BraceStmt *Body, - ArrayRef CaseBodyVariables, - std::optional Implicit = std::nullopt, - NullablePtr fallthroughStmt = nullptr); + ArrayRef CaseBodyVariables, std::optional Implicit, + NullablePtr fallthroughStmt); CaseParentKind getParentKind() const { return ParentKind; } diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index 81cd39c9c2579..2f6dd77d433d1 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -897,11 +897,23 @@ CaseStmt *CaseStmt::createParsedDoCatch(ASTContext &ctx, SourceLoc catchLoc, } CaseStmt * -CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind, SourceLoc caseLoc, - ArrayRef caseLabelItems, - SourceLoc unknownAttrLoc, SourceLoc colonLoc, BraceStmt *body, - ArrayRef caseVarDecls, std::optional implicit, - NullablePtr fallthroughStmt) { +CaseStmt::createImplicit(ASTContext &ctx, CaseParentKind parentKind, + ArrayRef caseLabelItems, + BraceStmt *body, + NullablePtr fallthroughStmt) { + auto caseVarDecls = getCaseVarDecls(ctx, caseLabelItems); + return create(ctx, parentKind, /*catchLoc*/ SourceLoc(), caseLabelItems, + /*unknownAttrLoc*/ SourceLoc(), /*colonLoc*/ SourceLoc(), body, + caseVarDecls, /*implicit*/ true, fallthroughStmt); +} + +CaseStmt *CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind, + SourceLoc caseLoc, + ArrayRef caseLabelItems, + SourceLoc unknownAttrLoc, SourceLoc colonLoc, + BraceStmt *body, ArrayRef caseVarDecls, + std::optional implicit, + NullablePtr fallthroughStmt) { void *mem = ctx.Allocate(totalSizeToAlloc( fallthroughStmt.isNonNull(), caseLabelItems.size()), diff --git a/lib/Sema/DerivedConformance/DerivedConformance.cpp b/lib/Sema/DerivedConformance/DerivedConformance.cpp index 16d6d81134402..216c3f6f18ad3 100644 --- a/lib/Sema/DerivedConformance/DerivedConformance.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformance.cpp @@ -805,9 +805,8 @@ DeclRefExpr *DerivedConformance::convertEnumToIndex(SmallVectorImpl &st assignExpr->setType(TupleType::getEmpty(C)); auto body = BraceStmt::create(C, SourceLoc(), ASTNode(assignExpr), SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - labelItem, SourceLoc(), SourceLoc(), body, - /*case body vardecls*/ std::nullopt)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, body)); } // generate: switch enumVar { } @@ -967,9 +966,7 @@ CaseStmt *DerivedConformance::unavailableEnumElementCaseStmt( auto *callExpr = DerivedConformance::createDiagnoseUnavailableCodeReachedCallExpr(C); auto body = BraceStmt::create(C, SourceLoc(), {callExpr}, SourceLoc()); - return CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), labelItem, - SourceLoc(), SourceLoc(), body, {}, - /*implicit*/ true); + return CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, body); } /// Creates a named variable based on a prefix character and a numeric index. diff --git a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp index 99de54d096a65..723df50239c74 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp @@ -946,27 +946,10 @@ createEnumSwitch(ASTContext &C, DeclContext *DC, Expr *expr, EnumDecl *enumDecl, // .(let a0, let a1, ...) SmallVector payloadVars; Pattern *subpattern = nullptr; - ArrayRef caseBodyVarDecls; if (createSubpattern) { subpattern = DerivedConformance::enumElementPayloadSubpattern( elt, 'a', DC, payloadVars, /* useLabels */ true); - - auto hasBoundDecls = !payloadVars.empty(); - if (hasBoundDecls) { - // We allocated a direct copy of our var decls for the case - // body. - auto copy = C.Allocate(payloadVars.size()); - for (unsigned i : indices(payloadVars)) { - auto *vOld = payloadVars[i]; - auto *vNew = new (C) VarDecl( - /*IsStatic*/ false, vOld->getIntroducer(), vOld->getNameLoc(), - vOld->getName(), vOld->getDeclContext()); - vNew->setImplicit(); - copy[i] = vNew; - } - caseBodyVarDecls = copy; - } } // CodingKeys.x @@ -985,10 +968,8 @@ createEnumSwitch(ASTContext &C, DeclContext *DC, Expr *expr, EnumDecl *enumDecl, subpattern, DC); auto labelItem = CaseLabelItem(pat); - auto stmt = - CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), labelItem, - SourceLoc(), SourceLoc(), caseBody, - caseBodyVarDecls); + auto stmt = CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, + caseBody); cases.push_back(stmt); } } diff --git a/lib/Sema/DerivedConformance/DerivedConformanceCodingKey.cpp b/lib/Sema/DerivedConformance/DerivedConformanceCodingKey.cpp index 0c84f71aa84ac..ecbce38ca3032 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceCodingKey.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceCodingKey.cpp @@ -222,10 +222,8 @@ deriveBodyCodingKey_enum_stringValue(AbstractFunctionDecl *strValDecl, void *) { auto *returnStmt = ReturnStmt::createImplicit(C, caseValue); auto *caseBody = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - labelItem, SourceLoc(), SourceLoc(), - caseBody, - /*case body var decls*/ std::nullopt)); + cases.push_back(CaseStmt::createImplicit(C, CaseParentKind::Switch, + labelItem, caseBody)); } auto *selfRef = DerivedConformance::createSelfDeclRef(strValDecl); @@ -292,9 +290,8 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl, void *) { auto *body = BraceStmt::create(C, SourceLoc(), ASTNode(assignment), SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - labelItem, SourceLoc(), SourceLoc(), body, - /*case body var decls*/ std::nullopt)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, body)); } auto *anyPat = AnyPattern::createImplicit(C); @@ -303,10 +300,8 @@ deriveBodyCodingKey_init_stringValue(AbstractFunctionDecl *initDecl, void *) { auto *dfltReturnStmt = new (C) FailStmt(SourceLoc(), SourceLoc()); auto *dfltBody = BraceStmt::create(C, SourceLoc(), ASTNode(dfltReturnStmt), SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - dfltLabelItem, SourceLoc(), SourceLoc(), - dfltBody, - /*case body var decls*/ std::nullopt)); + cases.push_back(CaseStmt::createImplicit(C, CaseParentKind::Switch, + dfltLabelItem, dfltBody)); auto *stringValueDecl = initDecl->getParameters()->get(0); auto *stringValueRef = new (C) DeclRefExpr(stringValueDecl, DeclNameLoc(), diff --git a/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp index 93d823fac3bfd..01287cd0e2bdc 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceComparable.cpp @@ -127,23 +127,6 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v auto *rhsElemPat = EnumElementPattern::createImplicit( enumType, elt, rhsSubpattern, /*DC*/ ltDecl); - auto hasBoundDecls = !lhsPayloadVars.empty(); - ArrayRef caseBodyVarDecls; - if (hasBoundDecls) { - // We allocated a direct copy of our lhs var decls for the case - // body. - auto copy = C.Allocate(lhsPayloadVars.size()); - for (unsigned i : indices(lhsPayloadVars)) { - auto *vOld = lhsPayloadVars[i]; - auto *vNew = new (C) VarDecl( - /*IsStatic*/ false, vOld->getIntroducer(), - vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext()); - vNew->setImplicit(); - copy[i] = vNew; - } - caseBodyVarDecls = copy; - } - // case (.(let l0, let l1, ...), .(let r0, let r1, ...)) auto caseTuplePattern = TuplePattern::createImplicit(C, { TuplePatternElt(lhsElemPat), TuplePatternElt(rhsElemPat) }); @@ -177,9 +160,8 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v auto body = BraceStmt::create(C, SourceLoc(), statementsInCase, SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - labelItem, SourceLoc(), SourceLoc(), body, - caseBodyVarDecls)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, body)); } // default: result = (lhs) < (rhs) @@ -190,10 +172,8 @@ deriveBodyComparable_enum_hasAssociatedValues_lt(AbstractFunctionDecl *ltDecl, v auto defaultPattern = AnyPattern::createImplicit(C); auto defaultItem = CaseLabelItem::getDefault(defaultPattern); auto body = deriveBodyComparable_enum_noAssociatedValues_lt(ltDecl, nullptr).first; - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - defaultItem, SourceLoc(), SourceLoc(), - body, - /*case body var decls*/ std::nullopt)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, defaultItem, body)); } // switch (a, b) { } diff --git a/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp index fc16fd5cf425d..b1cd116446201 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceEquatableHashable.cpp @@ -193,23 +193,6 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl, auto *rhsElemPat = EnumElementPattern::createImplicit( enumType, elt, rhsSubpattern, /*DC*/ eqDecl); - auto hasBoundDecls = !lhsPayloadVars.empty(); - ArrayRef caseBodyVarDecls; - if (hasBoundDecls) { - // We allocated a direct copy of our lhs var decls for the case - // body. - auto copy = C.Allocate(lhsPayloadVars.size()); - for (unsigned i : indices(lhsPayloadVars)) { - auto *vOld = lhsPayloadVars[i]; - auto *vNew = new (C) VarDecl( - /*IsStatic*/ false, vOld->getIntroducer(), - vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext()); - vNew->setImplicit(); - copy[i] = vNew; - } - caseBodyVarDecls = copy; - } - // case (.(let l0, let l1, ...), .(let r0, let r1, ...)) auto caseTuplePattern = TuplePattern::createImplicit(C, { TuplePatternElt(lhsElemPat), TuplePatternElt(rhsElemPat) }); @@ -244,9 +227,8 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl, auto body = BraceStmt::create(C, SourceLoc(), statementsInCase, SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - labelItem, SourceLoc(), SourceLoc(), body, - caseBodyVarDecls)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, body)); } // default: result = false @@ -261,10 +243,8 @@ deriveBodyEquatable_enum_hasAssociatedValues_eq(AbstractFunctionDecl *eqDecl, auto *returnStmt = ReturnStmt::createImplicit(C, falseExpr); auto body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - defaultItem, SourceLoc(), SourceLoc(), - body, - /*case body var decls*/ std::nullopt)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, defaultItem, body)); } // switch (a, b) { } @@ -735,26 +715,9 @@ deriveBodyHashable_enum_hasAssociatedValues_hashInto( statements.emplace_back(ASTNode(combineExpr)); } - auto hasBoundDecls = !payloadVars.empty(); - ArrayRef caseBodyVarDecls; - if (hasBoundDecls) { - auto copy = C.Allocate(payloadVars.size()); - for (unsigned i : indices(payloadVars)) { - auto *vOld = payloadVars[i]; - auto *vNew = new (C) VarDecl( - /*IsStatic*/ false, vOld->getIntroducer(), - vOld->getNameLoc(), vOld->getName(), vOld->getDeclContext()); - vNew->setImplicit(); - copy[i] = vNew; - } - caseBodyVarDecls = copy; - } - auto body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - labelItem, SourceLoc(), SourceLoc(), body, - caseBodyVarDecls, - /*implicit*/ true)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, body)); } // generate: switch enumVar { } diff --git a/lib/Sema/DerivedConformance/DerivedConformanceRawRepresentable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceRawRepresentable.cpp index 4e98dbfd01144..b8a3ec8eefc06 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceRawRepresentable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceRawRepresentable.cpp @@ -129,9 +129,8 @@ deriveBodyRawRepresentable_raw(AbstractFunctionDecl *toRawDecl, void *) { auto body = BraceStmt::create(C, SourceLoc(), ASTNode(returnStmt), SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - labelItem, SourceLoc(), SourceLoc(), body, - /*case body var decls*/ std::nullopt)); + cases.push_back( + CaseStmt::createImplicit(C, CaseParentKind::Switch, labelItem, body)); } auto selfRef = DerivedConformance::createSelfDeclRef(toRawDecl); @@ -363,10 +362,8 @@ deriveBodyRawRepresentable_init(AbstractFunctionDecl *initDecl, void *) { stmts, SourceLoc()); // cases.append("case \(litPat): \(body)") - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - CaseLabelItem(litPat), SourceLoc(), - SourceLoc(), body, - /*case body var decls*/ std::nullopt)); + cases.push_back(CaseStmt::createImplicit(C, CaseParentKind::Switch, + CaseLabelItem(litPat), body)); ++Idx; } @@ -376,10 +373,8 @@ deriveBodyRawRepresentable_init(AbstractFunctionDecl *initDecl, void *) { auto dfltReturnStmt = new (C) FailStmt(SourceLoc(), SourceLoc()); auto dfltBody = BraceStmt::create(C, SourceLoc(), ASTNode(dfltReturnStmt), SourceLoc()); - cases.push_back(CaseStmt::create(C, CaseParentKind::Switch, SourceLoc(), - dfltLabelItem, SourceLoc(), SourceLoc(), - dfltBody, - /*case body var decls*/ std::nullopt)); + cases.push_back(CaseStmt::createImplicit(C, CaseParentKind::Switch, + dfltLabelItem, dfltBody)); auto rawDecl = initDecl->getParameters()->get(0); auto rawRef = new (C) DeclRefExpr(rawDecl, DeclNameLoc(), /*implicit*/true); diff --git a/unittests/Sema/ConstraintGenerationTests.cpp b/unittests/Sema/ConstraintGenerationTests.cpp index 7d4717869395e..e6b36257ae4d5 100644 --- a/unittests/Sema/ConstraintGenerationTests.cpp +++ b/unittests/Sema/ConstraintGenerationTests.cpp @@ -286,10 +286,8 @@ TEST_F(SemaTest, TestSwitchExprLocator) { {IntegerLiteralExpr::createFromUnsigned(Context, 1, SourceLoc())}); auto *truePattern = ExprPattern::createImplicit( Context, new (Context) BooleanLiteralExpr(true, SourceLoc()), DC); - auto *trueCase = - CaseStmt::create(Context, CaseParentKind::Switch, SourceLoc(), - {CaseLabelItem(truePattern)}, SourceLoc(), SourceLoc(), - trueBrace, /*caseBodyVars*/ std::nullopt); + auto *trueCase = CaseStmt::createImplicit( + Context, CaseParentKind::Switch, {CaseLabelItem(truePattern)}, trueBrace); // case false: 2 auto *falseBrace = BraceStmt::createImplicit( @@ -298,9 +296,8 @@ TEST_F(SemaTest, TestSwitchExprLocator) { auto *falsePattern = ExprPattern::createImplicit( Context, new (Context) BooleanLiteralExpr(false, SourceLoc()), DC); auto *falseCase = - CaseStmt::create(Context, CaseParentKind::Switch, SourceLoc(), - {CaseLabelItem(falsePattern)}, SourceLoc(), SourceLoc(), - falseBrace, /*caseBodyVars*/ std::nullopt); + CaseStmt::createImplicit(Context, CaseParentKind::Switch, + {CaseLabelItem(falsePattern)}, falseBrace); auto *subject = new (Context) BooleanLiteralExpr(true, SourceLoc()); From dc13b1f442af900b568932010931f06ad5aa778e Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 08/13] [AST] Tail-allocate case body variables on CaseStmt --- include/swift/AST/Stmt.h | 18 +++++++++++------- lib/AST/Stmt.cpp | 18 ++++++++++++------ 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/include/swift/AST/Stmt.h b/include/swift/AST/Stmt.h index 2b205f46bce96..c8f9d4bb3e76d 100644 --- a/include/swift/AST/Stmt.h +++ b/include/swift/AST/Stmt.h @@ -84,8 +84,9 @@ class alignas(8) Stmt : public ASTAllocated { NumElements : 32 ); - SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 32, + SWIFT_INLINE_BITFIELD_FULL(CaseStmt, Stmt, 16+32, : NumPadBits, + NumCaseBodyVars : 16, NumPatterns : 32 ); @@ -1210,8 +1211,8 @@ enum CaseParentKind { Switch, DoCatch }; /// class CaseStmt final : public Stmt, - private llvm::TrailingObjects { + private llvm::TrailingObjects { friend TrailingObjects; Stmt *ParentStmt = nullptr; @@ -1222,14 +1223,17 @@ class CaseStmt final llvm::PointerIntPair BodyAndHasFallthrough; - ArrayRef CaseBodyVariables; - CaseStmt(CaseParentKind ParentKind, SourceLoc ItemIntroducerLoc, ArrayRef CaseLabelItems, SourceLoc UnknownAttrLoc, SourceLoc ItemTerminatorLoc, BraceStmt *Body, ArrayRef CaseBodyVariables, std::optional Implicit, NullablePtr fallthroughStmt); + MutableArrayRef getCaseBodyVariablesBuffer() { + return {getTrailingObjects(), + static_cast(Bits.CaseStmt.NumCaseBodyVars)}; + } + public: /// Create a parsed 'case'/'default' for 'switch' statement. static CaseStmt * @@ -1296,7 +1300,7 @@ class CaseStmt final void setBody(BraceStmt *body) { BodyAndHasFallthrough.setPointer(body); } /// True if the case block declares any patterns with local variable bindings. - bool hasCaseBodyVariables() const { return !CaseBodyVariables.empty(); } + bool hasCaseBodyVariables() const { return !getCaseBodyVariables().empty(); } /// Get the source location of the 'case', 'default', or 'catch' of the first /// label. @@ -1349,7 +1353,7 @@ class CaseStmt final /// Return an ArrayRef containing the case body variables of this CaseStmt. ArrayRef getCaseBodyVariables() const { - return CaseBodyVariables; + return const_cast(this)->getCaseBodyVariablesBuffer(); } /// Find the next case statement within the same 'switch' or 'do-catch', diff --git a/lib/AST/Stmt.cpp b/lib/AST/Stmt.cpp index 2f6dd77d433d1..782dffd49e2ce 100644 --- a/lib/AST/Stmt.cpp +++ b/lib/AST/Stmt.cpp @@ -758,9 +758,11 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, : Stmt(StmtKind::Case, getDefaultImplicitFlag(implicit, itemIntroducerLoc)), UnknownAttrLoc(unknownAttrLoc), ItemIntroducerLoc(itemIntroducerLoc), ItemTerminatorLoc(itemTerminatorLoc), ParentKind(parentKind), - BodyAndHasFallthrough(body, fallthroughStmt.isNonNull()), - CaseBodyVariables(caseBodyVariables) { + BodyAndHasFallthrough(body, fallthroughStmt.isNonNull()) { Bits.CaseStmt.NumPatterns = caseLabelItems.size(); + Bits.CaseStmt.NumCaseBodyVars = caseBodyVariables.size(); + ASSERT(Bits.CaseStmt.NumCaseBodyVars == caseBodyVariables.size() && + "too many case body vars"); assert(Bits.CaseStmt.NumPatterns > 0 && "case block must have at least one pattern"); assert( @@ -770,6 +772,9 @@ CaseStmt::CaseStmt(CaseParentKind parentKind, SourceLoc itemIntroducerLoc, *getTrailingObjects() = fallthroughStmt.get(); } + std::uninitialized_copy(caseBodyVariables.begin(), caseBodyVariables.end(), + getCaseBodyVariablesBuffer().begin()); + MutableArrayRef items{getTrailingObjects(), static_cast(Bits.CaseStmt.NumPatterns)}; @@ -914,10 +919,11 @@ CaseStmt *CaseStmt::create(ASTContext &ctx, CaseParentKind ParentKind, BraceStmt *body, ArrayRef caseVarDecls, std::optional implicit, NullablePtr fallthroughStmt) { - void *mem = - ctx.Allocate(totalSizeToAlloc( - fallthroughStmt.isNonNull(), caseLabelItems.size()), - alignof(CaseStmt)); + void *mem = ctx.Allocate( + totalSizeToAlloc( + fallthroughStmt.isNonNull(), caseLabelItems.size(), + caseVarDecls.size()), + alignof(CaseStmt)); return ::new (mem) CaseStmt(ParentKind, caseLoc, caseLabelItems, unknownAttrLoc, colonLoc, body, caseVarDecls, implicit, fallthroughStmt); From 19831722788c85d0ab932bad5ab2840dadc748ee Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 09/13] [Sema] Rename `bindSwitchCasePatternVars` -> `diagnoseCaseVarMutabilityMismatch` Now that we wire up the parents up-front, this no longer needs to set the parents. As such, remove the logic and rename to reflect the fact that it now just diagnoses mutability mismatches. --- lib/Sema/CSSyntacticElement.cpp | 2 +- lib/Sema/TypeCheckStmt.cpp | 28 +++++++--------------------- lib/Sema/TypeChecker.h | 23 +++-------------------- 3 files changed, 11 insertions(+), 42 deletions(-) diff --git a/lib/Sema/CSSyntacticElement.cpp b/lib/Sema/CSSyntacticElement.cpp index e5bac5ebacd21..8981d5004397f 100644 --- a/lib/Sema/CSSyntacticElement.cpp +++ b/lib/Sema/CSSyntacticElement.cpp @@ -2005,7 +2005,7 @@ class SyntacticElementSolutionApplication } } - bindSwitchCasePatternVars(context.getAsDeclContext(), caseStmt); + diagnoseCaseVarMutabilityMismatch(context.getAsDeclContext(), caseStmt); for (auto *expected : caseStmt->getCaseBodyVariables()) { assert(expected->hasName()); diff --git a/lib/Sema/TypeCheckStmt.cpp b/lib/Sema/TypeCheckStmt.cpp index 15fa32c241286..08477dc6a42fa 100644 --- a/lib/Sema/TypeCheckStmt.cpp +++ b/lib/Sema/TypeCheckStmt.cpp @@ -1581,9 +1581,7 @@ class StmtChecker : public StmtVisitor { // First pass: check all of the bindings. for (auto *caseBlock : make_range(casesBegin, casesEnd)) { - // Bind all of the pattern variables together so we can follow the - // "parent" pointers later on. - bindSwitchCasePatternVars(DC, caseBlock); + diagnoseCaseVarMutabilityMismatch(DC, caseBlock); auto caseLabelItemArray = caseBlock->getMutableCaseLabelItems(); for (auto &labelItem : caseLabelItemArray) { @@ -3238,7 +3236,8 @@ void swift::checkUnknownAttrRestrictions( } } -void swift::bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *caseStmt) { +void swift::diagnoseCaseVarMutabilityMismatch(DeclContext *dc, + CaseStmt *caseStmt) { llvm::SmallDenseMap, 4> latestVars; auto recordVar = [&](Pattern *pattern, VarDecl *var) { if (!var->hasName()) @@ -3248,16 +3247,10 @@ void swift::bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *caseStmt) { // parent of this new variable. auto &entry = latestVars[var->getName()]; if (entry.first) { - assert(!var->getParentVarDecl() || - var->getParentVarDecl() == entry.first); - var->setParentVarDecl(entry.first); - // Check for a mutability mismatch. - if (pattern && entry.second != var->isLet()) { + if (entry.second != var->isLet()) { // Find the original declaration. - auto initialCaseVarDecl = entry.first; - while (auto parentVar = initialCaseVarDecl->getParentVarDecl()) - initialCaseVarDecl = parentVar; + auto initialCaseVarDecl = entry.first->getCanonicalVarDecl(); auto diag = var->diagnose(diag::mutability_mismatch_multiple_pattern_list, var->isLet(), initialCaseVarDecl->isLet()); @@ -3268,10 +3261,10 @@ void swift::bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *caseStmt) { if (VP->getSingleVar() == var) foundVP = VP; }); - if (foundVP) + if (foundVP) { diag.fixItReplace(foundVP->getLoc(), initialCaseVarDecl->isLet() ? "let" : "var"); - + } var->setInvalid(); initialCaseVarDecl->setInvalid(); } @@ -3283,8 +3276,6 @@ void swift::bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *caseStmt) { entry.first = var; }; - // Wire up the parent var decls for each variable that occurs within - // the patterns of each case item. in source order. for (auto &caseItem : caseStmt->getMutableCaseLabelItems()) { // Resolve the pattern. auto *pattern = caseItem.getPattern(); @@ -3300,11 +3291,6 @@ void swift::bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *caseStmt) { recordVar(pattern, var); }); } - - // Wire up the case body variables to the latest patterns. - for (auto bodyVar : caseStmt->getCaseBodyVariables()) { - recordVar(nullptr, bodyVar); - } } FuncDecl *TypeChecker::getForEachIteratorNextFunction( diff --git a/lib/Sema/TypeChecker.h b/lib/Sema/TypeChecker.h index 188edb66b0f98..4cd7ca3f0eaa4 100644 --- a/lib/Sema/TypeChecker.h +++ b/lib/Sema/TypeChecker.h @@ -1416,26 +1416,9 @@ bool checkFallthroughStmt(FallthroughStmt *stmt); void checkUnknownAttrRestrictions( ASTContext &ctx, CaseStmt *caseBlock, bool &limitExhaustivityChecks); -/// Bind all of the pattern variables that occur within a case statement and -/// all of its case items to their "parent" pattern variables, forming chains -/// of variables with the same name. -/// -/// Given a case such as: -/// \code -/// case .a(let x), .b(let x), .c(let x): -/// \endcode -/// -/// Each case item contains a (different) pattern variable named. -/// "x". This function will set the "parent" variable of the -/// second and third "x" variables to the "x" variable immediately -/// to its left. A fourth "x" will be the body case variable, -/// whose parent will be set to the "x" within the final case -/// item. -/// -/// Each of the "x" variables must eventually have the same type, and agree on -/// let vs. var. This function does not perform any of that validation, leaving -/// it to later stages. -void bindSwitchCasePatternVars(DeclContext *dc, CaseStmt *stmt); +/// Diagnoses any mutability mismatches for any same-named variables bound by +/// given CaseStmt. +void diagnoseCaseVarMutabilityMismatch(DeclContext *dc, CaseStmt *stmt); /// If \p attr was added by an access note, wraps the error in /// \c diag::wrap_invalid_attr_added_by_access_note and limits it as an access From 413824c082a4248eda6ae865af402e4f0bf0424a Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 10/13] [AST] Simplify `getRecursiveParentPatternStmt` `findParentPatternCaseStmtAndPattern` does more work than is necessary for this, we just want the parent of the canonical var. --- lib/AST/Decl.cpp | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index c470d81eee8f8..f79a9a1a2dd9c 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -8227,16 +8227,7 @@ VarDecl *VarDecl::getCanonicalVarDecl() const { } Stmt *VarDecl::getRecursiveParentPatternStmt() const { - // If our parent is already a pattern stmt, just return that. - if (auto *stmt = getParentPatternStmt()) - return stmt; - - // Otherwise, see if we have a parent var decl. If we do not, then return - // nullptr. Otherwise, return the case stmt that we found. - auto result = findParentPatternCaseStmtAndPattern(this); - if (!result.has_value()) - return nullptr; - return result->first; + return getCanonicalVarDecl()->getParentPatternStmt(); } /// Return the Pattern involved in initializing this VarDecl. Recall that the From 79fe1b354885b3c0664e039128c836a3e88e3956 Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 11/13] [AST] Remove `findParentPatternCaseStmtAndPattern` Add the extra logic to `VarDecl::getParentPattern` necessary to handle fallthrough and case body variables instead. This also changes the behavior for case body vars - previously we would return the first pattern in the CaseStmt, but that's not necessarily correct. Instead, return the first pattern that actually binds the variable. --- lib/AST/Decl.cpp | 81 ++++++++++++++---------------------------------- 1 file changed, 23 insertions(+), 58 deletions(-) diff --git a/lib/AST/Decl.cpp b/lib/AST/Decl.cpp index f79a9a1a2dd9c..a8d3cb7644400 100644 --- a/lib/AST/Decl.cpp +++ b/lib/AST/Decl.cpp @@ -8159,50 +8159,6 @@ SourceRange AbstractStorageDecl::getTypeSourceRangeForDiagnostics() const { return SourceRange(); } -static std::optional> -findParentPatternCaseStmtAndPattern(const VarDecl *inputVD) { - auto getMatchingPattern = [&](CaseStmt *cs) -> Pattern * { - // Check if inputVD is in our case body var decls if we have any. If we do, - // treat its pattern as our first case label item pattern. - for (auto *vd : cs->getCaseBodyVariables()) { - if (vd == inputVD) { - return cs->getMutableCaseLabelItems().front().getPattern(); - } - } - - // Then check the rest of our case label items. - for (auto &item : cs->getMutableCaseLabelItems()) { - if (item.getPattern()->containsVarDecl(inputVD)) { - return item.getPattern(); - } - } - - // Otherwise return false if we do not find anything. - return nullptr; - }; - - // First find our canonical var decl. This is the VarDecl corresponding to the - // first case label item of the first case block in the fallthrough chain that - // our case block is within. Grab the case stmt associated with that var decl - // and start traveling down the fallthrough chain looking for the case - // statement that the input VD belongs to by using getMatchingPattern(). - auto *canonicalVD = inputVD->getCanonicalVarDecl(); - auto *caseStmt = - dyn_cast_or_null(canonicalVD->getParentPatternStmt()); - if (!caseStmt) - return std::nullopt; - - if (auto *p = getMatchingPattern(caseStmt)) - return std::make_pair(caseStmt, p); - - while ((caseStmt = caseStmt->getFallthroughDest().getPtrOrNull())) { - if (auto *p = getMatchingPattern(caseStmt)) - return std::make_pair(caseStmt, p); - } - - return std::nullopt; -} - VarDecl *VarDecl::getCanonicalVarDecl() const { // Any var decl without a parent var decl is canonical. This means that before // type checking, all var decls are canonical. @@ -8247,17 +8203,34 @@ Pattern *VarDecl::getParentPattern() const { } // If this is a statement parent, dig the pattern out of it. - if (auto *stmt = getParentPatternStmt()) { + const auto *canonicalVD = getCanonicalVarDecl(); + if (auto *stmt = canonicalVD->getParentPatternStmt()) { if (auto *FES = dyn_cast(stmt)) return FES->getPattern(); if (auto *cs = dyn_cast(stmt)) { - // In a case statement, search for the pattern that contains it. This is - // a bit silly, because you can't have something like "case x, y:" anyway. - for (auto items : cs->getCaseLabelItems()) { - if (items.getPattern()->containsVarDecl(this)) - return items.getPattern(); + // In a case statement, search for the pattern that contains it. + auto findPattern = [](CaseStmt *cs, const VarDecl *VD) -> Pattern * { + for (auto items : cs->getCaseLabelItems()) { + if (items.getPattern()->containsVarDecl(VD)) + return items.getPattern(); + } + return nullptr; + }; + if (auto *P = findPattern(cs, this)) + return P; + + // If it's not in the CaseStmt, check its fallthrough destination. + if (auto fallthrough = cs->getFallthroughDest()) { + if (auto *P = findPattern(fallthrough.get(), this)) + return P; } + + // Finally, check the canonical variable, this is necessary to correctly + // handle case body vars, we just want to take the first pattern that + // declares it in that case. + if (auto *P = findPattern(cs, canonicalVD)) + return P; } if (auto *LCS = dyn_cast(stmt)) { @@ -8268,14 +8241,6 @@ Pattern *VarDecl::getParentPattern() const { } } - // Otherwise, check if we have to walk our case stmt's var decl list to find - // the pattern. - if (auto caseStmtPatternPair = findParentPatternCaseStmtAndPattern(this)) { - return caseStmtPatternPair->second; - } - - // Otherwise, this is a case we do not know or understand. Return nullptr to - // signal we do not have any information. return nullptr; } From 84847bcd0693a5b1722239f7d7723fc3027a22da Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 12/13] [Sema] Relax a check in `VarDeclUsageChecker` We don't want to just check the first pattern, we're interested in the first pattern that binds the given variable. That can be determined by checking if it's canonical or not. --- lib/Sema/MiscDiagnostics.cpp | 8 ++------ test/Parse/switch.swift | 6 +++--- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/lib/Sema/MiscDiagnostics.cpp b/lib/Sema/MiscDiagnostics.cpp index be5c7ac80ec64..9e9619e7fce91 100644 --- a/lib/Sema/MiscDiagnostics.cpp +++ b/lib/Sema/MiscDiagnostics.cpp @@ -4007,12 +4007,8 @@ VarDeclUsageChecker::~VarDeclUsageChecker() { if (auto *caseStmt = dyn_cast_or_null(var->getRecursiveParentPatternStmt())) { - // Only diagnose VarDecls from the first CaseLabelItem in CaseStmts, as - // the remaining items must match it anyway. - auto caseItems = caseStmt->getCaseLabelItems(); - assert(!caseItems.empty() && - "If we have any case stmt var decls, we should have a case item"); - if (!caseItems.front().getPattern()->containsVarDecl(var)) + // Only diagnose for the parent-most VarDecl. + if (var->getParentVarDecl()) continue; auto *childVar = var->getCorrespondingCaseBodyVariable().get(); diff --git a/test/Parse/switch.swift b/test/Parse/switch.swift index 78801b1b44b00..b181bfbaa1c56 100644 --- a/test/Parse/switch.swift +++ b/test/Parse/switch.swift @@ -200,10 +200,10 @@ switch t { case (var a, 2), (1, _): // expected-error {{'a' must be bound in every pattern}} expected-warning {{variable 'a' was never used; consider replacing with '_' or removing it}} () -case (_, 2), (var a, _): // expected-error {{'a' must be bound in every pattern}} +case (_, 2), (var a, _): // expected-error {{'a' must be bound in every pattern}} expected-warning {{variable 'a' was never used; consider replacing with '_' or removing it}} () -case (var a, 2), (1, var b): // expected-error {{'a' must be bound in every pattern}} expected-error {{'b' must be bound in every pattern}} expected-warning {{variable 'a' was never used; consider replacing with '_' or removing it}} +case (var a, 2), (1, var b): // expected-error {{'a' must be bound in every pattern}} expected-error {{'b' must be bound in every pattern}} expected-warning {{variable 'a' was never used; consider replacing with '_' or removing it}} expected-warning {{variable 'b' was never used; consider replacing with '_' or removing it}} () case (var a, 2): // expected-error {{'case' label in a 'switch' must have at least one executable statement}} {{17-17= break}} expected-warning {{variable 'a' was never used; consider replacing with '_' or removing it}} @@ -221,7 +221,7 @@ case (1, var b): // expected-warning {{variable 'b' was never used; consider rep case (1, let b): // let bindings expected-warning {{immutable value 'b' was never used; consider replacing with '_' or removing it}} () -case (_, 2), (let a, _): // expected-error {{'a' must be bound in every pattern}} expected-warning {{case is already handled by previous patterns; consider removing it}} +case (_, 2), (let a, _): // expected-error {{'a' must be bound in every pattern}} expected-warning {{case is already handled by previous patterns; consider removing it}} expected-warning {{immutable value 'a' was never used; consider replacing with '_' or removing it}} () // OK From 10ed17549c04839ce67af3dd2dce5726d9aa22c6 Mon Sep 17 00:00:00 2001 From: Hamish Knight Date: Tue, 9 Sep 2025 13:48:40 +0100 Subject: [PATCH 13/13] [CS] Set the naming pattern in `markInvalid` This normally gets populated by successful type-checking, we still want to populate it if we fail though to avoid attempting to type-check the parent statement again. --- lib/Sema/SyntacticElementTarget.cpp | 6 ++++++ test/Constraints/issue-66553.swift | 2 +- test/Parse/matching_patterns.swift | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/lib/Sema/SyntacticElementTarget.cpp b/lib/Sema/SyntacticElementTarget.cpp index 60b53dc60b8e2..d86a4861a3d26 100644 --- a/lib/Sema/SyntacticElementTarget.cpp +++ b/lib/Sema/SyntacticElementTarget.cpp @@ -307,6 +307,12 @@ void SyntacticElementTarget::markInvalid() const { PreWalkResult walkToPatternPre(Pattern *P) override { P->setType(ErrorType::get(Ctx)); + + // For a named pattern, set it on the variable. This stops us from + // attempting to double-type-check variables we've already type-checked. + if (auto *NP = dyn_cast(P)) + NP->getDecl()->setNamingPattern(NP); + return Action::Continue(P); } diff --git a/test/Constraints/issue-66553.swift b/test/Constraints/issue-66553.swift index 1933e9c22551e..983aca34aa61d 100644 --- a/test/Constraints/issue-66553.swift +++ b/test/Constraints/issue-66553.swift @@ -4,7 +4,7 @@ func baz(y: [Int], z: Int) -> Int { switch z { - case y[let z]: // expected-error 2{{'let' binding pattern cannot appear in an expression}} + case y[let z]: // expected-error {{'let' binding pattern cannot appear in an expression}} z default: z diff --git a/test/Parse/matching_patterns.swift b/test/Parse/matching_patterns.swift index 7ef93355a4c26..5da5d0ba9bd8e 100644 --- a/test/Parse/matching_patterns.swift +++ b/test/Parse/matching_patterns.swift @@ -406,7 +406,7 @@ func testNonBinding5(_ x: Int, y: [Int]) { func testNonBinding6(y: [Int], z: Int) -> Int { switch 0 { // We treat 'z' here as a binding, which is invalid. - case let y[z]: // expected-error 2{{pattern variable binding cannot appear in an expression}} + case let y[z]: // expected-error {{pattern variable binding cannot appear in an expression}} z case y[z]: // This is fine 0