From d6deeccc813f561afb47ecce3879537873199f87 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Wed, 9 Jul 2025 16:19:14 -0700 Subject: [PATCH 1/2] [Strict memory safety] Fix "unsafe" checking for the for..in loop The `$generator` variable we create for the async for..in loop is `nonisolated(unsafe)`, so ensure that we generate an `unsafe` expression when we use it. This uncovered some inconsistencies in how we do `unsafe` checking for for..in loops, so fix those. Fixes rdar://154775389. (cherry picked from commit 35628cb50339df96069eaf838b6ea408defe045c) --- lib/Sema/CSGen.cpp | 13 +++++++++---- lib/Sema/TypeCheckEffects.cpp | 11 +++++------ test/Unsafe/safe.swift | 2 ++ test/Unsafe/unsafe_concurrency.swift | 7 +++++++ 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index a7f6be1327035..13b758e7f45b8 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4692,10 +4692,15 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, } // Wrap the 'next' call in 'unsafe', if the for..in loop has that - // effect. - if (stmt->getUnsafeLoc().isValid()) { - nextCall = new (ctx) UnsafeExpr( - stmt->getUnsafeLoc(), nextCall, Type(), /*implicit=*/true); + // effect or if the loop is async (in which case the iterator variable + // is nonisolated(unsafe). + if (stmt->getUnsafeLoc().isValid() || + (isAsync && + ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { + SourceLoc loc = stmt->getUnsafeLoc(); + if (loc.isInvalid()) + loc = stmt->getForLoc(); + nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), /*implicit=*/true); } // The iterator type must conform to IteratorProtocol. diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index 1d76f6186b9d0..e4d51eac1a1ce 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -2439,7 +2439,7 @@ class ApplyClassifier { return ShouldRecurse; } ShouldRecurse_t checkUnsafe(UnsafeExpr *E) { - return E->isImplicit() ? ShouldRecurse : ShouldNotRecurse; + return ShouldNotRecurse; } ShouldRecurse_t checkTry(TryExpr *E) { return ShouldRecurse; @@ -4573,10 +4573,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker diagnoseUnsafeUse(unsafeUse); } } - } else if (S->getUnsafeLoc().isValid()) { - // Extraneous "unsafe" on the sequence. - Ctx.Diags.diagnose(S->getUnsafeLoc(), diag::no_unsafe_in_unsafe_for) - .fixItRemove(S->getUnsafeLoc()); } return ShouldRecurse; @@ -4636,7 +4632,10 @@ class CheckEffectsCoverage : public EffectsHandlingWalker return; } - Ctx.Diags.diagnose(E->getUnsafeLoc(), diag::no_unsafe_in_unsafe) + Ctx.Diags.diagnose(E->getUnsafeLoc(), + forEachNextCallExprs.contains(E) + ? diag::no_unsafe_in_unsafe_for + : diag::no_unsafe_in_unsafe) .fixItRemove(E->getUnsafeLoc()); } diff --git a/test/Unsafe/safe.swift b/test/Unsafe/safe.swift index 5131dfbfee172..2af1ca48ef309 100644 --- a/test/Unsafe/safe.swift +++ b/test/Unsafe/safe.swift @@ -98,6 +98,8 @@ func testUnsafeAsSequenceForEach() { for _ in unsafe uas { } // expected-warning{{for-in loop uses unsafe constructs but is not marked with 'unsafe'}}{{documentation-file=strict-memory-safety}}{{7-7=unsafe }} for unsafe _ in unsafe uas { } // okay + + for unsafe _ in [1, 2, 3] { } // expected-warning{{no unsafe operations occur within 'unsafe' for-in loop}} } func testForInUnsafeAmbiguity(_ integers: [Int]) { diff --git a/test/Unsafe/unsafe_concurrency.swift b/test/Unsafe/unsafe_concurrency.swift index 00dea6cf54d3e..895c62cdd816f 100644 --- a/test/Unsafe/unsafe_concurrency.swift +++ b/test/Unsafe/unsafe_concurrency.swift @@ -55,3 +55,10 @@ open class SyntaxVisitor { open class SyntaxAnyVisitor: SyntaxVisitor { override open func visit(_ token: TokenSyntax) { } } + +@available(SwiftStdlib 5.1, *) +func testMemorySafetyWithForLoop() async { + let (stream, continuation) = AsyncStream.makeStream() + for await _ in stream {} + _ = continuation +} From 8c34d5787426bd2f6ffcf98bc452e249be5d1885 Mon Sep 17 00:00:00 2001 From: Doug Gregor Date: Thu, 10 Jul 2025 08:48:37 -0700 Subject: [PATCH 2/2] [Strict memory safety] Eliminate spurious warnings with synthesized Codable When synthesizing code for Codable conformances involving unsafe types, make sure to wrap the resulting expressions in "unsafe" when strict memory safety is enabled. Tweak the warning-emission logic to suppress warnings about spurious "unsafe" expressions when the compiler generated the "unsafe" itself, so we don't spam the developer with warnings they can't fix. Also make the checking for other suppression considerations safe when there are no source locations, eliminating a potential assertion. Fixes rdar://153665692. --- lib/Sema/CSGen.cpp | 3 +- .../DerivedConformanceCodable.cpp | 48 ++++++++++++------- lib/Sema/TypeCheckEffects.cpp | 23 +++++---- lib/Sema/TypeCheckPattern.cpp | 2 +- test/Unsafe/codable_synthesis.swift | 26 ++++++++++ test/Unsafe/hashable_synthesis.swift | 18 +++++++ 6 files changed, 92 insertions(+), 28 deletions(-) create mode 100644 test/Unsafe/codable_synthesis.swift create mode 100644 test/Unsafe/hashable_synthesis.swift diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 13b758e7f45b8..7ffa9cb0fb0ec 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4698,9 +4698,10 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, (isAsync && ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { SourceLoc loc = stmt->getUnsafeLoc(); + bool implicit = stmt->getUnsafeLoc().isInvalid(); if (loc.isInvalid()) loc = stmt->getForLoc(); - nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), /*implicit=*/true); + nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), implicit); } // The iterator type must conform to IteratorProtocol. diff --git a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp index 8d9b5bd022931..1b88527d1dc61 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp @@ -772,10 +772,20 @@ lookupVarDeclForCodingKeysCase(DeclContext *conformanceDC, llvm_unreachable("Should have found at least 1 var decl"); } -static TryExpr *createEncodeCall(ASTContext &C, Type codingKeysType, - EnumElementDecl *codingKey, - Expr *containerExpr, Expr *varExpr, - bool useIfPresentVariant) { +/// If strict memory safety checking is enabled, wrap the expression in an +/// implicit "unsafe". +static Expr *wrapInUnsafeIfNeeded(ASTContext &ctx, Expr *expr) { + if (ctx.LangOpts.hasFeature(Feature::StrictMemorySafety, + /*allowMigration=*/true)) + return UnsafeExpr::createImplicit(ctx, expr->getStartLoc(), expr); + + return expr; +} + +static Expr *createEncodeCall(ASTContext &C, Type codingKeysType, + EnumElementDecl *codingKey, + Expr *containerExpr, Expr *varExpr, + bool useIfPresentVariant) { // CodingKeys.x auto *metaTyRef = TypeExpr::createImplicit(codingKeysType, C); auto *keyExpr = new (C) MemberRefExpr(metaTyRef, SourceLoc(), codingKey, @@ -794,7 +804,7 @@ static TryExpr *createEncodeCall(ASTContext &C, Type codingKeysType, // try container.encode(x, forKey: CodingKeys.x) auto *tryExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - return tryExpr; + return wrapInUnsafeIfNeeded(C, tryExpr); } /// Synthesizes the body for `func encode(to encoder: Encoder) throws`. @@ -929,7 +939,7 @@ deriveBodyEncodable_encode(AbstractFunctionDecl *encodeDecl, void *) { // try super.encode(to: container.superEncoder()) auto *tryExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - statements.push_back(tryExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, tryExpr)); } auto *body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc(), @@ -1112,8 +1122,10 @@ deriveBodyEncodable_enum_encode(AbstractFunctionDecl *encodeDecl, void *) { // generate: switch self { } auto enumRef = - new (C) DeclRefExpr(ConcreteDeclRef(selfRef), DeclNameLoc(), - /*implicit*/ true, AccessSemantics::Ordinary); + wrapInUnsafeIfNeeded( + C, + new (C) DeclRefExpr(ConcreteDeclRef(selfRef), DeclNameLoc(), + /*implicit*/ true, AccessSemantics::Ordinary)); auto switchStmt = createEnumSwitch( C, funcDC, enumRef, enumDecl, codingKeysEnum, /*createSubpattern*/ true, @@ -1276,11 +1288,11 @@ static FuncDecl *deriveEncodable_encode(DerivedConformance &derived) { return encodeDecl; } -static TryExpr *createDecodeCall(ASTContext &C, Type resultType, - Type codingKeysType, - EnumElementDecl *codingKey, - Expr *containerExpr, - bool useIfPresentVariant) { +static Expr *createDecodeCall(ASTContext &C, Type resultType, + Type codingKeysType, + EnumElementDecl *codingKey, + Expr *containerExpr, + bool useIfPresentVariant) { auto methodName = useIfPresentVariant ? C.Id_decodeIfPresent : C.Id_decode; // Type.self @@ -1470,7 +1482,7 @@ deriveBodyDecodable_init(AbstractFunctionDecl *initDecl, void *) { varDecl->getName()); auto *assignExpr = new (C) AssignExpr(varExpr, SourceLoc(), tryExpr, /*Implicit=*/true); - statements.push_back(assignExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, assignExpr)); } } @@ -1506,7 +1518,7 @@ deriveBodyDecodable_init(AbstractFunctionDecl *initDecl, void *) { // try super.init(from: container.superDecoder()) auto *tryExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - statements.push_back(tryExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, tryExpr)); } else { // The explicit constructor name is a compound name taking no arguments. DeclName initName(C, DeclBaseName::createConstructor(), @@ -1538,7 +1550,7 @@ deriveBodyDecodable_init(AbstractFunctionDecl *initDecl, void *) { callExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - statements.push_back(callExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, callExpr)); } } } @@ -1827,7 +1839,7 @@ deriveBodyDecodable_enum_init(AbstractFunctionDecl *initDecl, void *) { new (C) AssignExpr(selfRef, SourceLoc(), selfCaseExpr, /*Implicit=*/true); - caseStatements.push_back(assignExpr); + caseStatements.push_back(wrapInUnsafeIfNeeded(C, assignExpr)); } else { // Foo.bar(x:) SmallVector scratch; @@ -1845,7 +1857,7 @@ deriveBodyDecodable_enum_init(AbstractFunctionDecl *initDecl, void *) { new (C) AssignExpr(selfRef, SourceLoc(), caseCallExpr, /*Implicit=*/true); - caseStatements.push_back(assignExpr); + caseStatements.push_back(wrapInUnsafeIfNeeded(C, assignExpr)); } auto body = diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index e4d51eac1a1ce..763baf7df3041 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -4614,18 +4614,25 @@ class CheckEffectsCoverage : public EffectsHandlingWalker } void diagnoseRedundantUnsafe(UnsafeExpr *E) const { - // Silence this warning in the expansion of the _SwiftifyImport macro. - // This is a hack because it's tricky to determine when to insert "unsafe". - unsigned bufferID = - Ctx.SourceMgr.findBufferContainingLoc(E->getUnsafeLoc()); - if (auto sourceInfo = Ctx.SourceMgr.getGeneratedSourceInfo(bufferID)) { - if (sourceInfo->macroName == "_SwiftifyImport") - return; + // Ignore implicitly-generated "unsafe" expressions; they're allowed to be + // overly conservative. + if (E->isImplicit()) + return; + + SourceLoc loc = E->getUnsafeLoc(); + if (loc.isValid()) { + // Silence this warning in the expansion of the _SwiftifyImport macro. + // This is a hack because it's tricky to determine when to insert "unsafe". + unsigned bufferID = Ctx.SourceMgr.findBufferContainingLoc(loc); + if (auto sourceInfo = Ctx.SourceMgr.getGeneratedSourceInfo(bufferID)) { + if (sourceInfo->macroName == "_SwiftifyImport") + return; + } } if (auto *SVE = SingleValueStmtExpr::tryDigOutSingleValueStmtExpr(E)) { // For an if/switch expression, produce a tailored warning. - Ctx.Diags.diagnose(E->getUnsafeLoc(), + Ctx.Diags.diagnose(loc, diag::effect_marker_on_single_value_stmt, "unsafe", SVE->getStmt()->getKind()) .highlight(E->getUnsafeLoc()); diff --git a/lib/Sema/TypeCheckPattern.cpp b/lib/Sema/TypeCheckPattern.cpp index f66ed36fb8b37..1a0ea08c01a3c 100644 --- a/lib/Sema/TypeCheckPattern.cpp +++ b/lib/Sema/TypeCheckPattern.cpp @@ -786,7 +786,7 @@ ExprPatternMatchRequest::evaluate(Evaluator &evaluator, // If there was an "unsafe", put it outside of the match call. if (unsafeExpr) { - matchCall = UnsafeExpr::createImplicit(ctx, unsafeExpr->getLoc(), matchCall); + matchCall = new (ctx) UnsafeExpr(unsafeExpr->getLoc(), matchCall); } return {matchVar, matchCall}; diff --git a/test/Unsafe/codable_synthesis.swift b/test/Unsafe/codable_synthesis.swift new file mode 100644 index 0000000000000..282db90cd2c09 --- /dev/null +++ b/test/Unsafe/codable_synthesis.swift @@ -0,0 +1,26 @@ +// RUN: %target-typecheck-verify-swift -strict-memory-safety + +@unsafe public struct UnsafeStruct: Codable { + public var string: String +} + + +@unsafe public enum UnsafeEnum: Codable { +case something(Int) +} + +@safe public struct SafeStruct: Codable { + public var us: UnsafeStruct +} + +@safe public enum SafeEnum: Codable { +case something(UnsafeEnum) +} + +@unsafe public class C1: Codable { + public var string: String = "" +} + +@unsafe public class C2: C1 { + public var otherString: String = "" +} diff --git a/test/Unsafe/hashable_synthesis.swift b/test/Unsafe/hashable_synthesis.swift new file mode 100644 index 0000000000000..188f433462aad --- /dev/null +++ b/test/Unsafe/hashable_synthesis.swift @@ -0,0 +1,18 @@ +// RUN: %target-typecheck-verify-swift -strict-memory-safety + +@unsafe public struct UnsafeStruct: Hashable { + public var string: String +} + + +@unsafe public enum UnsafeEnum: Hashable { +case something(Int) +} + +@safe public struct SafeStruct: Hashable { + public var us: UnsafeStruct +} + +@safe public enum SafeEnum: Hashable { +case something(UnsafeEnum) +}