From 35628cb50339df96069eaf838b6ea408defe045c Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Wed, 9 Jul 2025 16:19:14 -0700 Subject: [PATCH] [Strict memory safety] Fix "unsafe" checking for the for..in loop The `$generator` variable we create for the async for..in loop is `nonisolated(unsafe)`, so ensure that we generate an `unsafe` expression when we use it. This uncovered some inconsistencies in how we do `unsafe` checking for for..in loops, so fix those. Fixes rdar://154775389. --- lib/Sema/CSGen.cpp | 13 +++++++++---- lib/Sema/TypeCheckEffects.cpp | 11 +++++------ test/Unsafe/safe.swift | 2 ++ test/Unsafe/unsafe_concurrency.swift | 7 +++++++ 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 4c67c74246905..411d9dfef68ad 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4728,10 +4728,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, } // Wrap the 'next' call in 'unsafe', if the for..in loop has that - // effect. - if (stmt->getUnsafeLoc().isValid()) { - nextCall = new (ctx) UnsafeExpr( - stmt->getUnsafeLoc(), nextCall, Type(), /*implicit=*/true); + // effect or if the loop is async (in which case the iterator variable + // is nonisolated(unsafe). + if (stmt->getUnsafeLoc().isValid() || + (isAsync && + ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { + SourceLoc loc = stmt->getUnsafeLoc(); + if (loc.isInvalid()) + loc = stmt->getForLoc(); + nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), /*implicit=*/true); } // The iterator type must conform to IteratorProtocol. diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index 4b6727a90916d..f0502c3034c82 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -2457,7 +2457,7 @@ class ApplyClassifier { return ShouldRecurse; } ShouldRecurse_t checkUnsafe(UnsafeExpr *E) { - return E->isImplicit() ? ShouldRecurse : ShouldNotRecurse; + return ShouldNotRecurse; } ShouldRecurse_t checkTry(TryExpr *E) { return ShouldRecurse; @@ -4626,10 +4626,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker diagnoseUnsafeUse(unsafeUse); } } - } else if (S->getUnsafeLoc().isValid()) { - // Extraneous "unsafe" on the sequence. - Ctx.Diags.diagnose(S->getUnsafeLoc(), diag::no_unsafe_in_unsafe_for) - .fixItRemove(S->getUnsafeLoc()); } return ShouldRecurse; @@ -4689,7 +4685,10 @@ class CheckEffectsCoverage : public EffectsHandlingWalker return; } - Ctx.Diags.diagnose(E->getUnsafeLoc(), diag::no_unsafe_in_unsafe) + Ctx.Diags.diagnose(E->getUnsafeLoc(), + forEachNextCallExprs.contains(E) + ? diag::no_unsafe_in_unsafe_for + : diag::no_unsafe_in_unsafe) .fixItRemove(E->getUnsafeLoc()); } diff --git a/test/Unsafe/safe.swift b/test/Unsafe/safe.swift index b2be9e630da76..254431bbd4d91 100644 --- a/test/Unsafe/safe.swift +++ b/test/Unsafe/safe.swift @@ -98,6 +98,8 @@ func testUnsafeAsSequenceForEach() { for _ in unsafe uas { } // expected-warning{{for-in loop uses unsafe constructs but is not marked with 'unsafe'}}{{documentation-file=strict-memory-safety}}{{7-7=unsafe }} for unsafe _ in unsafe uas { } // okay + + for unsafe _ in [1, 2, 3] { } // expected-warning{{no unsafe operations occur within 'unsafe' for-in loop}} } func testForInUnsafeAmbiguity(_ integers: [Int]) { diff --git a/test/Unsafe/unsafe_concurrency.swift b/test/Unsafe/unsafe_concurrency.swift index 00dea6cf54d3e..895c62cdd816f 100644 --- a/test/Unsafe/unsafe_concurrency.swift +++ b/test/Unsafe/unsafe_concurrency.swift @@ -55,3 +55,10 @@ open class SyntaxVisitor { open class SyntaxAnyVisitor: SyntaxVisitor { override open func visit(_ token: TokenSyntax) { } } + +@available(SwiftStdlib 5.1, *) +func testMemorySafetyWithForLoop() async { + let (stream, continuation) = AsyncStream.makeStream() + for await _ in stream {} + _ = continuation +}