From e70fbbc73877b68318a474544ce68eefb53f310a Mon Sep 17 00:00:00 2001 From: Kavon Farvardin Date: Thu, 11 Aug 2022 18:04:55 -0700 Subject: [PATCH 1/3] [ConstraintSystem] correct the @preconcurrency adjustment of var references We intended to introduce AST conversions that strip concurrency attributes off of types associated with `@preconcurrency` decls. But for VarDecl references, we stripped it too early, leading to things like a MemberVarDecl that doesn't have `@Sendable` in its result type, but the VarDecl it refers to does have it. That caused crashes in SIL where types didn't match up. This patch fixes things by delaying the stripping until the right point. resolves rdar://98018067 --- include/swift/Sema/ConstraintSystem.h | 4 +++- lib/Sema/CSApply.cpp | 20 ++++++++++++++-- lib/Sema/ConstraintSystem.cpp | 33 +++++++++++++++++++++------ 3 files changed, 47 insertions(+), 10 deletions(-) diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index 12457af2c1849..07fe88459d7a8 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -4764,7 +4764,8 @@ class ConstraintSystem { Type getUnopenedTypeOfReference(VarDecl *value, Type baseType, DeclContext *UseDC, ConstraintLocator *memberLocator = nullptr, - bool wantInterfaceType = false); + bool wantInterfaceType = false, + bool adjustForPreconcurrency = true); /// Return the type-of-reference of the given value. /// @@ -4786,6 +4787,7 @@ class ConstraintSystem { llvm::function_ref getType, ConstraintLocator *memberLocator = nullptr, bool wantInterfaceType = false, + bool adjustForPreconcurrency = true, llvm::function_ref getClosureType = [](const AbstractClosureExpr *) { return Type(); diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 32190a7da0cdf..09f7c4560843d 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -919,6 +919,15 @@ namespace { auto &context = cs.getASTContext(); + // turn LValues into RValues first + if (openedType->hasLValueType()) { + assert(adjustedOpenedType->hasLValueType() && "lvalue-ness mismatch?"); + return adjustTypeForDeclReference(cs.coerceToRValue(expr), + openedType->getRValueType(), + adjustedOpenedType->getRValueType(), + getNewType); + } + // If we have an optional type, wrap it up in a monadic '?' and recurse. if (Type objectType = openedType->getOptionalObjectType()) { Type adjustedRefType = getNewType(adjustedOpenedType); @@ -1669,10 +1678,17 @@ namespace { adjustedRefTy = adjustedRefTy->replaceCovariantResultType( containerTy, 1); } - cs.setType(memberRefExpr, refTy->castTo()->getResult()); + + // \returns result of the given function type + auto resultType = [](Type fnTy) -> Type { + return fnTy->castTo()->getResult(); + }; + + cs.setType(memberRefExpr, resultType(refTy)); Expr *result = memberRefExpr; - result = adjustTypeForDeclReference(result, refTy, adjustedRefTy); + result = adjustTypeForDeclReference(result, resultType(refTy), + resultType(adjustedRefTy)); closeExistentials(result, locator); // If the property is of dynamic 'Self' type, wrap an implicit diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index a675f32f7849d..bba808a464926 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -1259,7 +1259,8 @@ ClosureIsolatedByPreconcurrency::operator()(const ClosureExpr *expr) const { Type ConstraintSystem::getUnopenedTypeOfReference( VarDecl *value, Type baseType, DeclContext *UseDC, - ConstraintLocator *memberLocator, bool wantInterfaceType) { + ConstraintLocator *memberLocator, bool wantInterfaceType, + bool adjustForPreconcurrency) { return ConstraintSystem::getUnopenedTypeOfReference( value, baseType, UseDC, [&](VarDecl *var) -> Type { @@ -1272,22 +1273,25 @@ Type ConstraintSystem::getUnopenedTypeOfReference( return wantInterfaceType ? var->getInterfaceType() : var->getType(); }, - memberLocator, wantInterfaceType, GetClosureType{*this}, + memberLocator, wantInterfaceType, adjustForPreconcurrency, + GetClosureType{*this}, ClosureIsolatedByPreconcurrency{*this}); } Type ConstraintSystem::getUnopenedTypeOfReference( VarDecl *value, Type baseType, DeclContext *UseDC, llvm::function_ref getType, - ConstraintLocator *memberLocator, bool wantInterfaceType, + ConstraintLocator *memberLocator, + bool wantInterfaceType, bool adjustForPreconcurrency, llvm::function_ref getClosureType, llvm::function_ref isolatedByPreconcurrency) { Type requestedType = getType(value)->getWithoutSpecifierType()->getReferenceStorageReferent(); - // Adjust the type for concurrency. - requestedType = adjustVarTypeForConcurrency( - requestedType, value, UseDC, getClosureType, isolatedByPreconcurrency); + // Adjust the type for concurrency if requested. + if (adjustForPreconcurrency) + requestedType = adjustVarTypeForConcurrency( + requestedType, value, UseDC, getClosureType, isolatedByPreconcurrency); // If we're dealing with contextual types, and we referenced this type from // a different context, map the type. @@ -2309,9 +2313,14 @@ ConstraintSystem::getTypeOfMemberReference( FunctionType::ExtInfo info; refType = FunctionType::get(indices, elementTy, info); } else { + // Delay the adjustment for preconcurrency until after we've formed + // the function type for this kind of reference. Otherwise we will lose + // track of the adjustment in the formed function's return type. + refType = getUnopenedTypeOfReference(cast(value), baseTy, useDC, locator, - /*wantInterfaceType=*/true); + /*wantInterfaceType=*/true, + /*adjustForPreconcurrency=*/false); } auto selfTy = outerDC->getSelfInterfaceType(); @@ -2432,6 +2441,16 @@ ConstraintSystem::getTypeOfMemberReference( openedType = adjustFunctionTypeForConcurrency( origOpenedType->castTo(), subscript, useDC, /*numApplies=*/2, /*isMainDispatchQueue=*/false, replacements); + } else if (auto var = dyn_cast(value)) { + // Adjust the function's result type, since that's the Var's actual type. + auto origFnType = origOpenedType->castTo(); + + auto resultTy = adjustVarTypeForConcurrency( + origFnType->getResult(), var, useDC, GetClosureType{*this}, + ClosureIsolatedByPreconcurrency{*this}); + + openedType = FunctionType::get( + origFnType->getParams(), resultTy, origFnType->getExtInfo()); } // Compute the type of the reference. From 6c24bc57cbf318a5092c54943e1e21bbcdb7c537 Mon Sep 17 00:00:00 2001 From: Kavon Farvardin Date: Tue, 23 Aug 2022 18:02:57 -0700 Subject: [PATCH 2/3] [AST][SILGen] model ABI-safe casts of LValues We needed a way to describe an ABI-safe cast of an address representing an LValue to implement `@preconcurrency` and its injection of casts during accesses of members. This new AST node, `ABISafeConversionExpr` models what is essentially an `unchecked_addr_cast` in SIL when accessing the LVAlue. As of now I simply implemented it and the verification of the node for the concurrency needs to ensure that it's not misused by accident. If it finds use outside of that, feel free to update the verifier. --- include/swift/AST/Expr.h | 13 +++++++++++ include/swift/AST/ExprNodes.def | 1 + lib/AST/ASTDumper.cpp | 5 +++++ lib/AST/ASTPrinter.cpp | 3 +++ lib/AST/ASTVerifier.cpp | 39 +++++++++++++++++++++++++++++++++ lib/AST/Expr.cpp | 3 +++ lib/SILGen/LValue.h | 1 + lib/SILGen/SILGenExpr.cpp | 3 +++ lib/SILGen/SILGenLValue.cpp | 37 +++++++++++++++++++++++++++++++ lib/Sema/CSApply.cpp | 29 ++++++++++++------------ 10 files changed, 120 insertions(+), 14 deletions(-) diff --git a/include/swift/AST/Expr.h b/include/swift/AST/Expr.h index a5fa92b7ba9fc..bd20cd41110cc 100644 --- a/include/swift/AST/Expr.h +++ b/include/swift/AST/Expr.h @@ -3157,6 +3157,19 @@ class LoadExpr : public ImplicitConversionExpr { static bool classof(const Expr *E) { return E->getKind() == ExprKind::Load; } }; +/// ABISafeConversion - models a type conversion on an l-value that has no +/// material affect on the ABI of the type, while *preserving* the l-valueness +/// of the type. +class ABISafeConversionExpr : public ImplicitConversionExpr { +public: + ABISafeConversionExpr(Expr *subExpr, Type type) + : ImplicitConversionExpr(ExprKind::ABISafeConversion, subExpr, type) {} + + static bool classof(const Expr *E) { + return E->getKind() == ExprKind::ABISafeConversion; + } +}; + /// This is a conversion from an expression of UnresolvedType to an arbitrary /// other type, and from an arbitrary type to UnresolvedType. This node does /// not appear in valid code, only in code involving diagnostics. diff --git a/include/swift/AST/ExprNodes.def b/include/swift/AST/ExprNodes.def index a1833fca267f3..5475031f03bdf 100644 --- a/include/swift/AST/ExprNodes.def +++ b/include/swift/AST/ExprNodes.def @@ -152,6 +152,7 @@ ABSTRACT_EXPR(Apply, Expr) EXPR_RANGE(Apply, Call, ConstructorRefCall) ABSTRACT_EXPR(ImplicitConversion, Expr) EXPR(Load, ImplicitConversionExpr) + EXPR(ABISafeConversion, ImplicitConversionExpr) EXPR(DestructureTuple, ImplicitConversionExpr) EXPR(UnresolvedTypeConversion, ImplicitConversionExpr) EXPR(FunctionConversion, ImplicitConversionExpr) diff --git a/lib/AST/ASTDumper.cpp b/lib/AST/ASTDumper.cpp index fc0cd7595fe39..017c2cf64e881 100644 --- a/lib/AST/ASTDumper.cpp +++ b/lib/AST/ASTDumper.cpp @@ -2276,6 +2276,11 @@ class PrintExpr : public ExprVisitor { printRec(E->getSubExpr()); PrintWithColorRAII(OS, ParenthesisColor) << ')'; } + void visitABISafeConversionExpr(ABISafeConversionExpr *E) { + printCommon(E, "abi_safe_conversion_expr") << '\n'; + printRec(E->getSubExpr()); + PrintWithColorRAII(OS, ParenthesisColor) << ')'; + } void visitMetatypeConversionExpr(MetatypeConversionExpr *E) { printCommon(E, "metatype_conversion_expr") << '\n'; printRec(E->getSubExpr()); diff --git a/lib/AST/ASTPrinter.cpp b/lib/AST/ASTPrinter.cpp index 268dc6885666e..6e01bf741d864 100644 --- a/lib/AST/ASTPrinter.cpp +++ b/lib/AST/ASTPrinter.cpp @@ -4793,6 +4793,9 @@ void PrintAST::visitConstructorRefCallExpr(ConstructorRefCallExpr *expr) { } } +void PrintAST::visitABISafeConversionExpr(ABISafeConversionExpr *expr) { +} + void PrintAST::visitFunctionConversionExpr(FunctionConversionExpr *expr) { } diff --git a/lib/AST/ASTVerifier.cpp b/lib/AST/ASTVerifier.cpp index 587e6b135d648..fec82a6fe5ea1 100644 --- a/lib/AST/ASTVerifier.cpp +++ b/lib/AST/ASTVerifier.cpp @@ -240,6 +240,17 @@ class Verifier : public ASTWalker { pushScope(DC); } + /// Emit an error message and abort, optionally dumping the expression. + /// \param E if non-null, the expression to dump() followed by a new-line. + void error(llvm::StringRef msg, Expr *E = nullptr) { + Out << msg << "\n"; + if (E) { + E->dump(Out); + Out << "\n"; + } + abort(); + } + public: Verifier(ModuleDecl *M, DeclContext *DC) : Verifier(PointerUnion(M), DC) {} @@ -2306,6 +2317,34 @@ class Verifier : public ASTWalker { verifyCheckedBase(E); } + void verifyChecked(ABISafeConversionExpr *E) { + PrettyStackTraceExpr debugStack(Ctx, "verify ABISafeConversionExpr", E); + + auto toType = E->getType(); + auto fromType = E->getSubExpr()->getType(); + + if (!fromType->hasLValueType()) + error("conversion source must be an l-value", E); + + if (!toType->hasLValueType()) + error("conversion result must be an l-value", E); + + { + // At the moment, "ABI Safe" means concurrency features can be stripped. + // Since we don't know how deeply the stripping is happening, to verify + // in a fuzzy way, strip everything to see if they're the same type. + auto strippedFrom = fromType->getRValueType() + ->stripConcurrency(/*recurse*/true, + /*dropGlobalActor*/true); + auto strippedTo = toType->getRValueType() + ->stripConcurrency(/*recurse*/true, + /*dropGlobalActor*/true); + + if (!strippedFrom->isEqual(strippedTo)) + error("possibly non-ABI safe conversion", E); + } + } + void verifyChecked(ValueDecl *VD) { if (VD->getInterfaceType()->hasError()) { Out << "checked decl cannot have error type\n"; diff --git a/lib/AST/Expr.cpp b/lib/AST/Expr.cpp index ef12e1e8d1c1d..6f2898ca3de80 100644 --- a/lib/AST/Expr.cpp +++ b/lib/AST/Expr.cpp @@ -414,6 +414,7 @@ ConcreteDeclRef Expr::getReferencedDecl(bool stopAtParenExpr) const { PASS_THROUGH_REFERENCE(Load, getSubExpr); NO_REFERENCE(DestructureTuple); NO_REFERENCE(UnresolvedTypeConversion); + PASS_THROUGH_REFERENCE(ABISafeConversion, getSubExpr); PASS_THROUGH_REFERENCE(FunctionConversion, getSubExpr); PASS_THROUGH_REFERENCE(CovariantFunctionConversion, getSubExpr); PASS_THROUGH_REFERENCE(CovariantReturnConversion, getSubExpr); @@ -741,6 +742,7 @@ bool Expr::canAppendPostfixExpression(bool appendingPostfixOperator) const { return false; case ExprKind::Load: + case ExprKind::ABISafeConversion: case ExprKind::DestructureTuple: case ExprKind::UnresolvedTypeConversion: case ExprKind::FunctionConversion: @@ -914,6 +916,7 @@ bool Expr::isValidParentOfTypeExpr(Expr *typeExpr) const { case ExprKind::Load: case ExprKind::DestructureTuple: case ExprKind::UnresolvedTypeConversion: + case ExprKind::ABISafeConversion: case ExprKind::FunctionConversion: case ExprKind::CovariantFunctionConversion: case ExprKind::CovariantReturnConversion: diff --git a/lib/SILGen/LValue.h b/lib/SILGen/LValue.h index 13155beead466..9b16da86bedcb 100644 --- a/lib/SILGen/LValue.h +++ b/lib/SILGen/LValue.h @@ -105,6 +105,7 @@ class PathComponent { CoroutineAccessorKind, // coroutine accessor ValueKind, // random base pointer as an lvalue PhysicalKeyPathApplicationKind, // applying a key path + ABISafeConversionKind, // unchecked_addr_cast // Logical LValue kinds GetterSetterKind, // property or subscript getter/setter diff --git a/lib/SILGen/SILGenExpr.cpp b/lib/SILGen/SILGenExpr.cpp index f204f328a4554..58eb3644e7dad 100644 --- a/lib/SILGen/SILGenExpr.cpp +++ b/lib/SILGen/SILGenExpr.cpp @@ -451,6 +451,9 @@ namespace { RValue visitArchetypeToSuperExpr(ArchetypeToSuperExpr *E, SGFContext C); RValue visitUnresolvedTypeConversionExpr(UnresolvedTypeConversionExpr *E, SGFContext C); + RValue visitABISafeConversionExpr(ABISafeConversionExpr *E, SGFContext C) { + llvm_unreachable("cannot appear in rvalue"); + } RValue visitFunctionConversionExpr(FunctionConversionExpr *E, SGFContext C); RValue visitCovariantFunctionConversionExpr( diff --git a/lib/SILGen/SILGenLValue.cpp b/lib/SILGen/SILGenLValue.cpp index 1256bcd570789..7598cde92713a 100644 --- a/lib/SILGen/SILGenLValue.cpp +++ b/lib/SILGen/SILGenLValue.cpp @@ -327,6 +327,9 @@ class LLVM_LIBRARY_VISIBILITY SILGenLValue LValueOptions options); LValue visitMoveExpr(MoveExpr *e, SGFAccessKind accessKind, LValueOptions options); + LValue visitABISafeConversionExpr(ABISafeConversionExpr *e, + SGFAccessKind accessKind, + LValueOptions options); // Expressions that wrap lvalues @@ -2204,6 +2207,29 @@ namespace { OS.indent(indent) << "PhysicalKeyPathApplicationComponent\n"; } }; + + /// A physical component which performs an unchecked_addr_cast + class ABISafeConversionComponent final : public PhysicalPathComponent { + public: + ABISafeConversionComponent(LValueTypeData typeData) + : PhysicalPathComponent(typeData, ABISafeConversionKind, + /*actorIsolation=*/None) {} + + ManagedValue project(SILGenFunction &SGF, SILLocation loc, + ManagedValue base) && override { + auto toType = SGF.getLoweredType(getTypeData().SubstFormalType) + .getAddressType(); + + if (base.getType() == toType) + return base; // nothing to do + + return SGF.B.createUncheckedAddrCast(loc, base, toType); + } + + void dump(raw_ostream &OS, unsigned indent) const override { + OS.indent(indent) << "ABISafeConversionComponent\n"; + } + }; } // end anonymous namespace RValue @@ -3734,6 +3760,17 @@ LValue SILGenLValue::visitMoveExpr(MoveExpr *e, SGFAccessKind accessKind, toAddr->getType().getASTType()); } +LValue SILGenLValue::visitABISafeConversionExpr(ABISafeConversionExpr *e, + SGFAccessKind accessKind, + LValueOptions options) { + LValue lval = visitRec(e->getSubExpr(), accessKind, options); + auto typeData = getValueTypeData(SGF, accessKind, e); + + lval.add(typeData); + + return lval; +} + /// Emit an lvalue that refers to the given property. This is /// designed to work with ManagedValue 'base's that are either +0 or +1. LValue SILGenFunction::emitPropertyLValue(SILLocation loc, ManagedValue base, diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 09f7c4560843d..9708a01d98f23 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -919,13 +919,22 @@ namespace { auto &context = cs.getASTContext(); - // turn LValues into RValues first + // For an RValue function type, use a standard function conversion. + if (openedType->is()) { + expr = new (context) FunctionConversionExpr( + expr, getNewType(adjustedOpenedType)); + cs.cacheType(expr); + return expr; + } + + // For any kind of LValue, use an ABISafeConversion. if (openedType->hasLValueType()) { - assert(adjustedOpenedType->hasLValueType() && "lvalue-ness mismatch?"); - return adjustTypeForDeclReference(cs.coerceToRValue(expr), - openedType->getRValueType(), - adjustedOpenedType->getRValueType(), - getNewType); + assert(adjustedOpenedType->hasLValueType() && "lvalueness mismatch?"); + + expr = new (context) ABISafeConversionExpr( + expr, getNewType(adjustedOpenedType)); + cs.cacheType(expr); + return expr; } // If we have an optional type, wrap it up in a monadic '?' and recurse. @@ -944,14 +953,6 @@ namespace { return expr; } - // For a function type, perform a function conversion. - if (openedType->is()) { - expr = new (context) FunctionConversionExpr( - expr, getNewType(adjustedOpenedType)); - cs.cacheType(expr); - return expr; - } - assert(false && "Unhandled adjustment"); return expr; } From a4d661d3bbfa5777a2a7a6001e3ab1295455f736 Mon Sep 17 00:00:00 2001 From: Kavon Farvardin Date: Fri, 19 Aug 2022 15:36:24 -0700 Subject: [PATCH 3/3] Add tests for rdar://98018067 --- test/SILGen/Inputs/objc_preconcurrency.h | 9 ++++++++ test/SILGen/objc_preconcurrency.swift | 29 +++++++++++++++++++++++- 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 test/SILGen/Inputs/objc_preconcurrency.h diff --git a/test/SILGen/Inputs/objc_preconcurrency.h b/test/SILGen/Inputs/objc_preconcurrency.h new file mode 100644 index 0000000000000..2e91503a4f264 --- /dev/null +++ b/test/SILGen/Inputs/objc_preconcurrency.h @@ -0,0 +1,9 @@ +@import Foundation; + +#define SENDABLE __attribute__((__swift_attr__("@Sendable"))) + +SENDABLE +@interface NSTouchGrass : NSObject +@property (nullable, copy) void (SENDABLE ^cancellationHandler)(void); +@property (nonnull, copy) void (SENDABLE ^exceptionHandler)(void); +@end diff --git a/test/SILGen/objc_preconcurrency.swift b/test/SILGen/objc_preconcurrency.swift index a99ee36a6f96e..69fa8b9474703 100644 --- a/test/SILGen/objc_preconcurrency.swift +++ b/test/SILGen/objc_preconcurrency.swift @@ -1,9 +1,10 @@ -// RUN: %target-swift-emit-silgen -module-name objc_preconcurrency -sdk %S/Inputs -I %S/Inputs -enable-source-import %s -disable-objc-attr-requires-foundation-module | %FileCheck %s +// RUN: %target-swift-emit-silgen -module-name objc_preconcurrency -sdk %S/Inputs -I %S/Inputs -enable-source-import -import-objc-header %S/Inputs/objc_preconcurrency.h %s -disable-objc-attr-requires-foundation-module | %FileCheck %s // REQUIRES: objc_interop @objc protocol P { @preconcurrency @objc optional func f(_ completionHandler: @Sendable @escaping () -> Void) + @preconcurrency var sendyHandler: @Sendable () -> Void { get set } } // CHECK-LABEL: sil hidden [ossa] @$s19objc_preconcurrency19testDynamicDispatch1p17completionHandleryAA1P_p_yyctF @@ -16,3 +17,29 @@ func testDynamicDispatch(p: P, completionHandler: @escaping () -> Void) { // CHECK: bb{{[0-9]+}}(%{{[0-9]+}} : $@convention(objc_method) (@convention(block) @Sendable () -> (), @opened let _ = p.f } + +// CHECK-LABEL: sil hidden [ossa] @$s19objc_preconcurrency21testOptionalVarAccessyySo12NSTouchGrassCF +// CHECK: unchecked_addr_cast {{.*}} : $*Optional<@Sendable @callee_guaranteed () -> ()> to $*Optional<@callee_guaranteed () -> ()> +// CHECK: } // end sil function '$s19objc_preconcurrency21testOptionalVarAccessyySo12NSTouchGrassCF' +func testOptionalVarAccess(_ grass: NSTouchGrass) { + grass.cancellationHandler?() +} + +func modify(_ v: inout () -> Void) { + v = {} +} + +// CHECK-LABEL: sil hidden [ossa] @$s19objc_preconcurrency15testInoutAccessyySo12NSTouchGrassCF +// CHECK: unchecked_addr_cast {{.*}} : $*@Sendable @callee_guaranteed () -> () to $*@callee_guaranteed () -> () +// CHECK: } // end sil function '$s19objc_preconcurrency15testInoutAccessyySo12NSTouchGrassCF' +func testInoutAccess(_ grass: NSTouchGrass) { + modify(&grass.exceptionHandler) +} + + +// CHECK-LABEL: sil hidden [ossa] @$s19objc_preconcurrency21testProtocolVarAccess1pyAA1P_p_tF +// CHECK: unchecked_addr_cast {{.*}} : $*@Sendable @callee_guaranteed () -> () to $*@callee_guaranteed () -> () +// CHECK: } // end sil function '$s19objc_preconcurrency21testProtocolVarAccess1pyAA1P_p_tF' +func testProtocolVarAccess(p: P) { + modify(&p.sendyHandler) +}