diff --git a/include/swift/Sema/ConstraintSystem.h b/include/swift/Sema/ConstraintSystem.h index abb99fe9a1d21..3dfb878e3f890 100644 --- a/include/swift/Sema/ConstraintSystem.h +++ b/include/swift/Sema/ConstraintSystem.h @@ -64,6 +64,7 @@ namespace constraints { class ConstraintSystem; class SyntacticElementTarget; +struct PreparedOverload; } // end namespace constraints @@ -2211,6 +2212,8 @@ class ConstraintSystem { unsigned CountDisjunctions = 0; private: + bool PreparingOverload = false; + /// A constraint that has failed along the current solver path. /// Do not set directly, call \c recordFailedConstraint instead. Constraint *failedConstraint = nullptr; @@ -2752,6 +2755,7 @@ class ConstraintSystem { SolverState *solverState = nullptr; void recordChange(SolverTrail::Change change) { + ASSERT(!PreparingOverload); solverState->Trail.recordChange(change); } @@ -2920,15 +2924,15 @@ class ConstraintSystem { SolverTrail *getTrail() const { return solverState ? &solverState->Trail : nullptr; } - - /// Add a new type variable that was already created. - void addTypeVariable(TypeVariableType *typeVar); /// Add a constraint from the subscript base to the root of the key /// path literal to the constraint system. void addKeyPathApplicationRootConstraint(Type root, ConstraintLocatorBuilder locator); public: + /// Add a new type variable that was already created. + void addTypeVariable(TypeVariableType *typeVar); + /// Lookup for a member with the given name which is in the given base type. /// /// This routine caches the results of member lookups in the top constraint @@ -2949,7 +2953,9 @@ class ConstraintSystem { /// Create a new type variable. TypeVariableType *createTypeVariable(ConstraintLocator *locator, - unsigned options); + unsigned options, + PreparedOverload *preparedOverload + = nullptr); /// Retrieve the set of active type variables. ArrayRef getTypeVariables() const { @@ -3407,7 +3413,8 @@ class ConstraintSystem { /// Update OpenedExistentials and record a change in the trail. void recordOpenedExistentialType(ConstraintLocator *locator, - ExistentialArchetypeType *opened); + ExistentialArchetypeType *opened, + PreparedOverload *preparedOverload = nullptr); /// Retrieve the generic environment for the opened element of a given pack /// expansion, or \c nullptr if no environment was recorded yet. @@ -3614,7 +3621,8 @@ class ConstraintSystem { /// Log and record the application of the fix. Return true iff any /// subsequent solution would be worse than the best known solution. - bool recordFix(ConstraintFix *fix, unsigned impact = 1); + bool recordFix(ConstraintFix *fix, unsigned impact = 1, + PreparedOverload *preparedOverload = nullptr); void recordPotentialHole(TypeVariableType *typeVar); void recordAnyTypeVarAsPotentialHole(Type type); @@ -3689,12 +3697,14 @@ class ConstraintSystem { /// Add a constraint to the constraint system. void addConstraint(ConstraintKind kind, Type first, Type second, ConstraintLocatorBuilder locator, - bool isFavored = false); + bool isFavored = false, + PreparedOverload *preparedOverload = nullptr); /// Add a requirement as a constraint to the constraint system. void addConstraint(Requirement req, ConstraintLocatorBuilder locator, bool isFavored, - bool prohibitNonisolatedConformance); + bool prohibitNonisolatedConformance, + PreparedOverload *preparedOverload = nullptr); void addApplicationConstraint( FunctionType *appliedFn, Type calleeType, @@ -4308,7 +4318,8 @@ class ConstraintSystem { /// \returns The opened type. Type openUnboundGenericType(GenericTypeDecl *decl, Type parentTy, ConstraintLocatorBuilder locator, - bool isTypeResolution); + bool isTypeResolution, + PreparedOverload *preparedOverload = nullptr); /// Replace placeholder types with fresh type variables, and unbound generic /// types with bound generic types whose generic args are fresh type @@ -4318,7 +4329,9 @@ class ConstraintSystem { /// /// \returns The converted type. Type replaceInferableTypesWithTypeVars(Type type, - ConstraintLocatorBuilder locator); + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload + = nullptr); /// "Open" the given type by replacing any occurrences of generic /// parameter types and dependent member types with fresh type variables. @@ -4329,7 +4342,8 @@ class ConstraintSystem { /// /// \returns The opened type, or \c type if there are no archetypes in it. Type openType(Type type, ArrayRef replacements, - ConstraintLocatorBuilder locator); + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload); /// "Open" an opaque archetype type, similar to \c openType. Type openOpaqueType(OpaqueTypeArchetypeType *type, @@ -4345,11 +4359,14 @@ class ConstraintSystem { /// aforementioned variable via special constraints. Type openPackExpansionType(PackExpansionType *expansion, ArrayRef replacements, - ConstraintLocatorBuilder locator); + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload); /// Update OpenedPackExpansionTypes and record a change in the trail. void recordOpenedPackExpansionType(PackExpansionType *expansion, - TypeVariableType *expansionVar); + TypeVariableType *expansionVar, + PreparedOverload *preparedOverload + = nullptr); /// Undo the above change. void removeOpenedPackExpansionType(PackExpansionType *expansion) { @@ -4374,26 +4391,30 @@ class ConstraintSystem { FunctionType *openFunctionType(AnyFunctionType *funcType, ConstraintLocatorBuilder locator, SmallVectorImpl &replacements, - DeclContext *outerDC); + DeclContext *outerDC, + PreparedOverload *preparedOverload); /// Open the generic parameter list and its requirements, /// creating type variables for each of the type parameters. void openGeneric(DeclContext *outerDC, GenericSignature signature, ConstraintLocatorBuilder locator, - SmallVectorImpl &replacements); + SmallVectorImpl &replacements, + PreparedOverload *preparedOverload); /// Open the generic parameter list creating type variables for each of the /// type parameters. void openGenericParameters(DeclContext *outerDC, GenericSignature signature, SmallVectorImpl &replacements, - ConstraintLocatorBuilder locator); + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload); /// Open a generic parameter into a type variable and record /// it in \c replacements. TypeVariableType *openGenericParameter(GenericTypeParamType *parameter, - ConstraintLocatorBuilder locator); + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload); /// Given generic signature open its generic requirements, /// using substitution function, and record them in the @@ -4402,7 +4423,8 @@ class ConstraintSystem { GenericSignature signature, bool skipProtocolSelfConstraint, ConstraintLocatorBuilder locator, - llvm::function_ref subst); + llvm::function_ref subst, + PreparedOverload *preparedOverload); // Record the given requirement in the constraint system. void openGenericRequirement(DeclContext *outerDC, @@ -4411,17 +4433,20 @@ class ConstraintSystem { const Requirement &requirement, bool skipProtocolSelfConstraint, ConstraintLocatorBuilder locator, - llvm::function_ref subst); + llvm::function_ref subst, + PreparedOverload *preparedOverload); /// Update OpenedTypes and record a change in the trail. void recordOpenedType( - ConstraintLocator *locator, ArrayRef openedTypes); + ConstraintLocator *locator, ArrayRef openedTypes, + PreparedOverload *preparedOverload = nullptr); /// Record the set of opened types for the given locator. void recordOpenedTypes( ConstraintLocatorBuilder locator, - SmallVectorImpl &replacements, - bool fixmeAllowDuplicates=false); + const SmallVectorImpl &replacements, + PreparedOverload *preparedOverload = nullptr, + bool fixmeAllowDuplicates = false); /// Check whether the given type conforms to the given protocol and if /// so return a valid conformance reference. @@ -4432,7 +4457,8 @@ class ConstraintSystem { FunctionType *adjustFunctionTypeForConcurrency( FunctionType *fnType, Type baseType, ValueDecl *decl, DeclContext *dc, unsigned numApplies, bool isMainDispatchQueue, - ArrayRef replacements, ConstraintLocatorBuilder locator); + ArrayRef replacements, ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload); /// Retrieve the type of a reference to the given value declaration. /// @@ -4447,7 +4473,8 @@ class ConstraintSystem { ValueDecl *decl, FunctionRefInfo functionRefInfo, ConstraintLocatorBuilder locator, - DeclContext *useDC); + DeclContext *useDC, + PreparedOverload *preparedOverload); /// Return the type-of-reference of the given value. /// @@ -4488,7 +4515,8 @@ class ConstraintSystem { DeclReferenceType getTypeOfMemberReference( Type baseTy, ValueDecl *decl, DeclContext *useDC, bool isDynamicLookup, FunctionRefInfo functionRefInfo, ConstraintLocator *locator, - SmallVectorImpl *replacements = nullptr); + SmallVectorImpl *replacements = nullptr, + PreparedOverload *preparedOverload = nullptr); /// Retrieve a list of generic parameter types solver has "opened" (replaced /// with a type variable) at the given location. @@ -5307,13 +5335,20 @@ class ConstraintSystem { /// Matches a wrapped or projected value parameter type to its backing /// property wrapper type by applying the property wrapper. TypeMatchResult applyPropertyWrapperToParameter( - Type wrapperType, Type paramType, ParamDecl *param, Identifier argLabel, - ConstraintKind matchKind, ConstraintLocator *locator, - ConstraintLocator *calleeLocator); + Type wrapperType, + Type paramType, + ParamDecl *param, + Identifier argLabel, + ConstraintKind matchKind, + ConstraintLocator *locator, + ConstraintLocator *calleeLocator, + PreparedOverload *preparedOverload = nullptr); /// Used by applyPropertyWrapperToParameter() to update appliedPropertyWrappers /// and record a change in the trail. - void applyPropertyWrapper(Expr *anchor, AppliedPropertyWrapper applied); + void applyPropertyWrapper(Expr *anchor, + AppliedPropertyWrapper applied, + PreparedOverload *preparedOverload = nullptr); /// Undo the above change. void removePropertyWrapper(Expr *anchor); diff --git a/include/swift/Sema/PreparedOverload.h b/include/swift/Sema/PreparedOverload.h new file mode 100644 index 0000000000000..78962f8927081 --- /dev/null +++ b/include/swift/Sema/PreparedOverload.h @@ -0,0 +1,165 @@ +//===--- PreparedOverload.h - A Choice from an Overload Set ----*- C++ -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2025 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +#ifndef SWIFT_SEMA_PREPAREDOVERLOAD_H +#define SWIFT_SEMA_PREPAREDOVERLOAD_H + +#include "swift/AST/PropertyWrappers.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/SmallVector.h" + +namespace swift { + +class ExistentialArchetypeType; +class GenericTypeParamType; +class TypeVariableType; + +namespace constraints { + +class ConstraintLocatorBuilder; +class ConstraintSystem; + +/// Describes a dependent type that has been opened to a particular type +/// variable. +using OpenedType = std::pair; + +/// A "pre-cooked" representation of all type variables and constraints +/// that are generated as part of an overload choice. +struct PreparedOverload { + /// A change to be introduced into the constraint system when this + /// overload choice is chosen. + struct Change { + enum ChangeKind : unsigned { + /// A generic parameter was opened to a type variable. + AddedTypeVariable, + + /// A generic requirement was opened to a constraint. + AddedConstraint, + + /// A mapping of generic parameter types to type variables + /// was recorded. + OpenedTypes, + + /// An existential type was opened. + OpenedExistentialType, + + /// A pack expansion type was opened. + OpenedPackExpansionType, + + /// A property wrapper was applied to a parameter. + AppliedPropertyWrapper, + + /// A fix was recorded because a property wrapper application failed. + AddedFix + }; + + /// The kind of change. + ChangeKind Kind; + + union { + /// For ChangeKind::AddedTypeVariable. + TypeVariableType *TypeVar; + + /// For ChangeKind::AddedConstraint. + Constraint *TheConstraint; + + /// For ChangeKind::OpenedTypes. + struct { + const OpenedType *Data; + size_t Count; + } Replacements; + + /// For ChangeKind::OpenedExistentialType. + ExistentialArchetypeType *TheExistential; + + /// For ChangeKind::OpenedPackExpansionType. + struct { + PackExpansionType *TheExpansion; + TypeVariableType *TypeVar; + } PackExpansion; + + /// For ChangeKind::AppliedPropertyWrapper. + struct { + TypeBase *WrapperType; + PropertyWrapperInitKind InitKind; + } PropertyWrapper; + + /// For ChangeKind::Fix. + struct { + ConstraintFix *TheFix; + unsigned Impact; + } Fix; + }; + }; + + SmallVector Changes; + + void addedTypeVariable(TypeVariableType *typeVar) { + Change change; + change.Kind = Change::AddedTypeVariable; + change.TypeVar = typeVar; + Changes.push_back(change); + } + + void addedConstraint(Constraint *constraint) { + Change change; + change.Kind = Change::AddedConstraint; + change.TheConstraint = constraint; + Changes.push_back(change); + } + + void openedTypes(ArrayRef replacements) { + Change change; + change.Kind = Change::OpenedTypes; + change.Replacements.Data = replacements.data(); + change.Replacements.Count = replacements.size(); + Changes.push_back(change); + } + + void openedExistentialType(ExistentialArchetypeType *openedExistential) { + Change change; + change.Kind = Change::OpenedExistentialType; + change.TheExistential = openedExistential; + Changes.push_back(change); + } + + void openedPackExpansionType(PackExpansionType *packExpansion, + TypeVariableType *typeVar) { + Change change; + change.Kind = Change::OpenedPackExpansionType; + change.PackExpansion.TheExpansion = packExpansion; + change.PackExpansion.TypeVar = typeVar; + Changes.push_back(change); + } + + void appliedPropertyWrapper(AppliedPropertyWrapper wrapper) { + Change change; + change.Kind = Change::AppliedPropertyWrapper; + change.PropertyWrapper.WrapperType = wrapper.wrapperType.getPointer(); + change.PropertyWrapper.InitKind = wrapper.initKind; + Changes.push_back(change); + } + + void addedFix(ConstraintFix *fix, unsigned impact) { + Change change; + change.Kind = Change::AddedFix; + change.Fix.TheFix = fix; + change.Fix.Impact = impact; + Changes.push_back(change); + } + + void discharge(ConstraintSystem &cs, ConstraintLocatorBuilder locator) const; +}; + +} +} + +#endif diff --git a/lib/Sema/CSGen.cpp b/lib/Sema/CSGen.cpp index 4c67c74246905..634348ce756fa 100644 --- a/lib/Sema/CSGen.cpp +++ b/lib/Sema/CSGen.cpp @@ -32,6 +32,7 @@ #include "swift/Sema/ConstraintGraph.h" #include "swift/Sema/ConstraintSystem.h" #include "swift/Sema/IDETypeChecking.h" +#include "swift/Sema/PreparedOverload.h" #include "swift/Subsystems.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/SetVector.h" @@ -5117,7 +5118,16 @@ bool ConstraintSystem::generateConstraints(StmtCondition condition, } void ConstraintSystem::applyPropertyWrapper( - Expr *anchor, AppliedPropertyWrapper applied) { + Expr *anchor, AppliedPropertyWrapper applied, + PreparedOverload *preparedOverload) { + if (preparedOverload) { + ASSERT(PreparingOverload); + preparedOverload->appliedPropertyWrapper(applied); + return; + } + + ASSERT(!PreparingOverload); + appliedPropertyWrappers[anchor].push_back(applied); if (solverState) @@ -5138,9 +5148,14 @@ void ConstraintSystem::removePropertyWrapper(Expr *anchor) { ConstraintSystem::TypeMatchResult ConstraintSystem::applyPropertyWrapperToParameter( - Type wrapperType, Type paramType, ParamDecl *param, Identifier argLabel, - ConstraintKind matchKind, ConstraintLocator *locator, - ConstraintLocator *calleeLocator) { + Type wrapperType, + Type paramType, + ParamDecl *param, + Identifier argLabel, + ConstraintKind matchKind, + ConstraintLocator *locator, + ConstraintLocator *calleeLocator, + PreparedOverload *preparedOverload) { Expr *anchor = getAsExpr(calleeLocator->getAnchor()); auto recordPropertyWrapperFix = [&](ConstraintFix *fix) -> TypeMatchResult { @@ -5149,7 +5164,7 @@ ConstraintSystem::applyPropertyWrapperToParameter( recordAnyTypeVarAsPotentialHole(paramType); - if (recordFix(fix)) + if (recordFix(fix, /*impact=*/1, preparedOverload)) return getTypeMatchFailure(locator); return getTypeMatchSuccess(); @@ -5176,21 +5191,33 @@ ConstraintSystem::applyPropertyWrapperToParameter( if (argLabel.hasDollarPrefix()) { Type projectionType = computeProjectedValueType(param, wrapperType); - addConstraint(matchKind, paramType, projectionType, locator); + addConstraint(matchKind, paramType, projectionType, locator, + /*isFavored=*/false, + preparedOverload); if (param->hasImplicitPropertyWrapper()) { auto wrappedValueType = getType(param->getPropertyWrapperWrappedValueVar()); - addConstraint(ConstraintKind::PropertyWrapper, projectionType, wrappedValueType, - getConstraintLocator(param)); + addConstraint(ConstraintKind::PropertyWrapper, + projectionType, wrappedValueType, + getConstraintLocator(param), + /*isFavored=*/false, + preparedOverload); setType(param->getPropertyWrapperProjectionVar(), projectionType); } - applyPropertyWrapper(anchor, { wrapperType, PropertyWrapperInitKind::ProjectedValue }); + applyPropertyWrapper(anchor, + { wrapperType, PropertyWrapperInitKind::ProjectedValue }, + preparedOverload); } else if (param->hasExternalPropertyWrapper()) { Type wrappedValueType = computeWrappedValueType(param, wrapperType); - addConstraint(matchKind, paramType, wrappedValueType, locator); + addConstraint(matchKind, paramType, wrappedValueType, + locator, + /*isFavored=*/false, + preparedOverload); setType(param->getPropertyWrapperWrappedValueVar(), wrappedValueType); - applyPropertyWrapper(anchor, { wrapperType, PropertyWrapperInitKind::WrappedValue }); + applyPropertyWrapper(anchor, + { wrapperType, PropertyWrapperInitKind::WrappedValue }, + preparedOverload); } else { return getTypeMatchFailure(locator); } diff --git a/lib/Sema/CSRanking.cpp b/lib/Sema/CSRanking.cpp index ff7a42cc1112f..b3e99eacc0aaa 100644 --- a/lib/Sema/CSRanking.cpp +++ b/lib/Sema/CSRanking.cpp @@ -423,7 +423,8 @@ static bool isProtocolExtensionAsSpecializedAs(DeclContext *dc1, // the second protocol extension. ConstraintSystem cs(dc1, std::nullopt); SmallVector replacements; - cs.openGeneric(dc2, sig2, ConstraintLocatorBuilder(nullptr), replacements); + cs.openGeneric(dc2, sig2, ConstraintLocatorBuilder(nullptr), replacements, + /*preparedOverload=*/nullptr); // Bind the 'Self' type from the first extension to the type parameter from // opening 'Self' of the second extension. @@ -581,13 +582,15 @@ bool CompareDeclSpecializationRequest::evaluate( SmallVectorImpl &replacements, ConstraintLocator *locator) -> Type { if (auto *funcType = type->getAs()) { - return cs.openFunctionType(funcType, locator, replacements, outerDC); + return cs.openFunctionType(funcType, locator, replacements, outerDC, + /*preparedOverload=*/nullptr); } cs.openGeneric(outerDC, innerDC->getGenericSignatureOfContext(), locator, - replacements); + replacements, /*preparedOverload=*/nullptr); - return cs.openType(type, replacements, locator); + return cs.openType(type, replacements, locator, + /*preparedOverload=*/nullptr); }; bool knownNonSubtype = false; diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index f511c4881c6d8..ba40039ae9ab3 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -40,6 +40,7 @@ #include "swift/Sema/CSFix.h" #include "swift/Sema/ConstraintSystem.h" #include "swift/Sema/IDETypeChecking.h" +#include "swift/Sema/PreparedOverload.h" #include "llvm/ADT/SetVector.h" #include "llvm/Support/Compiler.h" @@ -12046,8 +12047,10 @@ static Type getOpenedResultBuilderTypeFor(ConstraintSystem &cs, // Find the opened type for this callee and substitute in the type // parameters. auto substitutions = cs.getOpenedTypes(calleeLocator); - if (!substitutions.empty()) - builderType = cs.openType(builderType, substitutions, locator); + if (!substitutions.empty()) { + builderType = cs.openType(builderType, substitutions, locator, + /*preparedOverload=*/nullptr); + } assert(!builderType->hasTypeParameter()); } @@ -15400,7 +15403,16 @@ static bool isAugmentingFix(ConstraintFix *fix) { } } -bool ConstraintSystem::recordFix(ConstraintFix *fix, unsigned impact) { +bool ConstraintSystem::recordFix(ConstraintFix *fix, unsigned impact, + PreparedOverload *preparedOverload) { + if (preparedOverload) { + ASSERT(PreparingOverload); + preparedOverload->addedFix(fix, impact); + return true; + } + + ASSERT(!PreparingOverload); + if (isDebugMode()) { auto &log = llvm::errs(); log.indent(solverState ? solverState->getCurrentIndent() : 0) @@ -16013,6 +16025,8 @@ ConstraintSystem::addConstraintImpl(ConstraintKind kind, Type first, Type second, ConstraintLocatorBuilder locator, bool isFavored) { + ASSERT(!PreparingOverload); + assert(first && "Missing first type"); assert(second && "Missing second type"); @@ -16265,7 +16279,8 @@ ConstraintSystem::addKeyPathConstraint( void ConstraintSystem::addConstraint(Requirement req, ConstraintLocatorBuilder locator, bool isFavored, - bool prohibitNonisolatedConformance) { + bool prohibitNonisolatedConformance, + PreparedOverload *preparedOverload) { bool conformsToAnyObject = false; std::optional kind; switch (req.getKind()) { @@ -16273,7 +16288,8 @@ void ConstraintSystem::addConstraint(Requirement req, auto type1 = req.getFirstType(); auto type2 = req.getSecondType(); - addConstraint(ConstraintKind::SameShape, type1, type2, locator); + addConstraint(ConstraintKind::SameShape, type1, type2, locator, + /*isFavored=*/false, preparedOverload); return; } @@ -16314,19 +16330,32 @@ void ConstraintSystem::addConstraint(Requirement req, auto firstType = req.getFirstType(); if (kind) { addConstraint(*kind, req.getFirstType(), req.getSecondType(), locator, - isFavored); + isFavored, preparedOverload); } if (conformsToAnyObject) { auto anyObject = getASTContext().getAnyObjectConstraint(); - addConstraint(ConstraintKind::ConformsTo, firstType, anyObject, locator); + addConstraint(ConstraintKind::ConformsTo, firstType, anyObject, locator, + /*isFavored=*/false, preparedOverload); } } void ConstraintSystem::addConstraint(ConstraintKind kind, Type first, Type second, ConstraintLocatorBuilder locator, - bool isFavored) { + bool isFavored, + PreparedOverload *preparedOverload) { + if (preparedOverload) { + ASSERT(PreparingOverload); + auto c = Constraint::create(*this, kind, first, second, + getConstraintLocator(locator)); + if (isFavored) c->setFavored(); + preparedOverload->addedConstraint(c); + return; + } + + ASSERT(!PreparingOverload); + switch (addConstraintImpl(kind, first, second, locator, isFavored)) { case SolutionKind::Error: // Add a failing constraint, if needed. diff --git a/lib/Sema/CSSolver.cpp b/lib/Sema/CSSolver.cpp index 13efcb2406db1..b3f96ee3f8819 100644 --- a/lib/Sema/CSSolver.cpp +++ b/lib/Sema/CSSolver.cpp @@ -22,6 +22,7 @@ #include "swift/Basic/Defer.h" #include "swift/Sema/ConstraintGraph.h" #include "swift/Sema/ConstraintSystem.h" +#include "swift/Sema/PreparedOverload.h" #include "swift/Sema/SolutionResult.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -58,11 +59,17 @@ STATISTIC(LargestSolutionAttemptNumber, "# of the largest solution attempt"); TypeVariableType *ConstraintSystem::createTypeVariable( ConstraintLocator *locator, - unsigned options) { + unsigned options, + PreparedOverload *preparedOverload) { ++TotalNumTypeVariables; auto tv = TypeVariableType::getNew(getASTContext(), assignTypeVariableID(), locator, options); - addTypeVariable(tv); + if (preparedOverload) { + ASSERT(PreparingOverload); + preparedOverload->addedTypeVariable(tv); + } else { + addTypeVariable(tv); + } return tv; } diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index fec96cf5f8a02..177093469d121 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -38,6 +38,7 @@ #include "swift/Sema/CSFix.h" #include "swift/Sema/ConstraintGraph.h" #include "swift/Sema/IDETypeChecking.h" +#include "swift/Sema/PreparedOverload.h" #include "swift/Sema/SolutionResult.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" @@ -176,6 +177,8 @@ bool ConstraintSystem::hasFreeTypeVariables() { } void ConstraintSystem::addTypeVariable(TypeVariableType *typeVar) { + ASSERT(!PreparingOverload); + TypeVariables.insert(typeVar); // Notify the constraint graph. @@ -304,6 +307,8 @@ void ConstraintSystem::removeConversionRestriction( } void ConstraintSystem::addFix(ConstraintFix *fix) { + ASSERT(!PreparingOverload); + bool inserted = Fixes.insert(fix); ASSERT(inserted); @@ -897,7 +902,14 @@ ConstraintSystem::openAnyExistentialType(Type type, } void ConstraintSystem::recordOpenedExistentialType( - ConstraintLocator *locator, ExistentialArchetypeType *opened) { + ConstraintLocator *locator, + ExistentialArchetypeType *opened, + PreparedOverload *preparedOverload) { + if (preparedOverload) { + preparedOverload->openedExistentialType(opened); + return; + } + bool inserted = OpenedExistentialTypes.insert({locator, opened}).second; ASSERT(inserted); diff --git a/lib/Sema/TypeCheckConstraints.cpp b/lib/Sema/TypeCheckConstraints.cpp index 6414950917ed0..56560b7a82feb 100644 --- a/lib/Sema/TypeCheckConstraints.cpp +++ b/lib/Sema/TypeCheckConstraints.cpp @@ -627,7 +627,7 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue, if (auto *typeVar = findParam(GP)) return typeVar; - auto *typeVar = cs.openGenericParameter(GP, locator); + auto *typeVar = cs.openGenericParameter(GP, locator, nullptr); genericParameters.emplace_back(GP, typeVar); return typeVar; @@ -720,9 +720,9 @@ Type TypeChecker::typeCheckParameterDefault(Expr *&defaultValue, cs.openGenericRequirement(DC->getParent(), signature, index, requirement, /*skipSelfProtocolConstraint=*/false, locator, [&](Type type) -> Type { - return cs.openType(type, genericParameters, - locator); - }); + return cs.openType(type, genericParameters, locator, + /*preparedOverload=*/nullptr); + }, /*preparedOverload=*/nullptr); }; auto diagnoseInvalidRequirement = [&](Requirement requirement) { diff --git a/lib/Sema/TypeCheckProtocol.cpp b/lib/Sema/TypeCheckProtocol.cpp index c9e09d4274138..679b3c282a523 100644 --- a/lib/Sema/TypeCheckProtocol.cpp +++ b/lib/Sema/TypeCheckProtocol.cpp @@ -1224,7 +1224,7 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache, openWitnessType = cs->getTypeOfReference( witness, FunctionRefInfo::doubleBaseNameApply(), witnessLocator, - /*useDC=*/nullptr) + /*useDC=*/nullptr, /*preparedOverload=*/nullptr) .adjustedReferenceType; } openWitnessType = openWitnessType->getRValueType(); @@ -1251,11 +1251,11 @@ swift::matchWitness(WitnessChecker::RequirementEnvironmentCache &reqEnvCache, reqThrownError = getThrownErrorType(reqASD); reqThrownError = cs->openType(reqThrownError, reqReplacements, - reqLocator); + reqLocator, /*preparedOverload=*/nullptr); witnessThrownError = getThrownErrorType(witnessASD); witnessThrownError = cs->openType(witnessThrownError, witnessReplacements, - witnessLocator); + witnessLocator, /*preparedOverload=*/nullptr); } return std::make_tuple(std::nullopt, reqType, openWitnessType, diff --git a/lib/Sema/TypeOfReference.cpp b/lib/Sema/TypeOfReference.cpp index 313da3e045a03..1d7d9c4c9e35b 100644 --- a/lib/Sema/TypeOfReference.cpp +++ b/lib/Sema/TypeOfReference.cpp @@ -29,8 +29,10 @@ #include "swift/AST/ProtocolConformance.h" #include "swift/AST/TypeTransform.h" #include "swift/Sema/ConstraintSystem.h" +#include "swift/Sema/PreparedOverload.h" #include "swift/Basic/Assertions.h" #include "swift/Basic/Statistic.h" +#include "swift/Basic/Defer.h" using namespace swift; using namespace constraints; @@ -43,18 +45,21 @@ using namespace inference; Type ConstraintSystem::openUnboundGenericType(GenericTypeDecl *decl, Type parentTy, ConstraintLocatorBuilder locator, - bool isTypeResolution) { + bool isTypeResolution, + PreparedOverload *preparedOverload) { if (parentTy) { - parentTy = replaceInferableTypesWithTypeVars(parentTy, locator); + parentTy = replaceInferableTypesWithTypeVars( + parentTy, locator, preparedOverload); } // Open up the generic type. SmallVector replacements; openGeneric(decl->getDeclContext(), decl->getGenericSignature(), locator, - replacements); + replacements, preparedOverload); // FIXME: Get rid of fixmeAllowDuplicates. - recordOpenedTypes(locator, replacements, /*fixmeAllowDuplicates=*/true); + recordOpenedTypes(locator, replacements, preparedOverload, + /*fixmeAllowDuplicates=*/true); if (parentTy) { const auto parentTyInContext = @@ -81,7 +86,7 @@ Type ConstraintSystem::openUnboundGenericType(GenericTypeDecl *decl, continue; addConstraint(ConstraintKind::Bind, pair.second, found->second, - locator); + locator, /*isFavored=*/false, preparedOverload); } } @@ -115,7 +120,8 @@ Type ConstraintSystem::openUnboundGenericType(GenericTypeDecl *decl, } static void checkNestedTypeConstraints(ConstraintSystem &cs, Type type, - ConstraintLocatorBuilder locator) { + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { // If this is a type defined inside of constrained extension, let's add all // of the generic requirements to the constraint system to make sure that it's // something we can use. @@ -184,43 +190,44 @@ static void checkNestedTypeConstraints(ConstraintSystem &cs, Type type, // U is an associated type of protocol P. return type.subst(QuerySubstitutionMap{contextSubMap}, LookUpConformanceInSubstitutionMap(subMap)); - }); + }, preparedOverload); } } // And now make sure the parent is okay, for things like X.Y.Z. - checkNestedTypeConstraints(cs, parentTy, locator); + checkNestedTypeConstraints(cs, parentTy, locator, preparedOverload); } Type ConstraintSystem::replaceInferableTypesWithTypeVars( - Type type, ConstraintLocatorBuilder locator) { + Type type, ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { if (!type->hasUnboundGenericType() && !type->hasPlaceholder()) return type; + auto flags = TVO_CanBindToNoEscape | TVO_PrefersSubtypeBinding | TVO_CanBindToHole; + type = type.transformRec([&](Type type) -> std::optional { if (auto unbound = type->getAs()) { return openUnboundGenericType(unbound->getDecl(), unbound->getParent(), - locator, /*isTypeResolution=*/false); + locator, /*isTypeResolution=*/false, + preparedOverload); } else if (auto *placeholderTy = type->getAs()) { if (auto *typeRepr = placeholderTy->getOriginator().dyn_cast()) { if (isa(typeRepr)) { - return Type(createTypeVariable( - getConstraintLocator(locator, - LocatorPathElt::PlaceholderType(typeRepr)), - TVO_CanBindToNoEscape | TVO_PrefersSubtypeBinding | - TVO_CanBindToHole)); + return Type( + createTypeVariable( + getConstraintLocator(locator, LocatorPathElt::PlaceholderType(typeRepr)), + flags, preparedOverload)); } - } else if (auto *var = - placeholderTy->getOriginator().dyn_cast()) { + } else if (auto *var = placeholderTy->getOriginator().dyn_cast()) { if (var->getName().hasDollarPrefix()) { auto *repr = new (type->getASTContext()) PlaceholderTypeRepr(var->getLoc()); - return Type(createTypeVariable( - getConstraintLocator(locator, - LocatorPathElt::PlaceholderType(repr)), - TVO_CanBindToNoEscape | TVO_PrefersSubtypeBinding | - TVO_CanBindToHole)); + return Type( + createTypeVariable( + getConstraintLocator(locator, LocatorPathElt::PlaceholderType(repr)), + flags, preparedOverload)); } } } @@ -239,13 +246,16 @@ namespace { struct TypeOpener : public TypeTransform { ArrayRef replacements; ConstraintLocatorBuilder locator; + PreparedOverload *preparedOverload; ConstraintSystem &cs; TypeOpener(ArrayRef replacements, ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload, ConstraintSystem &cs) : TypeTransform(cs.getASTContext()), - replacements(replacements), locator(locator), cs(cs) {} + replacements(replacements), locator(locator), + preparedOverload(preparedOverload), cs(cs) {} std::optional transform(TypeBase *type, TypePosition pos) { if (!type->hasTypeParameter()) @@ -268,7 +278,8 @@ struct TypeOpener : public TypeTransform { Type transformPackExpansionType(PackExpansionType *expansion, TypePosition pos) { - return cs.openPackExpansionType(expansion, replacements, locator); + return cs.openPackExpansionType(expansion, replacements, locator, + preparedOverload); } bool shouldUnwrapVanishingTuples() const { @@ -283,47 +294,78 @@ struct TypeOpener : public TypeTransform { } Type ConstraintSystem::openType(Type type, ArrayRef replacements, - ConstraintLocatorBuilder locator) { + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { assert(!type->hasUnboundGenericType()); if (!type->hasTypeParameter()) return type; - return TypeOpener(replacements, locator, *this) + return TypeOpener(replacements, locator, preparedOverload, *this) .doIt(type, TypePosition::Invariant); } Type ConstraintSystem::openPackExpansionType(PackExpansionType *expansion, ArrayRef replacements, - ConstraintLocatorBuilder locator) { + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { auto patternType = - openType(expansion->getPatternType(), replacements, locator); - auto shapeType = openType(expansion->getCountType(), replacements, locator); + openType(expansion->getPatternType(), replacements, locator, + preparedOverload); + auto shapeType = openType(expansion->getCountType(), replacements, locator, + preparedOverload); auto openedPackExpansion = PackExpansionType::get(patternType, shapeType); - auto known = OpenedPackExpansionTypes.find(openedPackExpansion); - if (known != OpenedPackExpansionTypes.end()) - return known->second; + // FIXME: It's silly that we need to do this. The whole concept of + // "opening" pack expansions is broken. + { + if (preparedOverload) { + for (auto change : preparedOverload->Changes) { + if (change.Kind == PreparedOverload::Change::OpenedPackExpansionType && + change.PackExpansion.TheExpansion == openedPackExpansion) + return change.PackExpansion.TypeVar; + } + } + + auto known = OpenedPackExpansionTypes.find(openedPackExpansion); + if (known != OpenedPackExpansionTypes.end()) + return known->second; + } auto *expansionLoc = getConstraintLocator(locator.withPathElement( LocatorPathElt::PackExpansionType(openedPackExpansion))); - auto *expansionVar = createTypeVariable(expansionLoc, TVO_PackExpansion); + auto *expansionVar = createTypeVariable(expansionLoc, TVO_PackExpansion, + preparedOverload); // This constraint is important to make sure that pack expansion always // has a binding and connect pack expansion var to any type variables // that appear in pattern and shape types. - addUnsolvedConstraint(Constraint::create(*this, ConstraintKind::FallbackType, - expansionVar, openedPackExpansion, - expansionLoc)); - - recordOpenedPackExpansionType(openedPackExpansion, expansionVar); + auto *c = Constraint::create(*this, ConstraintKind::FallbackType, + expansionVar, openedPackExpansion, + expansionLoc); + if (preparedOverload) + preparedOverload->addedConstraint(c); + else + addUnsolvedConstraint(c); + + recordOpenedPackExpansionType(openedPackExpansion, expansionVar, + preparedOverload); return expansionVar; } void ConstraintSystem::recordOpenedPackExpansionType(PackExpansionType *expansion, - TypeVariableType *expansionVar) { + TypeVariableType *expansionVar, + PreparedOverload *preparedOverload) { + if (preparedOverload) { + ASSERT(PreparingOverload); + preparedOverload->openedPackExpansionType(expansion, expansionVar); + return; + } + + ASSERT(!PreparingOverload); + bool inserted = OpenedPackExpansionTypes.insert({expansion, expansionVar}).second; ASSERT(inserted); @@ -359,11 +401,13 @@ Type ConstraintSystem::openOpaqueType(OpaqueTypeArchetypeType *opaque, SmallVector replacements; openGeneric(DC, opaqueDecl->getOpaqueInterfaceGenericSignature(), - opaqueLocator, replacements); + opaqueLocator, replacements, /*preparedOverload=*/nullptr); - recordOpenedTypes(opaqueLocatorKey, replacements); + recordOpenedTypes(opaqueLocatorKey, replacements, + /*preparedOverload=*/nullptr); - return openType(opaque->getInterfaceType(), replacements, locator); + return openType(opaque->getInterfaceType(), replacements, locator, + /*preparedOverload=*/nullptr); } Type ConstraintSystem::openOpaqueType(Type type, ContextualTypePurpose context, @@ -432,19 +476,24 @@ FunctionType *ConstraintSystem::openFunctionType( AnyFunctionType *funcType, ConstraintLocatorBuilder locator, SmallVectorImpl &replacements, - DeclContext *outerDC) { + DeclContext *outerDC, + PreparedOverload *preparedOverload) { if (auto *genericFn = funcType->getAs()) { auto signature = genericFn->getGenericSignature(); - openGenericParameters(outerDC, signature, replacements, locator); + openGenericParameters(outerDC, signature, replacements, locator, + preparedOverload); openGenericRequirements(outerDC, signature, /*skipProtocolSelfConstraint=*/false, locator, [&](Type type) -> Type { - return openType(type, replacements, locator); - }); + return openType(type, replacements, locator, + preparedOverload); + }, preparedOverload); funcType = substGenericArgs(genericFn, - [&](Type type) { return openType(type, replacements, locator); }); + [&](Type type) { + return openType(type, replacements, locator, preparedOverload); + }); } return funcType->castTo(); @@ -651,8 +700,17 @@ Type ConstraintSystem::getUnopenedTypeOfReference( } void ConstraintSystem::recordOpenedType( - ConstraintLocator *locator, ArrayRef openedTypes) { - bool inserted = OpenedTypes.insert({locator, openedTypes}).second; + ConstraintLocator *locator, ArrayRef replacements, + PreparedOverload *preparedOverload) { + if (preparedOverload) { + ASSERT(PreparingOverload); + preparedOverload->openedTypes(replacements); + return; + } + + ASSERT(!PreparingOverload); + + bool inserted = OpenedTypes.insert({locator, replacements}).second; ASSERT(inserted); if (solverState) @@ -661,7 +719,8 @@ void ConstraintSystem::recordOpenedType( void ConstraintSystem::recordOpenedTypes( ConstraintLocatorBuilder locator, - SmallVectorImpl &replacements, + const SmallVectorImpl &replacements, + PreparedOverload *preparedOverload, bool fixmeAllowDuplicates) { if (replacements.empty()) return; @@ -680,14 +739,15 @@ void ConstraintSystem::recordOpenedTypes( ConstraintLocator *locatorPtr = getConstraintLocator(locator); assert(locatorPtr && "No locator for opened types?"); - OpenedType* openedTypes + OpenedType *openedTypes = Allocator.Allocate(replacements.size()); std::copy(replacements.begin(), replacements.end(), openedTypes); // FIXME: Get rid of fixmeAllowDuplicates. if (!fixmeAllowDuplicates || OpenedTypes.count(locatorPtr) == 0) recordOpenedType( - locatorPtr, llvm::ArrayRef(openedTypes, replacements.size())); + locatorPtr, llvm::ArrayRef(openedTypes, replacements.size()), + preparedOverload); } /// Determine how many levels of argument labels should be removed from the @@ -745,9 +805,12 @@ unsigned constraints::getNumApplications(bool hasAppliedSelf, /// Replaces property wrapper types in the parameter list of the given function type /// with the wrapped-value or projected-value types (depending on argument label). static FunctionType * -unwrapPropertyWrapperParameterTypes(ConstraintSystem &cs, AbstractFunctionDecl *funcDecl, - FunctionRefInfo functionRefInfo, FunctionType *functionType, - ConstraintLocatorBuilder locator) { +unwrapPropertyWrapperParameterTypes(ConstraintSystem &cs, + AbstractFunctionDecl *funcDecl, + FunctionRefInfo functionRefInfo, + FunctionType *functionType, + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { // Only apply property wrappers to unapplied references to functions. if (!functionRefInfo.isUnapplied()) return functionType; @@ -783,14 +846,15 @@ unwrapPropertyWrapperParameterTypes(ConstraintSystem &cs, AbstractFunctionDecl * } auto *loc = cs.getConstraintLocator(locator); - auto *wrappedType = cs.createTypeVariable(loc, 0); + auto *wrappedType = cs.createTypeVariable(loc, 0, preparedOverload); auto paramType = paramTypes[i].getParameterType(); auto paramLabel = paramTypes[i].getLabel(); auto paramInternalLabel = paramTypes[i].getInternalLabel(); adjustedParamTypes.push_back(AnyFunctionType::Param( wrappedType, paramLabel, ParameterTypeFlags(), paramInternalLabel)); - cs.applyPropertyWrapperToParameter(paramType, wrappedType, paramDecl, argLabel, - ConstraintKind::Equal, loc, loc); + cs.applyPropertyWrapperToParameter( + paramType, wrappedType, paramDecl, argLabel, ConstraintKind::Equal, + loc, loc, preparedOverload); } return FunctionType::get(adjustedParamTypes, functionType->getResult(), @@ -806,7 +870,7 @@ static bool isRequirementOrWitness(const ConstraintLocatorBuilder &locator) { FunctionType *ConstraintSystem::adjustFunctionTypeForConcurrency( FunctionType *fnType, Type baseType, ValueDecl *decl, DeclContext *dc, unsigned numApplies, bool isMainDispatchQueue, ArrayRef replacements, - ConstraintLocatorBuilder locator) { + ConstraintLocatorBuilder locator, PreparedOverload *preparedOverload) { auto *adjustedTy = swift::adjustFunctionTypeForConcurrency( fnType, decl, dc, numApplies, isMainDispatchQueue, GetClosureType{*this}, @@ -814,7 +878,7 @@ FunctionType *ConstraintSystem::adjustFunctionTypeForConcurrency( if (replacements.empty()) return type; - return openType(type, replacements, locator); + return openType(type, replacements, locator, preparedOverload); }); // Infer @Sendable for global actor isolated function types under the @@ -932,21 +996,25 @@ DeclReferenceType ConstraintSystem::getTypeOfReference(ValueDecl *value, FunctionRefInfo functionRefInfo, ConstraintLocatorBuilder locator, - DeclContext *useDC) { + DeclContext *useDC, + PreparedOverload *preparedOverload) { + ASSERT(!!preparedOverload == PreparingOverload); + if (value->getDeclContext()->isTypeContext() && isa(value)) { // Unqualified lookup can find operator names within nominal types. auto func = cast(value); assert(func->isOperator() && "Lookup should only find operators"); - SmallVector replacements; + SmallVector replacements; AnyFunctionType *funcType = func->getInterfaceType() ->castTo(); auto openedType = openFunctionType( - funcType, locator, replacements, func->getDeclContext()); + funcType, locator, replacements, func->getDeclContext(), + preparedOverload); // If we opened up any type variables, record the replacements. - recordOpenedTypes(locator, replacements); + recordOpenedTypes(locator, replacements, preparedOverload); // If this is a method whose result type is dynamic Self, replace // DynamicSelf with the actual object type. @@ -964,7 +1032,7 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value, functionRefInfo); openedType = adjustFunctionTypeForConcurrency( origOpenedType, /*baseType=*/Type(), func, useDC, numApplies, false, - replacements, locator); + replacements, locator, preparedOverload); } // The reference implicitly binds 'self'. @@ -981,11 +1049,12 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value, funcDecl, /*isCurriedInstanceReference=*/false, functionRefInfo); auto openedType = openFunctionType(funcType, locator, replacements, - funcDecl->getDeclContext()) + funcDecl->getDeclContext(), + preparedOverload) ->removeArgumentLabels(numLabelsToRemove); openedType = unwrapPropertyWrapperParameterTypes( *this, funcDecl, functionRefInfo, openedType->castTo(), - locator); + locator, preparedOverload); auto origOpenedType = openedType; if (!isRequirementOrWitness(locator)) { @@ -993,7 +1062,7 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value, functionRefInfo); openedType = adjustFunctionTypeForConcurrency( origOpenedType->castTo(), /*baseType=*/Type(), funcDecl, - useDC, numApplies, false, replacements, locator); + useDC, numApplies, false, replacements, locator, preparedOverload); } if (isForCodeCompletion() && openedType->hasError()) { @@ -1004,7 +1073,7 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value, } // If we opened up any type variables, record the replacements. - recordOpenedTypes(locator, replacements); + recordOpenedTypes(locator, replacements, preparedOverload); return { origOpenedType, openedType, origOpenedType, openedType, Type() }; } @@ -1021,10 +1090,10 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value, /*isSpecialized=*/false); type = useDC->mapTypeIntoContext(type); - checkNestedTypeConstraints(*this, type, locator); + checkNestedTypeConstraints(*this, type, locator, preparedOverload); // Convert any placeholder types and open generics. - type = replaceInferableTypesWithTypeVars(type, locator); + type = replaceInferableTypesWithTypeVars(type, locator, preparedOverload); // Module types are not wrapped in metatypes. if (type->is()) @@ -1043,10 +1112,11 @@ ConstraintSystem::getTypeOfReference(ValueDecl *value, SmallVector replacements; Type openedType = openFunctionType( macroType->castTo(), locator, replacements, - macro->getDeclContext()); + macro->getDeclContext(), + preparedOverload); // If we opened up any type variables, record the replacements. - recordOpenedTypes(locator, replacements); + recordOpenedTypes(locator, replacements, preparedOverload); // FIXME: Should we use replaceParamErrorTypeByPlaceholder() here? @@ -1110,7 +1180,8 @@ static void bindArchetypesFromContext( ConstraintSystem &cs, DeclContext *outerDC, ConstraintLocator *locatorPtr, - ArrayRef replacements) { + ArrayRef replacements, + PreparedOverload *preparedOverload) { auto bindPrimaryArchetype = [&](Type paramTy, Type contextTy) { // We might not have a type variable for this generic parameter @@ -1121,7 +1192,7 @@ static void bindArchetypesFromContext( for (auto pair : replacements) { if (pair.first->isEqual(paramTy)) { cs.addConstraint(ConstraintKind::Bind, pair.second, contextTy, - locatorPtr); + locatorPtr, /*isFavored=*/false, preparedOverload); return; } } @@ -1155,39 +1226,44 @@ void ConstraintSystem::openGeneric( DeclContext *outerDC, GenericSignature sig, ConstraintLocatorBuilder locator, - SmallVectorImpl &replacements) { + SmallVectorImpl &replacements, + PreparedOverload *preparedOverload) { if (!sig) return; - openGenericParameters(outerDC, sig, replacements, locator); + openGenericParameters(outerDC, sig, replacements, locator, preparedOverload); // Add the requirements as constraints. openGenericRequirements( outerDC, sig, /*skipProtocolSelfConstraint=*/false, locator, - [&](Type type) { return openType(type, replacements, locator); }); + [&](Type type) { + return openType(type, replacements, locator, preparedOverload); + }, preparedOverload); } void ConstraintSystem::openGenericParameters(DeclContext *outerDC, GenericSignature sig, SmallVectorImpl &replacements, - ConstraintLocatorBuilder locator) { + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { ASSERT(sig); ASSERT(replacements.empty()); // Create the type variables for the generic parameters. for (auto gp : sig.getGenericParams()) { - auto *typeVar = openGenericParameter(gp, locator); + auto *typeVar = openGenericParameter(gp, locator, preparedOverload); replacements.emplace_back(gp, typeVar); } auto *baseLocator = getConstraintLocator( locator.withPathElement(LocatorPathElt::OpenedGeneric(sig))); - bindArchetypesFromContext(*this, outerDC, baseLocator, replacements); + bindArchetypesFromContext(*this, outerDC, baseLocator, replacements, preparedOverload); } TypeVariableType *ConstraintSystem::openGenericParameter(GenericTypeParamType *parameter, - ConstraintLocatorBuilder locator) { + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { auto *paramLocator = getConstraintLocator( locator.withPathElement(LocatorPathElt::GenericParameter(parameter))); @@ -1199,20 +1275,21 @@ TypeVariableType *ConstraintSystem::openGenericParameter(GenericTypeParamType *p if (shouldAttemptFixes()) options |= TVO_CanBindToHole; - return createTypeVariable(paramLocator, options); + return createTypeVariable(paramLocator, options, preparedOverload); } void ConstraintSystem::openGenericRequirements( DeclContext *outerDC, GenericSignature signature, bool skipProtocolSelfConstraint, ConstraintLocatorBuilder locator, - llvm::function_ref substFn) { + llvm::function_ref substFn, + PreparedOverload *preparedOverload) { auto requirements = signature.getRequirements(); for (unsigned pos = 0, n = requirements.size(); pos != n; ++pos) { auto openedGenericLoc = locator.withPathElement(LocatorPathElt::OpenedGeneric(signature)); openGenericRequirement(outerDC, signature, pos, requirements[pos], skipProtocolSelfConstraint, openedGenericLoc, - substFn); + substFn, preparedOverload); } } @@ -1220,7 +1297,8 @@ void ConstraintSystem::openGenericRequirement( DeclContext *outerDC, GenericSignature signature, unsigned index, const Requirement &req, bool skipProtocolSelfConstraint, ConstraintLocatorBuilder locator, - llvm::function_ref substFn) { + llvm::function_ref substFn, + PreparedOverload *preparedOverload) { std::optional openedReq; auto openedFirst = substFn(req.getFirstType()); @@ -1257,7 +1335,8 @@ void ConstraintSystem::openGenericRequirement( addConstraint(*openedReq, locator.withPathElement( LocatorPathElt::TypeParameterRequirement(index, kind)), - /*isFavored=*/false, prohibitIsolatedConformance); + /*isFavored=*/false, prohibitIsolatedConformance, + preparedOverload); } /// Add the constraint on the type used for the 'Self' type for a member @@ -1271,19 +1350,22 @@ void ConstraintSystem::openGenericRequirement( /// \param selfTy The instance type of the context in which the member is /// declared. static void addSelfConstraint(ConstraintSystem &cs, Type objectTy, Type selfTy, - ConstraintLocatorBuilder locator){ + ConstraintLocatorBuilder locator, + PreparedOverload *preparedOverload) { assert(!selfTy->is()); // Otherwise, use a subtype constraint for classes to cope with inheritance. if (selfTy->getClassOrBoundGenericClass()) { cs.addConstraint(ConstraintKind::Subtype, objectTy, selfTy, - cs.getConstraintLocator(locator)); + cs.getConstraintLocator(locator), /*isFavored=*/false, + preparedOverload); return; } // Otherwise, the types must be equivalent. cs.addConstraint(ConstraintKind::Bind, objectTy, selfTy, - cs.getConstraintLocator(locator)); + cs.getConstraintLocator(locator), /*isFavored=*/false, + preparedOverload); } Type constraints::getDynamicSelfReplacementType( @@ -1507,13 +1589,17 @@ Type ConstraintSystem::getMemberReferenceTypeFromOpenedType( DeclReferenceType ConstraintSystem::getTypeOfMemberReference( Type baseTy, ValueDecl *value, DeclContext *useDC, bool isDynamicLookup, FunctionRefInfo functionRefInfo, ConstraintLocator *locator, - SmallVectorImpl *replacementsPtr) { + SmallVectorImpl *replacementsPtr, + PreparedOverload *preparedOverload) { + ASSERT(!!preparedOverload == PreparingOverload); + // Figure out the instance type used for the base. Type resolvedBaseTy = getFixedTypeRecursive(baseTy, /*wantRValue=*/true); // If the base is a module type, just use the type of the decl. if (resolvedBaseTy->is()) { - return getTypeOfReference(value, functionRefInfo, locator, useDC); + return getTypeOfReference(value, functionRefInfo, locator, useDC, + preparedOverload); } // Check to see if the self parameter is applied, in which case we'll want to @@ -1548,10 +1634,11 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( if (memberTy->isConstraintType()) memberTy = ExistentialType::get(memberTy); - checkNestedTypeConstraints(*this, memberTy, locator); + checkNestedTypeConstraints(*this, memberTy, locator, preparedOverload); // Convert any placeholders and open any generics. - memberTy = replaceInferableTypesWithTypeVars(memberTy, locator); + memberTy = replaceInferableTypesWithTypeVars( + memberTy, locator, preparedOverload); // Wrap it in a metatype. memberTy = MetatypeType::get(memberTy); @@ -1586,8 +1673,10 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( // If we have a generic signature, open the parameters. We delay opening // requirements to allow contextual types to affect the situation. auto genericSig = innerDC->getGenericSignatureOfContext(); - if (genericSig) - openGenericParameters(outerDC, genericSig, replacements, locator); + if (genericSig) { + openGenericParameters(outerDC, genericSig, replacements, locator, + preparedOverload); + } Type thrownErrorType; if (isa(value) || isa(value)) { @@ -1596,7 +1685,9 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( if (auto *genericFn = openedType->getAs()) { openedType = substGenericArgs(genericFn, - [&](Type type) { return openType(type, replacements, locator); }); + [&](Type type) { + return openType(type, replacements, locator, preparedOverload); + }); } } else { // If the storage has a throwing getter, save the thrown error type.. @@ -1653,11 +1744,13 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( // If the storage is generic, open the self and ref types. if (genericSig) { - selfTy = openType(selfTy, replacements, locator); - refType = openType(refType, replacements, locator); + selfTy = openType(selfTy, replacements, locator, preparedOverload); + refType = openType(refType, replacements, locator, preparedOverload); - if (thrownErrorType) - thrownErrorType = openType(thrownErrorType, replacements, locator); + if (thrownErrorType) { + thrownErrorType = openType(thrownErrorType, replacements, locator, + preparedOverload); + } } FunctionType::Param selfParam(selfTy, Identifier(), selfFlags); @@ -1684,14 +1777,16 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( getConcreteReplacementForProtocolSelfType(value)) { // Concrete type replacing `Self` could be generic, so we need // to make sure that it's opened before use. - baseOpenedTy = openType(concreteSelf, replacements, locator); + baseOpenedTy = openType(concreteSelf, replacements, locator, + preparedOverload); baseObjTy = baseOpenedTy; } } } else if (baseObjTy->isExistentialType()) { auto openedArchetype = ExistentialArchetypeType::get(baseObjTy->getCanonicalType()); - recordOpenedExistentialType(getConstraintLocator(locator), openedArchetype); + recordOpenedExistentialType(getConstraintLocator(locator), openedArchetype, + preparedOverload); baseOpenedTy = openedArchetype; } @@ -1705,9 +1800,10 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( // conformance constraint because we wouldn't have found the declaration // if it didn't conform. addConstraint(ConstraintKind::Bind, baseOpenedTy, selfObjTy, - getConstraintLocator(locator)); + getConstraintLocator(locator), /*isFavored=*/false, + preparedOverload); } else if (!isDynamicLookup) { - addSelfConstraint(*this, baseOpenedTy, selfObjTy, locator); + addSelfConstraint(*this, baseOpenedTy, selfObjTy, locator, preparedOverload); } // Open generic requirements after self constraint has been @@ -1720,7 +1816,9 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( openGenericRequirements( outerDC, genericSig, /*skipProtocolSelfConstraint=*/true, locator, - [&](Type type) { return openType(type, replacements, locator); }); + [&](Type type) { + return openType(type, replacements, locator, preparedOverload); + }, preparedOverload); } if (auto *funcDecl = dyn_cast(value)) { @@ -1728,8 +1826,9 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( // Strip off the 'self' parameter auto *functionType = fullFunctionType->getResult()->getAs(); - functionType = unwrapPropertyWrapperParameterTypes(*this, funcDecl, functionRefInfo, - functionType, locator); + functionType = unwrapPropertyWrapperParameterTypes( + *this, funcDecl, functionRefInfo, functionType, + locator, preparedOverload); // FIXME: Verify ExtInfo state is correct, not working by accident. FunctionType::ExtInfo info; @@ -1750,12 +1849,13 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( unsigned numApplies = getNumApplications(hasAppliedSelf, functionRefInfo); openedType = adjustFunctionTypeForConcurrency( origOpenedType->castTo(), resolvedBaseTy, value, useDC, - numApplies, isMainDispatchQueueMember(locator), replacements, locator); + numApplies, isMainDispatchQueueMember(locator), replacements, locator, + preparedOverload); } else if (auto subscript = dyn_cast(value)) { openedType = adjustFunctionTypeForConcurrency( - origOpenedType->castTo(), resolvedBaseTy, subscript, - useDC, - /*numApplies=*/2, /*isMainDispatchQueue=*/false, replacements, locator); + origOpenedType->castTo(), resolvedBaseTy, subscript, useDC, + /*numApplies=*/2, /*isMainDispatchQueue=*/false, replacements, locator, + preparedOverload); } else if (auto var = dyn_cast(value)) { // Adjust the function's result type, since that's the Var's actual type. auto origFnType = origOpenedType->castTo(); @@ -1782,7 +1882,7 @@ DeclReferenceType ConstraintSystem::getTypeOfMemberReference( } // If we opened up any type variables, record the replacements. - recordOpenedTypes(locator, replacements); + recordOpenedTypes(locator, replacements, preparedOverload); return { origOpenedType, openedType, origType, type, thrownErrorType }; } @@ -1887,7 +1987,7 @@ Type ConstraintSystem::getEffectiveOverloadType(ConstraintLocator *locator, FunctionType::get(indices, elementTy, info), overload.getBaseType(), subscript, useDC, /*numApplies=*/1, /*isMainDispatchQueue=*/false, emptyReplacements, - locator); + locator, /*preparedOverload=*/nullptr); } else if (auto var = dyn_cast(decl)) { type = var->getValueInterfaceType(); if (doesStorageProduceLValue( @@ -1940,8 +2040,8 @@ Type ConstraintSystem::getEffectiveOverloadType(ConstraintLocator *locator, type = adjustFunctionTypeForConcurrency( type->castTo(), overload.getBaseType(), decl, - useDC, numApplies, - /*isMainDispatchQueue=*/false, emptyReplacements, locator) + useDC, numApplies, /*isMainDispatchQueue=*/false, + emptyReplacements, locator, /*preparedOverload=*/nullptr) ->getResult(); } } @@ -2395,6 +2495,60 @@ void ConstraintSystem::recordResolvedOverload(ConstraintLocator *locator, recordChange(SolverTrail::Change::ResolvedOverload(locator)); } +void PreparedOverload::discharge(ConstraintSystem &cs, + ConstraintLocatorBuilder locator) const { + for (auto change : Changes) { + switch (change.Kind) { + case PreparedOverload::Change::AddedTypeVariable: + cs.addTypeVariable(change.TypeVar); + break; + + case PreparedOverload::Change::AddedConstraint: + cs.addUnsolvedConstraint(change.TheConstraint); + cs.activateConstraint(change.TheConstraint); + break; + + case PreparedOverload::Change::OpenedTypes: { + auto *locatorPtr = cs.getConstraintLocator(locator); + ArrayRef replacements( + change.Replacements.Data, + change.Replacements.Count); + + // FIXME: Get rid of this conditional. + if (cs.getOpenedTypes(locatorPtr).empty()) + cs.recordOpenedType(locatorPtr, replacements); + break; + } + + case PreparedOverload::Change::OpenedExistentialType: { + auto *locatorPtr = cs.getConstraintLocator(locator); + cs.recordOpenedExistentialType(locatorPtr, + change.TheExistential); + break; + } + + case PreparedOverload::Change::OpenedPackExpansionType: + cs.recordOpenedPackExpansionType( + change.PackExpansion.TheExpansion, + change.PackExpansion.TypeVar); + break; + + case PreparedOverload::Change::AppliedPropertyWrapper: { + auto *locatorPtr = cs.getConstraintLocator(locator); + Expr *anchor = getAsExpr(locatorPtr->getAnchor()); + cs.applyPropertyWrapper(anchor, + { Type(change.PropertyWrapper.WrapperType), + change.PropertyWrapper.InitKind }); + break; + } + + case PreparedOverload::Change::AddedFix: + cs.recordFix(change.Fix.TheFix, change.Fix.Impact); + break; + } + } +} + void ConstraintSystem::resolveOverload(ConstraintLocator *locator, Type boundType, OverloadChoice choice, DeclContext *useDC) { @@ -2427,10 +2581,12 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator, declRefType = getTypeOfMemberReference( baseTy, choice.getDecl(), useDC, (kind == OverloadChoiceKind::DeclViaDynamic), - choice.getFunctionRefInfo(), locator, nullptr); + choice.getFunctionRefInfo(), locator, nullptr, + /*preparedOverload=*/nullptr); } else { declRefType = getTypeOfReference( - choice.getDecl(), choice.getFunctionRefInfo(), locator, useDC); + choice.getDecl(), choice.getFunctionRefInfo(), locator, useDC, + /*preparedOverload=*/nullptr); } openedType = declRefType.openedType;