diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index a7f6be1327035..7ffa9cb0fb0ec 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -4692,10 +4692,16 @@ generateForEachStmtConstraints(ConstraintSystem &cs, DeclContext *dc, } // Wrap the 'next' call in 'unsafe', if the for..in loop has that - // effect. - if (stmt->getUnsafeLoc().isValid()) { - nextCall = new (ctx) UnsafeExpr( - stmt->getUnsafeLoc(), nextCall, Type(), /*implicit=*/true); + // effect or if the loop is async (in which case the iterator variable + // is nonisolated(unsafe). + if (stmt->getUnsafeLoc().isValid() || + (isAsync && + ctx.LangOpts.StrictConcurrencyLevel == StrictConcurrency::Complete)) { + SourceLoc loc = stmt->getUnsafeLoc(); + bool implicit = stmt->getUnsafeLoc().isInvalid(); + if (loc.isInvalid()) + loc = stmt->getForLoc(); + nextCall = new (ctx) UnsafeExpr(loc, nextCall, Type(), implicit); } // The iterator type must conform to IteratorProtocol. diff --git a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp index 8d9b5bd022931..1b88527d1dc61 100644 --- a/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp +++ b/lib/Sema/DerivedConformance/DerivedConformanceCodable.cpp @@ -772,10 +772,20 @@ lookupVarDeclForCodingKeysCase(DeclContext *conformanceDC, llvm_unreachable("Should have found at least 1 var decl"); } -static TryExpr *createEncodeCall(ASTContext &C, Type codingKeysType, - EnumElementDecl *codingKey, - Expr *containerExpr, Expr *varExpr, - bool useIfPresentVariant) { +/// If strict memory safety checking is enabled, wrap the expression in an +/// implicit "unsafe". +static Expr *wrapInUnsafeIfNeeded(ASTContext &ctx, Expr *expr) { + if (ctx.LangOpts.hasFeature(Feature::StrictMemorySafety, + /*allowMigration=*/true)) + return UnsafeExpr::createImplicit(ctx, expr->getStartLoc(), expr); + + return expr; +} + +static Expr *createEncodeCall(ASTContext &C, Type codingKeysType, + EnumElementDecl *codingKey, + Expr *containerExpr, Expr *varExpr, + bool useIfPresentVariant) { // CodingKeys.x auto *metaTyRef = TypeExpr::createImplicit(codingKeysType, C); auto *keyExpr = new (C) MemberRefExpr(metaTyRef, SourceLoc(), codingKey, @@ -794,7 +804,7 @@ static TryExpr *createEncodeCall(ASTContext &C, Type codingKeysType, // try container.encode(x, forKey: CodingKeys.x) auto *tryExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - return tryExpr; + return wrapInUnsafeIfNeeded(C, tryExpr); } /// Synthesizes the body for `func encode(to encoder: Encoder) throws`. @@ -929,7 +939,7 @@ deriveBodyEncodable_encode(AbstractFunctionDecl *encodeDecl, void *) { // try super.encode(to: container.superEncoder()) auto *tryExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - statements.push_back(tryExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, tryExpr)); } auto *body = BraceStmt::create(C, SourceLoc(), statements, SourceLoc(), @@ -1112,8 +1122,10 @@ deriveBodyEncodable_enum_encode(AbstractFunctionDecl *encodeDecl, void *) { // generate: switch self { } auto enumRef = - new (C) DeclRefExpr(ConcreteDeclRef(selfRef), DeclNameLoc(), - /*implicit*/ true, AccessSemantics::Ordinary); + wrapInUnsafeIfNeeded( + C, + new (C) DeclRefExpr(ConcreteDeclRef(selfRef), DeclNameLoc(), + /*implicit*/ true, AccessSemantics::Ordinary)); auto switchStmt = createEnumSwitch( C, funcDC, enumRef, enumDecl, codingKeysEnum, /*createSubpattern*/ true, @@ -1276,11 +1288,11 @@ static FuncDecl *deriveEncodable_encode(DerivedConformance &derived) { return encodeDecl; } -static TryExpr *createDecodeCall(ASTContext &C, Type resultType, - Type codingKeysType, - EnumElementDecl *codingKey, - Expr *containerExpr, - bool useIfPresentVariant) { +static Expr *createDecodeCall(ASTContext &C, Type resultType, + Type codingKeysType, + EnumElementDecl *codingKey, + Expr *containerExpr, + bool useIfPresentVariant) { auto methodName = useIfPresentVariant ? C.Id_decodeIfPresent : C.Id_decode; // Type.self @@ -1470,7 +1482,7 @@ deriveBodyDecodable_init(AbstractFunctionDecl *initDecl, void *) { varDecl->getName()); auto *assignExpr = new (C) AssignExpr(varExpr, SourceLoc(), tryExpr, /*Implicit=*/true); - statements.push_back(assignExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, assignExpr)); } } @@ -1506,7 +1518,7 @@ deriveBodyDecodable_init(AbstractFunctionDecl *initDecl, void *) { // try super.init(from: container.superDecoder()) auto *tryExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - statements.push_back(tryExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, tryExpr)); } else { // The explicit constructor name is a compound name taking no arguments. DeclName initName(C, DeclBaseName::createConstructor(), @@ -1538,7 +1550,7 @@ deriveBodyDecodable_init(AbstractFunctionDecl *initDecl, void *) { callExpr = new (C) TryExpr(SourceLoc(), callExpr, Type(), /*Implicit=*/true); - statements.push_back(callExpr); + statements.push_back(wrapInUnsafeIfNeeded(C, callExpr)); } } } @@ -1827,7 +1839,7 @@ deriveBodyDecodable_enum_init(AbstractFunctionDecl *initDecl, void *) { new (C) AssignExpr(selfRef, SourceLoc(), selfCaseExpr, /*Implicit=*/true); - caseStatements.push_back(assignExpr); + caseStatements.push_back(wrapInUnsafeIfNeeded(C, assignExpr)); } else { // Foo.bar(x:) SmallVector scratch; @@ -1845,7 +1857,7 @@ deriveBodyDecodable_enum_init(AbstractFunctionDecl *initDecl, void *) { new (C) AssignExpr(selfRef, SourceLoc(), caseCallExpr, /*Implicit=*/true); - caseStatements.push_back(assignExpr); + caseStatements.push_back(wrapInUnsafeIfNeeded(C, assignExpr)); } auto body = diff --git a/lib/Sema/TypeCheckEffects.cpp b/lib/Sema/TypeCheckEffects.cpp index 1d76f6186b9d0..763baf7df3041 100644 --- a/lib/Sema/TypeCheckEffects.cpp +++ b/lib/Sema/TypeCheckEffects.cpp @@ -2439,7 +2439,7 @@ class ApplyClassifier { return ShouldRecurse; } ShouldRecurse_t checkUnsafe(UnsafeExpr *E) { - return E->isImplicit() ? ShouldRecurse : ShouldNotRecurse; + return ShouldNotRecurse; } ShouldRecurse_t checkTry(TryExpr *E) { return ShouldRecurse; @@ -4573,10 +4573,6 @@ class CheckEffectsCoverage : public EffectsHandlingWalker diagnoseUnsafeUse(unsafeUse); } } - } else if (S->getUnsafeLoc().isValid()) { - // Extraneous "unsafe" on the sequence. - Ctx.Diags.diagnose(S->getUnsafeLoc(), diag::no_unsafe_in_unsafe_for) - .fixItRemove(S->getUnsafeLoc()); } return ShouldRecurse; @@ -4618,25 +4614,35 @@ class CheckEffectsCoverage : public EffectsHandlingWalker } void diagnoseRedundantUnsafe(UnsafeExpr *E) const { - // Silence this warning in the expansion of the _SwiftifyImport macro. - // This is a hack because it's tricky to determine when to insert "unsafe". - unsigned bufferID = - Ctx.SourceMgr.findBufferContainingLoc(E->getUnsafeLoc()); - if (auto sourceInfo = Ctx.SourceMgr.getGeneratedSourceInfo(bufferID)) { - if (sourceInfo->macroName == "_SwiftifyImport") - return; + // Ignore implicitly-generated "unsafe" expressions; they're allowed to be + // overly conservative. + if (E->isImplicit()) + return; + + SourceLoc loc = E->getUnsafeLoc(); + if (loc.isValid()) { + // Silence this warning in the expansion of the _SwiftifyImport macro. + // This is a hack because it's tricky to determine when to insert "unsafe". + unsigned bufferID = Ctx.SourceMgr.findBufferContainingLoc(loc); + if (auto sourceInfo = Ctx.SourceMgr.getGeneratedSourceInfo(bufferID)) { + if (sourceInfo->macroName == "_SwiftifyImport") + return; + } } if (auto *SVE = SingleValueStmtExpr::tryDigOutSingleValueStmtExpr(E)) { // For an if/switch expression, produce a tailored warning. - Ctx.Diags.diagnose(E->getUnsafeLoc(), + Ctx.Diags.diagnose(loc, diag::effect_marker_on_single_value_stmt, "unsafe", SVE->getStmt()->getKind()) .highlight(E->getUnsafeLoc()); return; } - Ctx.Diags.diagnose(E->getUnsafeLoc(), diag::no_unsafe_in_unsafe) + Ctx.Diags.diagnose(E->getUnsafeLoc(), + forEachNextCallExprs.contains(E) + ? diag::no_unsafe_in_unsafe_for + : diag::no_unsafe_in_unsafe) .fixItRemove(E->getUnsafeLoc()); } diff --git a/lib/Sema/TypeCheckPattern.cpp b/lib/Sema/TypeCheckPattern.cpp index f66ed36fb8b37..1a0ea08c01a3c 100644 --- a/lib/Sema/TypeCheckPattern.cpp +++ b/lib/Sema/TypeCheckPattern.cpp @@ -786,7 +786,7 @@ ExprPatternMatchRequest::evaluate(Evaluator &evaluator, // If there was an "unsafe", put it outside of the match call. if (unsafeExpr) { - matchCall = UnsafeExpr::createImplicit(ctx, unsafeExpr->getLoc(), matchCall); + matchCall = new (ctx) UnsafeExpr(unsafeExpr->getLoc(), matchCall); } return {matchVar, matchCall}; diff --git a/test/Unsafe/codable_synthesis.swift b/test/Unsafe/codable_synthesis.swift new file mode 100644 index 0000000000000..282db90cd2c09 --- /dev/null +++ b/test/Unsafe/codable_synthesis.swift @@ -0,0 +1,26 @@ +// RUN: %target-typecheck-verify-swift -strict-memory-safety + +@unsafe public struct UnsafeStruct: Codable { + public var string: String +} + + +@unsafe public enum UnsafeEnum: Codable { +case something(Int) +} + +@safe public struct SafeStruct: Codable { + public var us: UnsafeStruct +} + +@safe public enum SafeEnum: Codable { +case something(UnsafeEnum) +} + +@unsafe public class C1: Codable { + public var string: String = "" +} + +@unsafe public class C2: C1 { + public var otherString: String = "" +} diff --git a/test/Unsafe/hashable_synthesis.swift b/test/Unsafe/hashable_synthesis.swift new file mode 100644 index 0000000000000..188f433462aad --- /dev/null +++ b/test/Unsafe/hashable_synthesis.swift @@ -0,0 +1,18 @@ +// RUN: %target-typecheck-verify-swift -strict-memory-safety + +@unsafe public struct UnsafeStruct: Hashable { + public var string: String +} + + +@unsafe public enum UnsafeEnum: Hashable { +case something(Int) +} + +@safe public struct SafeStruct: Hashable { + public var us: UnsafeStruct +} + +@safe public enum SafeEnum: Hashable { +case something(UnsafeEnum) +} diff --git a/test/Unsafe/safe.swift b/test/Unsafe/safe.swift index 5131dfbfee172..2af1ca48ef309 100644 --- a/test/Unsafe/safe.swift +++ b/test/Unsafe/safe.swift @@ -98,6 +98,8 @@ func testUnsafeAsSequenceForEach() { for _ in unsafe uas { } // expected-warning{{for-in loop uses unsafe constructs but is not marked with 'unsafe'}}{{documentation-file=strict-memory-safety}}{{7-7=unsafe }} for unsafe _ in unsafe uas { } // okay + + for unsafe _ in [1, 2, 3] { } // expected-warning{{no unsafe operations occur within 'unsafe' for-in loop}} } func testForInUnsafeAmbiguity(_ integers: [Int]) { diff --git a/test/Unsafe/unsafe_concurrency.swift b/test/Unsafe/unsafe_concurrency.swift index 00dea6cf54d3e..895c62cdd816f 100644 --- a/test/Unsafe/unsafe_concurrency.swift +++ b/test/Unsafe/unsafe_concurrency.swift @@ -55,3 +55,10 @@ open class SyntaxVisitor { open class SyntaxAnyVisitor: SyntaxVisitor { override open func visit(_ token: TokenSyntax) { } } + +@available(SwiftStdlib 5.1, *) +func testMemorySafetyWithForLoop() async { + let (stream, continuation) = AsyncStream.makeStream() + for await _ in stream {} + _ = continuation +}