From e072ad2b65ca7a52d31d9e2b8f2634f68abf8559 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 01:42:58 -0400 Subject: [PATCH 01/30] [NFC] Change the printing of AbstractionPattern to include the sub map --- lib/SIL/IR/AbstractionPattern.cpp | 32 ++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index 5ce151a7cf576..92bff9908b4e2 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -1376,6 +1376,26 @@ void AbstractionPattern::dump() const { llvm::errs() << "\n"; } +static void printGenerics(raw_ostream &out, const AbstractionPattern &pattern) { + if (auto sig = pattern.getGenericSignature()) { + sig->print(out); + } + // It'd be really nice if we could get these interleaved with the types. + if (auto subs = pattern.getGenericSubstitutions()) { + out << "@<"; + bool first = false; + for (auto sub : subs.getReplacementTypes()) { + if (!first) { + out << ","; + } else { + first = true; + } + out << sub; + } + out << ">"; + } +} + void AbstractionPattern::print(raw_ostream &out) const { switch (getKind()) { case Kind::Invalid: @@ -1396,9 +1416,7 @@ void AbstractionPattern::print(raw_ostream &out) const { ? "AP::Type" : getKind() == Kind::Discard ? "AP::Discard" : "<>"); - if (auto sig = getGenericSignature()) { - sig->print(out); - } + printGenerics(out, *this); out << '('; getType().dump(out); out << ')'; @@ -1425,9 +1443,7 @@ void AbstractionPattern::print(raw_ostream &out) const { getKind() == Kind::ObjCCompletionHandlerArgumentsType ? "AP::ObjCCompletionHandlerArgumentsType(" : "AP::CFunctionAsMethodType("); - if (auto sig = getGenericSignature()) { - sig->print(out); - } + printGenerics(out, *this); getType().dump(out); out << ", "; // [TODO: Improve-Clang-type-printing] @@ -1459,9 +1475,7 @@ void AbstractionPattern::print(raw_ostream &out) const { getKind() == Kind::CurriedCXXMethodType ? "AP::CurriedCXXMethodType(" : "AP::PartialCurriedCXXMethodType"); - if (auto sig = getGenericSignature()) { - sig->print(out); - } + printGenerics(out, *this); getType().dump(out); out << ", "; getCXXMethod()->dump(); From 3b10e59cd07033596230591763ccf9615d1ec820 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 01:45:27 -0400 Subject: [PATCH 02/30] Fix unsafeGetSubstFieldType to propagate a substitution map in whichever case it happens to be in. This is a basic fix so that parallel walks on tuples and function types in the substituted type will work . Separately, though, I do not think the places that use this really need to be passed an orig type; this is used for computing type properties, and I am not aware of any reason we should need an orig type to compute type properties. Additionally, the orig types computed by this function are not really correct because of the substitution being done in some cases, so it'd be very nice to rip this all out. I'm not good to look into that right now, though. --- include/swift/SIL/AbstractionPattern.h | 3 ++- lib/SIL/IR/AbstractionPattern.cpp | 9 ++++++--- lib/SIL/IR/TypeLowering.cpp | 11 +++++++---- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/include/swift/SIL/AbstractionPattern.h b/include/swift/SIL/AbstractionPattern.h index e9771df716956..6f31004f6233e 100644 --- a/include/swift/SIL/AbstractionPattern.h +++ b/include/swift/SIL/AbstractionPattern.h @@ -783,7 +783,8 @@ class AbstractionPattern { /// Note that, for most purposes, you should lower a field's type against its /// *unsubstituted* interface type. AbstractionPattern - unsafeGetSubstFieldType(ValueDecl *member, CanType origMemberType) const; + unsafeGetSubstFieldType(ValueDecl *member, CanType origMemberType, + SubstitutionMap subMap) const; private: /// Return an abstraction pattern for the curried type of an diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index 92bff9908b4e2..1896258930562 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -1583,13 +1583,14 @@ bool AbstractionPattern::hasSameBasicTypeStructure(CanType l, CanType r) { AbstractionPattern AbstractionPattern::unsafeGetSubstFieldType(ValueDecl *member, - CanType origMemberInterfaceType) + CanType origMemberInterfaceType, + SubstitutionMap subMap) const { assert(origMemberInterfaceType); if (isTypeParameterOrOpaqueArchetype()) { // Fall back to the generic abstraction pattern for the member. auto sig = member->getDeclContext()->getGenericSignatureOfContext(); - return AbstractionPattern(sig.getCanonicalSignature(), + return AbstractionPattern(subMap, sig.getCanonicalSignature(), origMemberInterfaceType); } @@ -1626,7 +1627,9 @@ const { member, origMemberInterfaceType) ->getReducedType(getGenericSignature()); - return AbstractionPattern(getGenericSignature(), memberTy); + return AbstractionPattern(getGenericSubstitutions(), + getGenericSignature(), + memberTy); } llvm_unreachable("invalid abstraction pattern kind"); } diff --git a/lib/SIL/IR/TypeLowering.cpp b/lib/SIL/IR/TypeLowering.cpp index fe71adc24b2d0..dcb0242bc741b 100644 --- a/lib/SIL/IR/TypeLowering.cpp +++ b/lib/SIL/IR/TypeLowering.cpp @@ -2343,7 +2343,8 @@ namespace { auto sig = field->getDeclContext()->getGenericSignatureOfContext(); auto interfaceTy = field->getInterfaceType()->getReducedType(sig); auto origFieldType = origType.unsafeGetSubstFieldType(field, - interfaceTy); + interfaceTy, + subMap); properties.addSubobject(classifyType(origFieldType, substFieldType, TC, Expansion)); @@ -2422,7 +2423,8 @@ namespace { auto origEltType = origType.unsafeGetSubstFieldType(elt, elt->getArgumentInterfaceType() - ->getReducedType(D->getGenericSignature())); + ->getReducedType(D->getGenericSignature()), + subMap); properties.addSubobject(classifyType(origEltType, substEltType, TC, Expansion)); properties = @@ -2765,7 +2767,8 @@ bool TypeConverter::visitAggregateLeaves( auto interfaceTy = structField->getInterfaceType()->getReducedType(sig); auto origFieldType = - origTy.unsafeGetSubstFieldType(structField, interfaceTy); + origTy.unsafeGetSubstFieldType(structField, interfaceTy, + subMap); insertIntoWorklist(substFieldTy, origFieldType, structField, llvm::None); } @@ -2782,7 +2785,7 @@ bool TypeConverter::visitAggregateLeaves( ->getCanonicalType(); auto origElementTy = origTy.unsafeGetSubstFieldType( element, element->getArgumentInterfaceType()->getReducedType( - decl->getGenericSignature())); + decl->getGenericSignature()), subMap); insertIntoWorklist(substElementType, origElementTy, element, llvm::None); From 75782df77b63c9bb6074f2a13a892148c6b524b6 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 01:53:35 -0400 Subject: [PATCH 03/30] When we are performing SIL substitution, and we reach a type that needs to be lowered, use an opaque abstraction pattern. As I argue in the comment, we know that the orig type is now either an opaque type or a type with high-level structure that is invariant to lowering. Substitution will not change the latter property, and an opaque abstraction pattern is correct for the former. Attempting to create a "truer" abstraction pattern that preserves more structure from the orig type is both pointless and problematic. The substitutions we just did may have replaced pack references with non-pack types if there are active expansions in progress; this cannot be easily explained in terms of substitutions. (In theory, we could pass a more opaque concept of substitutions through AbstractionPattern, which might help with this. That would also make it harder to catch bugs with signature mismatches, though.) --- lib/SIL/IR/SILFunctionType.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 4ed98fb82dcf6..28f21ac4f49da 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -4748,7 +4748,17 @@ class SILTypeSubstituter : return origType; } - AbstractionPattern abstraction(Sig, origType); + // We've looked through all the top-level structure in the orig + // type that's affected by type lowering. If substitution has + // given us a type with top-level structure that's affected by + // type lowering, it must be because the orig type was a type + // variable of some sort, and we should lower using an opaque + // abstraction pattern. If substitution hasn't given us such a + // type, it doesn't matter what abstraction pattern we use, + // lowering will just come back with substType. So we can just + // use an opaque abstraction pattern here and not put any effort + // into computing a more "honest" abstraction pattern. + AbstractionPattern abstraction = AbstractionPattern::getOpaque(); return TC.getLoweredRValueType(typeExpansionContext, abstraction, substType); } From 9fd961ae768949e19ae005c8f6023aa7efc06c28 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 02:04:05 -0400 Subject: [PATCH 04/30] When computing the field type of a SILType, add substitutions to the orig type. --- lib/SIL/IR/SILType.cpp | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/lib/SIL/IR/SILType.cpp b/lib/SIL/IR/SILType.cpp index e4ee158320678..a0cbcdb204453 100644 --- a/lib/SIL/IR/SILType.cpp +++ b/lib/SIL/IR/SILType.cpp @@ -301,9 +301,30 @@ bool SILType::canRefCast(SILType operTy, SILType resultTy, SILModule &M) { && toTy.isHeapObjectReferenceType(); } +static bool needsFieldSubstitutions(const AbstractionPattern &origType) { + if (origType.isTypeParameter()) return false; + auto type = origType.getType(); + if (!type->hasTypeParameter()) return false; + return type.findIf([](CanType type) { + return isa(type); + }); +} + +static void addFieldSubstitutionsIfNeeded(TypeConverter &TC, SILType ty, + ValueDecl *field, + AbstractionPattern &origType) { + if (needsFieldSubstitutions(origType)) { + auto subMap = ty.getASTType()->getContextSubstitutionMap( + &TC.M, field->getDeclContext()); + origType = origType.withSubstitutions(subMap); + } +} + SILType SILType::getFieldType(VarDecl *field, TypeConverter &TC, TypeExpansionContext context) const { AbstractionPattern origFieldTy = TC.getAbstractionPattern(field); + addFieldSubstitutionsIfNeeded(TC, *this, field, origFieldTy); + CanType substFieldTy; if (field->hasClangNode()) { substFieldTy = origFieldTy.getType(); @@ -372,6 +393,9 @@ SILType SILType::getEnumElementType(EnumElementDecl *elt, TypeConverter &TC, getCategory()); } + auto origEltType = TC.getAbstractionPattern(elt); + addFieldSubstitutionsIfNeeded(TC, *this, elt, origEltType); + auto substEltTy = getASTType()->getTypeOfMember( &TC.M, elt, elt->getArgumentInterfaceType()); auto loweredTy = TC.getLoweredRValueType( From a1b49bd6b538026c0491c3eeef3c021f1d80d1dd Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 02:05:21 -0400 Subject: [PATCH 05/30] Do proper parallel walks of orig+subst types when computing type properties for struct and enum types. --- lib/SIL/IR/TypeLowering.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/lib/SIL/IR/TypeLowering.cpp b/lib/SIL/IR/TypeLowering.cpp index dcb0242bc741b..26bc61f04e11b 100644 --- a/lib/SIL/IR/TypeLowering.cpp +++ b/lib/SIL/IR/TypeLowering.cpp @@ -749,11 +749,12 @@ namespace { RetTy visitTupleType(CanTupleType type, AbstractionPattern origType, IsTypeExpansionSensitive_t isSensitive) { RecursiveProperties props; - for (unsigned i = 0, e = type->getNumElements(); i < e; ++i) { - props.addSubobject(classifyType(origType.getTupleElementType(i), - type.getElementType(i), - TC, Expansion)); - } + origType.forEachExpandedTupleElement(type, + [&](AbstractionPattern origEltType, CanType substEltType, + const TupleTypeElt &elt) { + props.addSubobject( + classifyType(origEltType, substEltType, TC, Expansion)); + }); props = mergeIsTypeExpansionSensitive(isSensitive, props); return asImpl().handleAggregateByProperties(type, props); } @@ -2250,12 +2251,12 @@ namespace { AbstractionPattern origType, IsTypeExpansionSensitive_t isSensitive) { RecursiveProperties properties; - for (unsigned i = 0, e = tupleType->getNumElements(); i < e; ++i) { - auto eltType = tupleType.getElementType(i); - auto origEltType = origType.getTupleElementType(i); - auto &lowering = TC.getTypeLowering(origEltType, eltType, Expansion); - properties.addSubobject(lowering.getRecursiveProperties()); - } + origType.forEachExpandedTupleElement(tupleType, + [&](AbstractionPattern origEltType, CanType substEltType, + const TupleTypeElt &elt) { + properties.addSubobject( + classifyType(origEltType, substEltType, TC, Expansion)); + }); properties = mergeIsTypeExpansionSensitive(isSensitive, properties); return handleAggregateByProperties(tupleType, From d648cf166622fc428551dfe298054953d6ceec9b Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 02:06:15 -0400 Subject: [PATCH 06/30] Do a proper orig+subst walk of tuple expression elements in call argument emission. --- lib/SILGen/SILGenApply.cpp | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/lib/SILGen/SILGenApply.cpp b/lib/SILGen/SILGenApply.cpp index 78eda39d4ce3d..5fc3d2caa5ed5 100644 --- a/lib/SILGen/SILGenApply.cpp +++ b/lib/SILGen/SILGenApply.cpp @@ -3389,10 +3389,23 @@ class ArgEmitter { // If the source expression is a tuple literal, we can break it // up directly. if (auto tuple = dyn_cast(e)) { - for (auto i : indices(tuple->getElements())) { - emit(tuple->getElement(i), - origParamType.getTupleElementType(i)); - } + auto substTupleType = + cast(e->getType()->getCanonicalType()); + origParamType.forEachTupleElement(substTupleType, + [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origEltType, CanType substEltType) { + emit(tuple->getElement(substEltIndex), origEltType); + }, + [&](unsigned origEltIndex, unsigned substEltIndex, + AbstractionPattern origExpansionType, + CanTupleEltTypeArrayRef substEltTypes) { + SmallVector eltArgs; + eltArgs.reserve(substEltTypes.size()); + for (auto i : range(substEltIndex, substEltTypes.size())) { + eltArgs.emplace_back(tuple->getElement(i)); + } + emitPackArg(eltArgs, origExpansionType); + }); return; } From 1f43cf2213c20d37cc0a17322a07ca2a1e03ffe9 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 15:38:31 -0400 Subject: [PATCH 07/30] [NFC] Rename isPackExpansion -> isOrigPackExpansion for clarity --- include/swift/SIL/AbstractionPatternGenerators.h | 4 ++-- lib/SIL/IR/AbstractionPattern.cpp | 2 +- lib/SIL/IR/SILFunctionType.cpp | 2 +- lib/SILGen/FunctionInputGenerator.h | 4 ++-- lib/SILGen/SILGenPoly.cpp | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/swift/SIL/AbstractionPatternGenerators.h b/include/swift/SIL/AbstractionPatternGenerators.h index b0cbd0624f432..99ebf5ad98774 100644 --- a/include/swift/SIL/AbstractionPatternGenerators.h +++ b/include/swift/SIL/AbstractionPatternGenerators.h @@ -51,7 +51,7 @@ class FunctionParamGenerator { unsigned substParamIndex = 0; /// The number of subst parameters corresponding to the current - /// subst parameter. + /// orig parameter. unsigned numSubstParamsForOrigParam; /// Whether the orig function type is opaque, i.e. does not permit us to @@ -125,7 +125,7 @@ class FunctionParamGenerator { } /// Return whether the current orig parameter type is a pack expansion. - bool isPackExpansion() const { + bool isOrigPackExpansion() const { assert(!isFinished()); return origParamIsExpansion; } diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index 1896258930562..6fc6d2d34d9a1 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -2236,7 +2236,7 @@ class SubstFunctionTypePatternVisitor pattern.forEachFunctionParam(func.getParams(), /*ignore self*/ false, [&](FunctionParamGenerator ¶m) { - if (!param.isPackExpansion()) { + if (!param.isOrigPackExpansion()) { auto newParamTy = visit(param.getSubstParams()[0].getParameterType(), param.getOrigType()); addParam(param.getOrigFlags(), newParamTy); diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 28f21ac4f49da..bea349cc85f6f 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -1576,7 +1576,7 @@ class DestructureInputs { // If the parameter is not a pack expansion, just pull off the // next parameter and destructure it in parallel with the abstraction // pattern for the type. - if (!param.isPackExpansion()) { + if (!param.isOrigPackExpansion()) { visit(param.getOrigType(), param.getSubstParams()[0], /*forSelf*/false); return; diff --git a/lib/SILGen/FunctionInputGenerator.h b/lib/SILGen/FunctionInputGenerator.h index 4131748443427..1af627be8d21c 100644 --- a/lib/SILGen/FunctionInputGenerator.h +++ b/lib/SILGen/FunctionInputGenerator.h @@ -89,7 +89,7 @@ class FunctionInputGenerator { /// Ready the current orig parameter. void readyOrigParameter() { substParamIndex = 0; - if (origParam.isPackExpansion()) { + if (origParam.isOrigPackExpansion()) { // The pack value exists in the lowered parameters and must be // claimed whether it contains formal parameters or not. packValue = inputs.claimNext(); @@ -125,7 +125,7 @@ class FunctionInputGenerator { } bool isOrigPackExpansion() const { - return origParam.isPackExpansion(); + return origParam.isOrigPackExpansion(); } AbstractionPattern getOrigType() const { diff --git a/lib/SILGen/SILGenPoly.cpp b/lib/SILGen/SILGenPoly.cpp index 84dbfab354ce5..f7f64070c152d 100644 --- a/lib/SILGen/SILGenPoly.cpp +++ b/lib/SILGen/SILGenPoly.cpp @@ -1094,7 +1094,7 @@ class TranslateArguments { // output type, it corresponds to N formal parameters in the // substituted output type. translateToPackParam will pull off // N substituted formal parameters from the input type. - if (outputParams.isPackExpansion()) { + if (outputParams.isOrigPackExpansion()) { auto outputPackParam = claimNextOutputType(); auto output = translateToPackParam(inputParams, From 98b60e36fbc3cee1d403dc1bea1cd2e39b26b775 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 02:08:03 -0400 Subject: [PATCH 08/30] Add a test case for the work in this PR. My original test case here used a memberwise initializer, but those use their own logic for binding and forward parameters which will need to be updated separately. --- test/SILGen/variadic-generic-tuples.swift | 49 +++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/test/SILGen/variadic-generic-tuples.swift b/test/SILGen/variadic-generic-tuples.swift index a0fe0fad7bb03..fbcfbe6ec51e7 100644 --- a/test/SILGen/variadic-generic-tuples.swift +++ b/test/SILGen/variadic-generic-tuples.swift @@ -204,3 +204,52 @@ func projectTupleElements(_ value: repeat Wrapper) { let tuple = (repeat (each value).value) } + +func takesVariadicTuple(tuple: (repeat each T)) {} + +// CHECK-LABEL: sil{{.*}} @$s4main28testConcreteVariadicTupleArg1i1sySi_SStF : +// CHECK: [[PACK:%.*]] = alloc_pack $Pack{Int, String} +// CHECK-NEXT: [[I_COPY:%.*]] = alloc_stack $Int +// CHECK-NEXT: store %0 to [trivial] [[I_COPY]] : $*Int +// CHECK-NEXT: [[I_INDEX:%.*]] = scalar_pack_index 0 of $Pack{Int, String} +// CHECK-NEXT: pack_element_set [[I_COPY]] : $*Int into [[I_INDEX]] of [[PACK]] : +// CHECK-NEXT: [[S_COPY:%.*]] = alloc_stack $String +// CHECK-NEXT: [[T0:%.*]] = copy_value %1 : $String +// CHECK-NEXT: store [[T0]] to [init] [[S_COPY]] : $*String +// CHECK-NEXT: [[S_INDEX:%.*]] = scalar_pack_index 1 of $Pack{Int, String} +// CHECK-NEXT: pack_element_set [[S_COPY]] : $*String into [[S_INDEX]] of [[PACK]] : +// CHECK-NEXT: // function_ref +// CHECK-NEXT: [[FN:%.*]] = function_ref @$s4main18takesVariadicTuple5tupleyxxQp_t_tRvzlF : $@convention(thin) (@pack_guaranteed Pack{repeat each τ_0_0}) -> () +// CHECK-NEXT: apply [[FN]]([[PACK]]) +// CHECK-NEXT: destroy_addr [[S_COPY]] : +// CHECK-NEXT: dealloc_stack [[S_COPY]] : +// CHECK-NEXT: dealloc_stack [[I_COPY]] : +// CHECK-NEXT: dealloc_pack [[PACK]] : +func testConcreteVariadicTupleArg(i: Int, s: String) { + takesVariadicTuple(tuple: (i, s)) +} + +struct TupleHolder { + var content: (repeat each T) + + // Suppress the memberwise initializer + init(values: repeat each T) { + content = (repeat each values) + } +} + +// CHECK-LABEL: sil{{.*}} @$s4main31takesConcreteTupleHolderFactory7factoryyAA0dE0VySi_SSQPGyXE_tF : +// CHECK-SAME: $@convention(thin) (@guaranteed @noescape @callee_guaranteed () -> @owned TupleHolder) -> () +// CHECK: [[T0:%.*]] = copy_value %0 : +// CHECK: [[T1:%.*]] = begin_borrow [[T0]] +// CHECK: [[RESULT:%.*]] = apply [[T1]]() : +// CHECK: destroy_value [[RESULT]] +func takesConcreteTupleHolderFactory(factory: () -> TupleHolder) { + let holder = factory() +} + +/* We still crash with memberwise initializers +func generateConcreteMemberTuple() -> TupleHolder { + return HasMemberTuple(content: (0, "hello")) +} + */ From b82c81cb1f5e9512465b0ef70bcf4d7b8139240f Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 15:39:18 -0400 Subject: [PATCH 09/30] [NFC] Move forEachTupleElement to use a generator --- include/swift/SIL/AbstractionPattern.h | 19 +-- .../swift/SIL/AbstractionPatternGenerators.h | 123 ++++++++++++++++++ lib/SIL/IR/AbstractionPattern.cpp | 79 +++++------ lib/SIL/IR/SILFunctionType.cpp | 33 ++--- lib/SILGen/ResultPlan.cpp | 28 ++-- lib/SILGen/SILGenApply.cpp | 40 +++--- lib/SILGen/SILGenProlog.cpp | 36 ++--- lib/SILGen/SILGenStmt.cpp | 48 ++++--- 8 files changed, 251 insertions(+), 155 deletions(-) diff --git a/include/swift/SIL/AbstractionPattern.h b/include/swift/SIL/AbstractionPattern.h index 6f31004f6233e..3ece2ef05de9c 100644 --- a/include/swift/SIL/AbstractionPattern.h +++ b/include/swift/SIL/AbstractionPattern.h @@ -36,6 +36,7 @@ namespace clang { namespace swift { namespace Lowering { class FunctionParamGenerator; +class TupleElementGenerator; /// A pattern for the abstraction of a value. /// @@ -1174,6 +1175,10 @@ class AbstractionPattern { return CXXMethod; } + bool isOpaqueTuple() const { + return getKind() == Kind::Tuple; + } + bool isOpaqueFunctionOrOpaqueDerivativeFunction() const { return (getKind() == Kind::OpaqueFunction || getKind() == Kind::OpaqueDerivativeFunction); @@ -1356,20 +1361,8 @@ class AbstractionPattern { /// expand to. /// /// This pattern must be a tuple pattern. - /// - /// Calls handleScalar or handleExpansion as appropriate for each - /// element of the original tuple, in order. void forEachTupleElement(CanTupleType substType, - llvm::function_ref - handleScalar, - llvm::function_ref - handleExpansion) const; + llvm::function_ref fn) const; /// Perform a parallel visitation of the elements of a tuple type, /// expanding the elements of the type. This preserves the structure diff --git a/include/swift/SIL/AbstractionPatternGenerators.h b/include/swift/SIL/AbstractionPatternGenerators.h index 99ebf5ad98774..e25636e096d05 100644 --- a/include/swift/SIL/AbstractionPatternGenerators.h +++ b/include/swift/SIL/AbstractionPatternGenerators.h @@ -148,6 +148,129 @@ class FunctionParamGenerator { } }; +/// A generator for traversing the formal elements of a tuple type +/// while properly respecting variadic generics. +class TupleElementGenerator { + // The steady state of the generator. + + /// The abstraction pattern of the entire tuple type. Set once + /// during construction. + AbstractionPattern origTupleType; + + /// The substitute tuple type. Set once during construction. + CanTupleType substTupleType; + + /// The number of orig elements to traverse. Set once during + /// construction. + unsigned numOrigElts; + + /// The index of the current orig element. + /// Incremented during advance(). + unsigned origEltIndex = 0; + + /// The (start) index of the current subst elements. + /// Incremented during advance(). + unsigned substEltIndex = 0; + + /// The number of subst elements corresponding to the current + /// orig element. + unsigned numSubstEltsForOrigElt; + + /// Whether the orig tuple type is opaque, i.e. does not permit us to + /// call getNumTupleElements() and similar accessors. Set once during + /// construction. + bool origTupleTypeIsOpaque; + + /// Whether the current orig element is a pack expansion. + bool origEltIsExpansion; + + /// The abstraction pattern of the current orig element. + /// If it is a pack expansion, this is the expansion type, not the + /// pattern type. + AbstractionPattern origEltType = AbstractionPattern::getInvalid(); + + /// Load the informaton for the current orig element into the + /// fields above for it. + void loadElement() { + origEltType = origTupleType.getTupleElementType(origEltIndex); + origEltIsExpansion = origEltType.isPackExpansion(); + numSubstEltsForOrigElt = + (origEltIsExpansion + ? origEltType.getNumPackExpandedComponents() + : 1); + } + +public: + TupleElementGenerator(AbstractionPattern origTupleType, + CanTupleType substTupleType); + + /// Is the traversal finished? If so, none of the getters below + /// are allowed to be called. + bool isFinished() const { + return origEltIndex == numOrigElts; + } + + /// Advance to the next orig element. + void advance() { + assert(!isFinished()); + origEltIndex++; + substEltIndex += numSubstEltsForOrigElt; + if (!isFinished()) loadElement(); + } + + /// Return the index of the current orig element. + unsigned getOrigIndex() const { + assert(!isFinished()); + return origEltIndex; + } + + /// Return the index of the (first) subst element corresponding + /// to the current orig element. + unsigned getSubstIndex() const { + assert(!isFinished()); + return origEltIndex; + } + + /// Return a tuple element for the current orig element. + TupleTypeElt getOrigElement() const { + assert(!isFinished()); + return (origTupleTypeIsOpaque + ? substTupleType->getElement(substEltIndex) + : cast(origTupleType.getType()) + ->getElement(origEltIndex)); + } + + /// Return the type of the current orig element. + const AbstractionPattern &getOrigType() const { + assert(!isFinished()); + return origEltType; + } + + /// Return whether the current orig element type is a pack expansion. + bool isOrigPackExpansion() const { + assert(!isFinished()); + return origEltIsExpansion; + } + + /// Return the substituted elements corresponding to the current + /// orig element type. If the current orig element is not a + /// pack expansion, this will have exactly one element. + CanTupleEltTypeArrayRef getSubstTypes() const { + assert(!isFinished()); + return substTupleType.getElementTypes().slice(substEltIndex, + numSubstEltsForOrigElt); + } + + /// Call this to finalize the traversal and assert that it was done + /// properly. + void finish() { + assert(isFinished() && "didn't finish the traversal"); + assert(substEltIndex == substTupleType->getNumElements() && + "didn't exhaust subst elements; possible missing subs on " + "orig tuple type"); + } +}; + } // end namespace Lowering } // end namespace swift diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index 6fc6d2d34d9a1..6f8e375959366 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -470,35 +470,25 @@ bool AbstractionPattern::doesTupleContainPackExpansionType() const { } void AbstractionPattern::forEachTupleElement(CanTupleType substType, - llvm::function_ref - handleScalar, - llvm::function_ref - handleExpansion) const { - assert(isTuple() && "can only call on a tuple expansion"); - assert(matchesTuple(substType)); - - size_t substEltIndex = 0; - auto substEltTypes = substType.getElementTypes(); - for (size_t origEltIndex : range(getNumTupleElements())) { - auto origEltType = getTupleElementType(origEltIndex); - if (!origEltType.isPackExpansion()) { - handleScalar(origEltIndex, substEltIndex, - origEltType, substEltTypes[substEltIndex]); - substEltIndex++; - } else { - auto numComponents = origEltType.getNumPackExpandedComponents(); - handleExpansion(origEltIndex, substEltIndex, origEltType, - substEltTypes.slice(substEltIndex, numComponents)); - substEltIndex += numComponents; - } + llvm::function_ref handleElement) const { + TupleElementGenerator elt(*this, substType); + for (; !elt.isFinished(); elt.advance()) { + handleElement(elt); } - assert(substEltIndex == substEltTypes.size()); + elt.finish(); +} + +TupleElementGenerator::TupleElementGenerator( + AbstractionPattern origTupleType, + CanTupleType substTupleType) + : origTupleType(origTupleType), substTupleType(substTupleType) { + assert(origTupleType.isTuple()); + assert(origTupleType.matchesTuple(substTupleType)); + + origTupleTypeIsOpaque = origTupleType.isOpaqueTuple(); + numOrigElts = origTupleType.getNumTupleElements(); + + if (!isFinished()) loadElement(); } void AbstractionPattern::forEachExpandedTupleElement(CanTupleType substType, @@ -2196,28 +2186,19 @@ class SubstFunctionTypePatternVisitor CanType visitTupleType(CanTupleType tuple, AbstractionPattern pattern) { assert(pattern.isTuple()); - // It's pretty weird for us to end up in this case with an - // open-coded tuple pattern, but it happens with opaque derivative - // functions in autodiff. - CanTupleType origTupleTypeForLabels = pattern.getAs(); - if (!origTupleTypeForLabels) origTupleTypeForLabels = tuple; - SmallVector tupleElts; - pattern.forEachTupleElement(tuple, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origEltType, CanType substEltType) { - auto eltTy = visit(substEltType, origEltType); - auto &origElt = origTupleTypeForLabels->getElement(origEltIndex); - tupleElts.push_back(origElt.getWithType(eltTy)); - }, [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { - CanType candidateSubstType; - if (!substEltTypes.empty()) - candidateSubstType = substEltTypes[0]; - auto eltTy = handlePackExpansion(origExpansionType, candidateSubstType); - auto &origElt = origTupleTypeForLabels->getElement(origEltIndex); - tupleElts.push_back(origElt.getWithType(eltTy)); + pattern.forEachTupleElement(tuple, [&](TupleElementGenerator &elt) { + auto substEltTypes = elt.getSubstTypes(); + CanType eltTy; + if (!elt.isOrigPackExpansion()) { + eltTy = visit(substEltTypes[0], elt.getOrigType()); + } else { + CanType candidateSubstType; + if (!substEltTypes.empty()) + candidateSubstType = substEltTypes[0]; + eltTy = handlePackExpansion(elt.getOrigType(), candidateSubstType); + } + tupleElts.push_back(elt.getOrigElement().getWithType(eltTy)); }); return CanType(TupleType::get(tupleElts, TC.Context)); diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index bea349cc85f6f..749b4cff0ed09 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -1274,21 +1274,21 @@ class DestructureResults { if (origType.isTuple()) { auto substTupleType = cast(substType); origType.forEachTupleElement(substTupleType, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origEltType, CanType substEltType) { + [&](TupleElementGenerator &elt) { // If the original element type is not a pack expansion, just // pull off the next substituted element type. - destructure(origEltType, substEltType); + if (!elt.isOrigPackExpansion()) { + destructure(elt.getOrigType(), elt.getSubstTypes()[0]); + return; + } - }, [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { // If the original element type is a pack expansion, build a // lowered pack type for the substituted components it expands to. + auto origExpansionType = elt.getOrigType(); bool indirect = origExpansionType.arePackElementsPassedIndirectly(TC); SmallVector packElts; - for (auto substEltType : substEltTypes) { + for (auto substEltType : elt.getSubstTypes()) { auto origComponentType = origExpansionType.getPackExpansionComponentType(substEltType); CanType loweredEltTy = @@ -1690,16 +1690,17 @@ class DestructureInputs { assert(ownership != ValueOwnership::InOut); assert(origType.isTuple()); - origType.forEachTupleElement(substType, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origEltType, CanType substEltType) { - visit(ownership, forSelf, origEltType, substEltType, - isNonDifferentiable); - }, [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { + origType.forEachTupleElement(substType, [&](TupleElementGenerator &elt) { + if (!elt.isOrigPackExpansion()) { + visit(ownership, forSelf, elt.getOrigType(), elt.getSubstTypes()[0], + isNonDifferentiable); + return; + } + + auto origExpansionType = elt.getOrigType(); + SmallVector packElts; - for (auto substEltType : substEltTypes) { + for (auto substEltType : elt.getSubstTypes()) { auto origComponentType = origExpansionType.getPackExpansionComponentType(substEltType); auto loweredEltTy = diff --git a/lib/SILGen/ResultPlan.cpp b/lib/SILGen/ResultPlan.cpp index 5f0e10b2ea774..99875cfc01f7c 100644 --- a/lib/SILGen/ResultPlan.cpp +++ b/lib/SILGen/ResultPlan.cpp @@ -18,6 +18,7 @@ #include "RValue.h" #include "SILGenFunction.h" #include "swift/AST/GenericEnvironment.h" +#include "swift/SIL/AbstractionPatternGenerators.h" using namespace swift; using namespace Lowering; @@ -617,19 +618,20 @@ class TupleInitializationResultPlan final : public ResultPlan { eltPlans.reserve(origType.getNumTupleElements()); origType.forEachTupleElement(substType, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origEltType, - CanType substEltType) { - Initialization *eltInit = eltInits[substEltIndex].get(); - eltPlans.push_back(builder.build(eltInit, origEltType, substEltType)); - }, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { - auto componentInits = eltInits.slice(substEltIndex, substEltTypes.size()); - eltPlans.push_back(builder.buildForPackExpansion(componentInits, - origExpansionType, - substEltTypes)); + [&](TupleElementGenerator &elt) { + auto origEltType = elt.getOrigType(); + auto substEltTypes = elt.getSubstTypes(); + if (!elt.isOrigPackExpansion()) { + Initialization *eltInit = eltInits[elt.getSubstIndex()].get(); + eltPlans.push_back(builder.build(eltInit, origEltType, + substEltTypes[0])); + } else { + auto componentInits = + eltInits.slice(elt.getSubstIndex(), substEltTypes.size()); + eltPlans.push_back(builder.buildForPackExpansion(componentInits, + origEltType, + substEltTypes)); + } }); } diff --git a/lib/SILGen/SILGenApply.cpp b/lib/SILGen/SILGenApply.cpp index 5fc3d2caa5ed5..cff8bf88d5e04 100644 --- a/lib/SILGen/SILGenApply.cpp +++ b/lib/SILGen/SILGenApply.cpp @@ -42,6 +42,7 @@ #include "swift/Basic/SourceManager.h" #include "swift/Basic/Unicode.h" #include "swift/SIL/PrettyStackTrace.h" +#include "swift/SIL/AbstractionPatternGenerators.h" #include "swift/SIL/SILArgument.h" #include "clang/AST/DeclCXX.h" #include "clang/AST/DeclObjC.h" @@ -2203,19 +2204,16 @@ static unsigned getFlattenedValueCount(AbstractionPattern origType, // Otherwise, add up the elements. unsigned count = 0; - origType.forEachTupleElement(substTuple, - [&](unsigned origEltIndex, - unsigned substEltIndex, - AbstractionPattern origEltType, - CanType substEltType) { - // Recursively expand scalar components. - count += getFlattenedValueCount(origEltType, substEltType); - }, [&](unsigned origEltIndex, - unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { + origType.forEachTupleElement(substTuple, [&](TupleElementGenerator &elt) { // Expansion components turn into a single parameter. - count++; + if (elt.isOrigPackExpansion()) { + count++; + + // Recursively expand scalar components. + } else { + count += getFlattenedValueCount(elt.getOrigType(), + elt.getSubstTypes()[0]); + } }); return count; } @@ -3392,19 +3390,19 @@ class ArgEmitter { auto substTupleType = cast(e->getType()->getCanonicalType()); origParamType.forEachTupleElement(substTupleType, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origEltType, CanType substEltType) { - emit(tuple->getElement(substEltIndex), origEltType); - }, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { + [&](TupleElementGenerator &elt) { + if (!elt.isOrigPackExpansion()) { + emit(tuple->getElement(elt.getSubstIndex()), elt.getOrigType()); + return; + } + + auto substEltTypes = elt.getSubstTypes(); SmallVector eltArgs; eltArgs.reserve(substEltTypes.size()); - for (auto i : range(substEltIndex, substEltTypes.size())) { + for (auto i : range(elt.getSubstIndex(), substEltTypes.size())) { eltArgs.emplace_back(tuple->getElement(i)); } - emitPackArg(eltArgs, origExpansionType); + emitPackArg(eltArgs, elt.getOrigType()); }); return; } diff --git a/lib/SILGen/SILGenProlog.cpp b/lib/SILGen/SILGenProlog.cpp index 34b3da9674438..e10791acb7598 100644 --- a/lib/SILGen/SILGenProlog.cpp +++ b/lib/SILGen/SILGenProlog.cpp @@ -361,24 +361,24 @@ class EmitBBArguments : public CanTypeVisitor elements; - orig.forEachTupleElement(t, - [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origEltType, - CanType substEltType) { - auto elt = visit(substEltType, origEltType, - init ? eltInits[substEltIndex].get() : nullptr); - assert((init != nullptr) == (elt.isInContext())); - if (!elt.isInContext()) - elements.push_back(elt); - - if (elt.hasCleanup()) - canBeGuaranteed = false; - }, [&](unsigned origEltIndex, unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { - assert(init); - expandPack(origExpansionType, substEltTypes, substEltIndex, - eltInits, elements); + orig.forEachTupleElement(t, [&](TupleElementGenerator &elt) { + auto origEltType = elt.getOrigType(); + auto substEltTypes = elt.getSubstTypes(); + if (!elt.isOrigPackExpansion()) { + auto eltValue = + visit(substEltTypes[0], origEltType, + init ? eltInits[elt.getSubstIndex()].get() : nullptr); + assert((init != nullptr) == (eltValue.isInContext())); + if (!eltValue.isInContext()) + elements.push_back(eltValue); + + if (eltValue.hasCleanup()) + canBeGuaranteed = false; + } else { + assert(init); + expandPack(origEltType, substEltTypes, elt.getSubstIndex(), + eltInits, elements); + } }); // If we emitted into a context, we're done. diff --git a/lib/SILGen/SILGenStmt.cpp b/lib/SILGen/SILGenStmt.cpp index 49d0ecb27bfc3..69c9cc0ad6393 100644 --- a/lib/SILGen/SILGenStmt.cpp +++ b/lib/SILGen/SILGenStmt.cpp @@ -24,6 +24,7 @@ #include "swift/AST/DiagnosticsSIL.h" #include "swift/Basic/ProfileCounter.h" #include "swift/SIL/BasicBlockUtils.h" +#include "swift/SIL/AbstractionPatternGenerators.h" #include "swift/SIL/SILArgument.h" #include "llvm/Support/SaveAndRestore.h" @@ -581,31 +582,28 @@ prepareIndirectResultInit(SILGenFunction &SGF, SILLocation loc, tupleInit->SubInitializations.reserve(resultTupleType->getNumElements()); origResultType.forEachTupleElement(resultTupleType, - [&](unsigned origEltIndex, - unsigned substEltIndex, - AbstractionPattern origEltType, - CanType substEltType) { - auto eltInit = prepareIndirectResultInit(SGF, loc, fnTypeForResults, - origEltType, substEltType, - allResults, - directResults, - indirectResultAddrs, cleanups); - tupleInit->SubInitializations.push_back(std::move(eltInit)); - }, - [&](unsigned origEltIndex, - unsigned substEltIndex, - AbstractionPattern origExpansionType, - CanTupleEltTypeArrayRef substEltTypes) { - assert(allResults[0].isPack()); - assert(SGF.silConv.isSILIndirect(allResults[0])); - allResults = allResults.slice(1); - - auto packAddr = indirectResultAddrs[0]; - indirectResultAddrs = indirectResultAddrs.slice(1); - - preparePackResultInit(SGF, loc, origExpansionType, substEltTypes, - packAddr, - cleanups, tupleInit->SubInitializations); + [&](TupleElementGenerator &elt) { + if (!elt.isOrigPackExpansion()) { + auto eltInit = prepareIndirectResultInit(SGF, loc, fnTypeForResults, + elt.getOrigType(), + elt.getSubstTypes()[0], + allResults, + directResults, + indirectResultAddrs, + cleanups); + tupleInit->SubInitializations.push_back(std::move(eltInit)); + } else { + assert(allResults[0].isPack()); + assert(SGF.silConv.isSILIndirect(allResults[0])); + allResults = allResults.slice(1); + + auto packAddr = indirectResultAddrs[0]; + indirectResultAddrs = indirectResultAddrs.slice(1); + + preparePackResultInit(SGF, loc, elt.getOrigType(), elt.getSubstTypes(), + packAddr, + cleanups, tupleInit->SubInitializations); + } }); return InitializationPtr(tupleInit); From eaec51122f20188e7349880e322bb7cffb8a19f6 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 22:02:01 -0400 Subject: [PATCH 10/30] [NFC] Rename Generator::getCurrent() to Generator::get() Just for brevity's sake. --- include/swift/Basic/Generators.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/include/swift/Basic/Generators.h b/include/swift/Basic/Generators.h index df6051162b926..748711b0eec7f 100644 --- a/include/swift/Basic/Generators.h +++ b/include/swift/Basic/Generators.h @@ -31,6 +31,10 @@ // concept SimpleGenerator : Generator { // type reference; // +// // Get the current value. +// reference get(); +// +// // Get the current value and then advance the generator. // reference claimNext(); // } // @@ -103,7 +107,7 @@ class ArrayRefGenerator { } /// Return the current element of the array. - reference getCurrent() const { + reference get() const { assert(!isFinished()); return values.front(); } @@ -111,7 +115,7 @@ class ArrayRefGenerator { /// Claim the current element of the array and advance past it. reference claimNext() { assert(!isFinished()); - reference result = getCurrent(); + reference result = get(); advance(); return result; } From 239bcd8c004bed8b8a4fbbe3694aabe1ba2b9929 Mon Sep 17 00:00:00 2001 From: John McCall Date: Wed, 22 Mar 2023 22:03:18 -0400 Subject: [PATCH 11/30] Fix memberwise initializers for structs with variadic-tuple fields --- lib/SILGen/SILGenConstructor.cpp | 190 ++++++++++++++++++++-- test/SILGen/variadic-generic-tuples.swift | 38 ++++- 2 files changed, 208 insertions(+), 20 deletions(-) diff --git a/lib/SILGen/SILGenConstructor.cpp b/lib/SILGen/SILGenConstructor.cpp index 2c7a81ea65dea..aa3210cdc4f6c 100644 --- a/lib/SILGen/SILGenConstructor.cpp +++ b/lib/SILGen/SILGenConstructor.cpp @@ -25,6 +25,7 @@ #include "swift/AST/ParameterList.h" #include "swift/AST/PropertyWrappers.h" #include "swift/Basic/Defer.h" +#include "swift/Basic/Generators.h" #include "swift/SIL/SILArgument.h" #include "swift/SIL/SILInstruction.h" #include "swift/SIL/SILUndef.h" @@ -33,6 +34,61 @@ using namespace swift; using namespace Lowering; +namespace { + +class LoweredParamsInContextGenerator { + SILGenFunction &SGF; + ArrayRefGenerator> loweredParams; + +public: + LoweredParamsInContextGenerator(SILGenFunction &SGF) + : SGF(SGF), + loweredParams(SGF.F.getLoweredFunctionType()->getParameters()) { + } + + using reference = SILType; + + /// Get the original (unsubstituted into context) lowered parameter + /// type information. + SILParameterInfo getOrigInfo() const { + return loweredParams.get(); + } + + SILType get() const { + return SGF.getSILTypeInContext(loweredParams.get(), + SGF.F.getLoweredFunctionType()); + } + + SILType claimNext() { + auto param = get(); + advance(); + return param; + } + + bool isFinished() const { + return loweredParams.isFinished(); + } + + void advance() { + loweredParams.advance(); + } + + void finish() { + loweredParams.finish(); + } +}; + +} // end anonymous namespace + +static ManagedValue emitManagedParameter(SILGenFunction &SGF, + SILValue value, bool isOwned) { + if (isOwned) { + return SGF.emitManagedRValueWithCleanup(value); + } else { + return ManagedValue::forUnmanaged(value); + } +} + static SILValue emitConstructorMetatypeArg(SILGenFunction &SGF, ValueDecl *ctor) { // In addition to the declared arguments, the constructor implicitly takes @@ -63,14 +119,65 @@ static SILValue emitConstructorMetatypeArg(SILGenFunction &SGF, static RValue emitImplicitValueConstructorArg(SILGenFunction &SGF, SILLocation loc, CanType interfaceType, - DeclContext *DC) { + DeclContext *DC, + LoweredParamsInContextGenerator &loweredParamTypes, + Initialization *argInit = nullptr) { auto type = DC->mapTypeIntoContext(interfaceType)->getCanonicalType(); // Restructure tuple arguments. - if (auto tupleTy = dyn_cast(interfaceType)) { + if (auto tupleIfaceTy = dyn_cast(interfaceType)) { + // If we don't have a context to emit into, but we have a tuple + // that contains pack expansions, create a temporary. + TemporaryInitializationPtr tempInit; + if (!argInit && tupleIfaceTy.containsPackExpansionType()) { + tempInit = SGF.emitTemporary(loc, SGF.getTypeLowering(type)); + argInit = tempInit.get(); + } + + // Split the initialization into element initializations if we have + // one. We should never have to deal with an initialization that + // can't be split here. + assert(!argInit || argInit->canSplitIntoTupleElements()); + SmallVector initsBuf; + MutableArrayRef eltInits; + if (argInit) { + eltInits = argInit->splitIntoTupleElements(SGF, loc, type, initsBuf); + assert(eltInits.size() == tupleIfaceTy->getNumElements()); + } + RValue tuple(type); - for (auto fieldType : tupleTy.getElementTypes()) - tuple.addElement(emitImplicitValueConstructorArg(SGF, loc, fieldType, DC)); + + for (auto eltIndex : range(tupleIfaceTy->getNumElements())) { + auto eltIfaceType = tupleIfaceTy.getElementType(eltIndex); + auto eltInit = (argInit ? eltInits[eltIndex].get() : nullptr); + RValue element = emitImplicitValueConstructorArg(SGF, loc, eltIfaceType, + DC, loweredParamTypes, + eltInit); + if (argInit) { + assert(element.isInContext()); + } else { + tuple.addElement(std::move(element)); + } + } + + // If we created a temporary initializer above, finish it and claim + // the managed buffer. + if (tempInit) { + tempInit->finishInitialization(SGF); + + auto tupleValue = tempInit->getManagedAddress(); + if (tupleValue.getType().isLoadable(SGF.F)) { + tupleValue = SGF.B.createLoadTake(loc, tupleValue); + } + + return RValue(SGF, loc, type, tupleValue); + + // Otherwise, if we have an emitInto, return forInContext(). + } else if (argInit) { + argInit->finishInitialization(SGF); + return RValue::forInContext(); + } + return tuple; } @@ -83,13 +190,51 @@ static RValue emitImplicitValueConstructorArg(SILGenFunction &SGF, VD->setSpecifier(ParamSpecifier::Default); VD->setInterfaceType(interfaceType); - auto argType = SGF.getLoweredTypeForFunctionArgument(type); + auto origParamInfo = loweredParamTypes.getOrigInfo(); + auto argType = loweredParamTypes.claimNext(); + auto *arg = SGF.F.begin()->createFunctionArgument(argType, VD); - ManagedValue mvArg; - if (arg->getArgumentConvention().isOwnedConvention()) { - mvArg = SGF.emitManagedRValueWithCleanup(arg); - } else { - mvArg = ManagedValue::forUnmanaged(arg); + bool argIsConsumed = origParamInfo.isConsumed(); + + // If the lowered parameter is a pack expansion, copy/move the pack + // into the initialization, which we assume is there. + if (auto packTy = argType.getAs()) { + assert(isa(interfaceType)); + assert(packTy->getNumElements() == 1); + assert(argInit); + assert(argInit->canPerformPackExpansionInitialization()); + + auto expansionTy = packTy->getSILElementType(0); + auto openedEnvAndEltTy = + SGF.createOpenedElementValueEnvironment(expansionTy); + auto openedEnv = openedEnvAndEltTy.first; + auto eltTy = openedEnvAndEltTy.second; + auto formalPackType = CanPackType::get(SGF.getASTContext(), {type}); + + SGF.emitDynamicPackLoop(loc, formalPackType, /*component*/0, openedEnv, + [&](SILValue indexWithinComponent, + SILValue packExpansionIndex, + SILValue packIndex) { + argInit->performPackExpansionInitialization(SGF, loc, + indexWithinComponent, + [&](Initialization *eltInit) { + auto eltAddr = + SGF.B.createPackElementGet(loc, packIndex, arg, eltTy); + ManagedValue eltMV = emitManagedParameter(SGF, eltAddr, argIsConsumed); + eltInit->copyOrInitValueInto(SGF, loc, eltMV, argIsConsumed); + eltInit->finishInitialization(SGF); + }); + }); + argInit->finishInitialization(SGF); + return RValue::forInContext(); + } + + ManagedValue mvArg = emitManagedParameter(SGF, arg, argIsConsumed); + + if (argInit) { + argInit->copyOrInitValueInto(SGF, loc, mvArg, argIsConsumed); + argInit->finishInitialization(SGF); + return RValue::forInContext(); } // This can happen if the value is resilient in the calling convention @@ -164,15 +309,19 @@ static void emitImplicitValueConstructor(SILGenFunction &SGF, AssertingManualScope functionLevelScope(SGF.Cleanups, CleanupLocation(Loc)); + auto loweredFunctionTy = SGF.F.getLoweredFunctionType(); + // FIXME: Handle 'self' along with the other arguments. + assert(loweredFunctionTy->getNumResults() == 1); + auto selfResultInfo = loweredFunctionTy->getResults()[0]; auto *paramList = ctor->getParameters(); auto *selfDecl = ctor->getImplicitSelfDecl(); auto selfIfaceTy = selfDecl->getInterfaceType(); - SILType selfTy = SGF.getLoweredTypeForFunctionArgument(selfDecl->getType()); + SILType selfTy = SGF.getSILTypeInContext(selfResultInfo, loweredFunctionTy); // Emit the indirect return argument, if any. SILValue resultSlot; - if (SILModuleConventions::isReturnedIndirectlyInSIL(selfTy, SGF.SGM.M)) { + if (selfTy.isAddress()) { auto &AC = SGF.getASTContext(); auto VD = new (AC) ParamDecl(SourceLoc(), SourceLoc(), AC.getIdentifier("$return_value"), @@ -181,10 +330,11 @@ static void emitImplicitValueConstructor(SILGenFunction &SGF, ctor); VD->setSpecifier(ParamSpecifier::InOut); VD->setInterfaceType(selfIfaceTy); - resultSlot = - SGF.F.begin()->createFunctionArgument(selfTy.getAddressType(), VD); + resultSlot = SGF.F.begin()->createFunctionArgument(selfTy, VD); } + LoweredParamsInContextGenerator loweredParams(SGF); + // Emit the elementwise arguments. SmallVector elements; for (size_t i = 0, size = paramList->size(); i < size; ++i) { @@ -192,10 +342,13 @@ static void emitImplicitValueConstructor(SILGenFunction &SGF, elements.push_back( emitImplicitValueConstructorArg( - SGF, Loc, param->getInterfaceType()->getCanonicalType(), ctor)); + SGF, Loc, param->getInterfaceType()->getCanonicalType(), ctor, + loweredParams)); } emitConstructorMetatypeArg(SGF, ctor); + (void) loweredParams.claimNext(); + loweredParams.finish(); auto *decl = selfTy.getStructOrBoundGenericStruct(); assert(decl && "not a struct?!"); @@ -601,16 +754,21 @@ void SILGenFunction::emitEnumConstructor(EnumElementDecl *element) { Scope scope(Cleanups, CleanupLoc); + LoweredParamsInContextGenerator loweredParams(*this); + // Emit the exploded constructor argument. ArgumentSource payload; if (element->hasAssociatedValues()) { auto eltArgTy = element->getArgumentInterfaceType()->getCanonicalType(); - RValue arg = emitImplicitValueConstructorArg(*this, Loc, eltArgTy, element); + RValue arg = emitImplicitValueConstructorArg(*this, Loc, eltArgTy, element, + loweredParams); payload = ArgumentSource(Loc, std::move(arg)); } // Emit the metatype argument. emitConstructorMetatypeArg(*this, element); + (void) loweredParams.claimNext(); + loweredParams.finish(); // If possible, emit the enum directly into the indirect return. SGFContext C = (dest ? SGFContext(dest.get()) : SGFContext()); diff --git a/test/SILGen/variadic-generic-tuples.swift b/test/SILGen/variadic-generic-tuples.swift index fbcfbe6ec51e7..b1585c5804b98 100644 --- a/test/SILGen/variadic-generic-tuples.swift +++ b/test/SILGen/variadic-generic-tuples.swift @@ -248,8 +248,38 @@ func takesConcreteTupleHolderFactory(factory: () -> TupleHolder) { let holder = factory() } -/* We still crash with memberwise initializers -func generateConcreteMemberTuple() -> TupleHolder { - return HasMemberTuple(content: (0, "hello")) +struct MemberwiseTupleHolder { + var content: (repeat each T) +} + +// Memberwise initializer. +// TODO: initialize directly into the fields +// CHECK-LABEL: sil{{.*}} @$s4main21MemberwiseTupleHolderV7contentACyxxQp_QPGxxQp_t_tcfC +// CHECK-SAME: $@convention(method) (@pack_owned Pack{repeat each T}, @thin MemberwiseTupleHolder.Type) -> @out MemberwiseTupleHolder { +// CHECK: [[TEMP:%.*]] = alloc_stack $(repeat each T) +// CHECK-NEXT: [[ZERO:%.*]] = integer_literal $Builtin.Word, 0 +// CHECK-NEXT: [[ONE:%.*]] = integer_literal $Builtin.Word, 1 +// CHECK-NEXT: [[LEN:%.*]] = pack_length $Pack{repeat each T} +// CHECK-NEXT: br bb1([[ZERO]] : $Builtin.Word) +// CHECK: bb1([[IDX:%.*]] : $Builtin.Word) +// CHECK-NEXT: [[IDX_EQ_LEN:%.*]] = builtin "cmp_eq_Word"([[IDX]] : $Builtin.Word, [[LEN]] : $Builtin.Word) : $Builtin.Int1 +// CHECK-NEXT: cond_br [[IDX_EQ_LEN]], bb3, bb2 +// CHECK: bb2: +// CHECK-NEXT: [[INDEX:%.*]] = dynamic_pack_index [[IDX]] of $Pack{repeat each T} +// CHECK-NEXT: open_pack_element [[INDEX]] of at , shape $T, uuid [[UUID:".*"]] +// CHECK-NEXT: [[TUPLE_ELT_ADDR:%.*]] = tuple_pack_element_addr [[INDEX]] of [[TEMP]] : $*(repeat each T) as $*@pack_element([[UUID]]) T +// CHECK-NEXT: [[PACK_ELT_ADDR:%.*]] = pack_element_get [[INDEX]] of %1 : $*Pack{repeat each T} as $*@pack_element([[UUID]]) T +// CHECK-NEXT: copy_addr [take] [[PACK_ELT_ADDR]] to [init] [[TUPLE_ELT_ADDR]] +// CHECK-NEXT: [[NEXT_IDX:%.*]] = builtin "add_Word"([[IDX]] : $Builtin.Word, [[ONE]] : $Builtin.Word) : $Builtin.Word +// CHECK-NEXT: br bb1([[NEXT_IDX]] : $Builtin.Word) +// CHECK: bb3: +// CHECK-NEXT: [[CONTENTS_ADDR:%.*]] = struct_element_addr %0 : $*MemberwiseTupleHolder, #MemberwiseTupleHolder.content +// CHECK-NEXT: copy_addr [take] [[TEMP]] to [init] [[CONTENTS_ADDR]] +// CHECK-NEXT: tuple () +// CHECK-NEXT: dealloc_stack [[TEMP]] +// CHECK-NEXT: return + + +func callVariadicMemberwiseInit() -> MemberwiseTupleHolder { + return MemberwiseTupleHolder(content: (0, "hello")) } - */ From dd9ae1d2927ea1879099725f7e11da374aa3076d Mon Sep 17 00:00:00 2001 From: John McCall Date: Thu, 23 Mar 2023 12:06:38 -0400 Subject: [PATCH 12/30] [NFC] Thread a common type through all the AST substitution code. There are a lot of problems caused by our highly-abstract substitution subsystem. Most of them would be solved by a more semantic / holistic understanding of the active transformation, but that's difficult to do because we just pass around function_refs. The first step in fixing that is to pass around a better currency type. For now, it can just hold the function_refs (and the SubstOptions). I've set it up so that the places that just apply SubstitutionMaps are constructing the IFS in a standard way; that should make it easy to change those places in the future. --- include/swift/AST/InFlightSubstitution.h | 88 +++++++++++++++++ include/swift/AST/PackConformance.h | 7 ++ include/swift/AST/ProtocolConformance.h | 7 ++ include/swift/AST/ProtocolConformanceRef.h | 7 ++ include/swift/AST/SubstitutionMap.h | 15 +++ include/swift/AST/Type.h | 7 ++ lib/AST/PackConformance.cpp | 50 +++++----- lib/AST/ProtocolConformance.cpp | 30 +++--- lib/AST/ProtocolConformanceRef.cpp | 24 +++-- lib/AST/SubstitutionMap.cpp | 35 ++++--- lib/AST/Type.cpp | 106 ++++++++++----------- 11 files changed, 258 insertions(+), 118 deletions(-) create mode 100644 include/swift/AST/InFlightSubstitution.h diff --git a/include/swift/AST/InFlightSubstitution.h b/include/swift/AST/InFlightSubstitution.h new file mode 100644 index 0000000000000..a137929d4c9d9 --- /dev/null +++ b/include/swift/AST/InFlightSubstitution.h @@ -0,0 +1,88 @@ +//===--- InFlightSubstitution.h - In-flight substitution data ---*- C++ -*-===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file defines the InFlightSubstitution structure, which captures +// all the information about a type substitution that's currently in +// progress. For now, this is meant to be an internal implementation +// detail of the substitution system, and other systems should not use +// it (unless they are part of the extended substitution system, such as +// the SIL type substituter) +// +//===----------------------------------------------------------------------===// + +#ifndef SWIFT_AST_INFLIGHTSUBSTITUTION_H +#define SWIFT_AST_INFLIGHTSUBSTITUTION_H + +#include "swift/AST/SubstitutionMap.h" + +namespace swift { +class SubstitutionMap; + +class InFlightSubstitution { + SubstOptions Options; + TypeSubstitutionFn BaselineSubstType; + LookupConformanceFn BaselineLookupConformance; + +public: + InFlightSubstitution(TypeSubstitutionFn substType, + LookupConformanceFn lookupConformance, + SubstOptions options) + : Options(options), + BaselineSubstType(substType), + BaselineLookupConformance(lookupConformance) {} + + InFlightSubstitution(const InFlightSubstitution &) = delete; + InFlightSubstitution &operator=(const InFlightSubstitution &) = delete; + + Type substType(SubstitutableType *ty) { + return BaselineSubstType(ty); + } + + ProtocolConformanceRef lookupConformance(CanType dependentType, + Type conformingReplacementType, + ProtocolDecl *conformedProtocol) { + return BaselineLookupConformance(dependentType, + conformingReplacementType, + conformedProtocol); + } + + SubstOptions getOptions() const { + return Options; + } + + /// Is the given type invariant to substitution? + bool isInvariant(Type type) const; +}; + +/// A helper classes that provides stable storage for the query +/// functions against a SubstitutionMap. +struct InFlightSubstitutionViaSubMapHelper { + QuerySubstitutionMap QueryType; + LookUpConformanceInSubstitutionMap QueryConformance; + + InFlightSubstitutionViaSubMapHelper(SubstitutionMap subMap) + : QueryType{subMap}, QueryConformance(subMap) {} +}; +class InFlightSubstitutionViaSubMap : + private InFlightSubstitutionViaSubMapHelper, + public InFlightSubstitution { + +public: + InFlightSubstitutionViaSubMap(SubstitutionMap subMap, + SubstOptions options) + : InFlightSubstitutionViaSubMapHelper(subMap), + InFlightSubstitution(QueryType, QueryConformance, options) {} +}; + +} // end namespace swift + +#endif diff --git a/include/swift/AST/PackConformance.h b/include/swift/AST/PackConformance.h index eee457af99ff2..7c722932c3465 100644 --- a/include/swift/AST/PackConformance.h +++ b/include/swift/AST/PackConformance.h @@ -88,6 +88,13 @@ class alignas(1 << DeclAlignInBits) PackConformance final LookupConformanceFn conformances, SubstOptions options=None) const; + /// Apply an in-flight substitution to the conformances in this + /// protocol conformance ref. + /// + /// This function should generally not be used outside of the + /// substitution subsystem. + ProtocolConformanceRef subst(InFlightSubstitution &IFS) const; + SWIFT_DEBUG_DUMP; void dump(llvm::raw_ostream &out, unsigned indent = 0) const; }; diff --git a/include/swift/AST/ProtocolConformance.h b/include/swift/AST/ProtocolConformance.h index dd24b843100b8..6e6348f439c8d 100644 --- a/include/swift/AST/ProtocolConformance.h +++ b/include/swift/AST/ProtocolConformance.h @@ -309,6 +309,13 @@ class alignas(1 << DeclAlignInBits) ProtocolConformance LookupConformanceFn conformances, SubstOptions options=None) const; + /// Substitute the conforming type and produce a ProtocolConformance that + /// applies to the substituted type. + /// + /// This function should generally not be used outside of the substitution + /// subsystem. + ProtocolConformance *subst(InFlightSubstitution &IFS) const; + SWIFT_DEBUG_DUMP; void dump(llvm::raw_ostream &out, unsigned indent = 0) const; }; diff --git a/include/swift/AST/ProtocolConformanceRef.h b/include/swift/AST/ProtocolConformanceRef.h index 67ef66d9f9262..4f1fa2e244bc4 100644 --- a/include/swift/AST/ProtocolConformanceRef.h +++ b/include/swift/AST/ProtocolConformanceRef.h @@ -164,6 +164,13 @@ class ProtocolConformanceRef { LookupConformanceFn conformances, SubstOptions options=None) const; + /// Apply a substitution to the conforming type. + /// + /// This function should generally not be used outside of the substitution + /// subsystem. + ProtocolConformanceRef subst(Type origType, + InFlightSubstitution &IFS) const; + /// Map contextual types to interface types in the conformance. ProtocolConformanceRef mapConformanceOutOfContext() const; diff --git a/include/swift/AST/SubstitutionMap.h b/include/swift/AST/SubstitutionMap.h index 3cd722b316123..3c047451904ba 100644 --- a/include/swift/AST/SubstitutionMap.h +++ b/include/swift/AST/SubstitutionMap.h @@ -121,6 +121,14 @@ class SubstitutionMap { TypeSubstitutionFn subs, LookupConformanceFn lookupConformance); + /// Build a substitution map from the substitutions represented by + /// the given in-flight substitution. + /// + /// This function should generally only be used by the substitution + /// subsystem. + static SubstitutionMap get(GenericSignature genericSig, + InFlightSubstitution &IFS); + /// Retrieve the generic signature describing the environment in which /// substitutions occur. GenericSignature getGenericSignature() const; @@ -182,6 +190,13 @@ class SubstitutionMap { LookupConformanceFn conformances, SubstOptions options=None) const; + /// Apply an in-flight substitution to all replacement types in the map. + /// Does not change keys. + /// + /// This should generally not be used outside of the substitution + /// subsystem. + SubstitutionMap subst(InFlightSubstitution &subs) const; + /// Apply type expansion lowering to all types in the substitution map. Opaque /// archetypes will be lowered to their underlying types if the type expansion /// context allows. diff --git a/include/swift/AST/Type.h b/include/swift/AST/Type.h index 6d59c99e8b48e..7bd6eb586c843 100644 --- a/include/swift/AST/Type.h +++ b/include/swift/AST/Type.h @@ -43,6 +43,7 @@ class ClassDecl; class CanType; class EnumDecl; class GenericSignatureImpl; +class InFlightSubstitution; class ModuleDecl; class NominalTypeDecl; class GenericTypeDecl; @@ -343,6 +344,12 @@ class Type { LookupConformanceFn conformances, SubstOptions options=None) const; + /// Apply an in-flight substitution to this type. + /// + /// This should generally not be used outside of the substitution + /// subsystem. + Type subst(InFlightSubstitution &subs) const; + bool isPrivateStdlibType(bool treatNonBuiltinProtocolsAsPublic = true) const; SWIFT_DEBUG_DUMP; diff --git a/lib/AST/PackConformance.cpp b/lib/AST/PackConformance.cpp index dd634bc643639..e4006e6c8ae7c 100644 --- a/lib/AST/PackConformance.cpp +++ b/lib/AST/PackConformance.cpp @@ -18,6 +18,7 @@ #include "swift/AST/PackConformance.h" #include "swift/AST/ASTContext.h" #include "swift/AST/Decl.h" +#include "swift/AST/InFlightSubstitution.h" #include "swift/AST/Module.h" #include "swift/AST/Types.h" @@ -162,9 +163,8 @@ PackConformance *PackConformance::getAssociatedConformance( ProtocolConformanceRef PackConformance::subst(SubstitutionMap subMap, SubstOptions options) const { - return subst(QuerySubstitutionMap{subMap}, - LookUpConformanceInSubstitutionMap(subMap), - options); + InFlightSubstitutionViaSubMap IFS(subMap, options); + return subst(IFS); } // TODO: Move this elsewhere since it's generally useful @@ -206,14 +206,9 @@ namespace { template class PackExpander { protected: - TypeSubstitutionFn subs; - LookupConformanceFn conformances; - SubstOptions options; + InFlightSubstitution &IFS; - PackExpander(TypeSubstitutionFn subs, - LookupConformanceFn conformances, - SubstOptions options) - : subs(subs), conformances(conformances), options(options) {} + PackExpander(InFlightSubstitution &IFS) : IFS(IFS) {} ImplClass *asImpl() { return static_cast(this); @@ -233,7 +228,7 @@ class PackExpander { // the expanded count pack type. llvm::SmallDenseMap expandedPacks; for (auto origParamType : rootParameterPacks) { - auto substParamType = origParamType.subst(subs, conformances, options); + auto substParamType = origParamType.subst(IFS); if (auto expandedParamType = substParamType->template getAs()) { assert(arePackShapesEqual(expandedParamType, expandedCountType) && @@ -262,7 +257,7 @@ class PackExpander { } // Compute the substituted type using our parent substitutions. - auto substType = Type(type).subst(subs, conformances, options); + auto substType = Type(type).subst(IFS); // If the substituted type is a pack, project the jth element. if (isRootParameterPack(type)) { @@ -277,12 +272,13 @@ class PackExpander { return packType->getElementType(j); } - return subs(type); + return IFS.substType(type); }; auto projectedConformances = [&](CanType origType, Type substType, ProtocolDecl *proto) -> ProtocolConformanceRef { - auto substConformance = conformances(origType, substType, proto); + auto substConformance = + IFS.lookupConformance(origType, substType, proto); // If the substituted conformance is a pack, project the jth element. if (isRootedInParameterPack(origType)) { @@ -294,7 +290,7 @@ class PackExpander { auto origCountElement = expandedCountType->getElementType(j); auto substCountElement = origCountElement.subst( - projectedSubs, projectedConformances, options); + projectedSubs, projectedConformances, IFS.getOptions()); asImpl()->add(origCountElement, substCountElement, i); } @@ -304,7 +300,7 @@ class PackExpander { /// form a new pack expansion. void addUnexpandedExpansion(Type origPatternType, Type substCountType, unsigned i) { - auto substPatternType = origPatternType.subst(subs, conformances, options); + auto substPatternType = origPatternType.subst(IFS); auto substExpansion = PackExpansionType::get(substPatternType, substCountType); asImpl()->add(origPatternType, substExpansion, i); @@ -313,7 +309,7 @@ class PackExpander { /// Scalar elements of the original pack are substituted and added to the /// flattened pack. void addScalar(Type origElement, unsigned i) { - auto substElement = origElement.subst(subs, conformances, options); + auto substElement = origElement.subst(IFS); asImpl()->add(origElement, substElement, i); } @@ -323,7 +319,7 @@ class PackExpander { auto origPatternType = origExpansion->getPatternType(); auto origCountType = origExpansion->getCountType(); - auto substCountType = origCountType.subst(subs, conformances, options); + auto substCountType = origCountType.subst(IFS); // If the substituted count type is a pack, we're expanding the // original element. @@ -358,19 +354,16 @@ class PackConformanceExpander : public PackExpander { ArrayRef origConformances; - PackConformanceExpander(TypeSubstitutionFn subs, - LookupConformanceFn conformances, - SubstOptions options, + PackConformanceExpander(InFlightSubstitution &IFS, ArrayRef origConformances) - : PackExpander(subs, conformances, options), - origConformances(origConformances) {} + : PackExpander(IFS), origConformances(origConformances) {} void add(Type origType, Type substType, unsigned i) { substElements.push_back(substType); // FIXME: Pass down projection callbacks substConformances.push_back(origConformances[i].subst( - origType, subs, conformances, options)); + origType, IFS)); } }; @@ -379,8 +372,13 @@ class PackConformanceExpander : public PackExpander { ProtocolConformanceRef PackConformance::subst(TypeSubstitutionFn subs, LookupConformanceFn conformances, SubstOptions options) const { - PackConformanceExpander expander(subs, conformances, options, - getPatternConformances()); + InFlightSubstitution IFS(subs, conformances, options); + return subst(IFS); +} + +ProtocolConformanceRef +PackConformance::subst(InFlightSubstitution &IFS) const { + PackConformanceExpander expander(IFS, getPatternConformances()); expander.expand(ConformingType); auto &ctx = Protocol->getASTContext(); diff --git a/lib/AST/ProtocolConformance.cpp b/lib/AST/ProtocolConformance.cpp index b8a05c97da57a..bdbbbe1dcbd0c 100644 --- a/lib/AST/ProtocolConformance.cpp +++ b/lib/AST/ProtocolConformance.cpp @@ -22,6 +22,7 @@ #include "swift/AST/DistributedDecl.h" #include "swift/AST/FileUnit.h" #include "swift/AST/GenericEnvironment.h" +#include "swift/AST/InFlightSubstitution.h" #include "swift/AST/LazyResolver.h" #include "swift/AST/Module.h" #include "swift/AST/TypeCheckRequests.h" @@ -920,15 +921,20 @@ bool ProtocolConformance::isVisibleFrom(const DeclContext *dc) const { ProtocolConformance * ProtocolConformance::subst(SubstitutionMap subMap, SubstOptions options) const { - return subst(QuerySubstitutionMap{subMap}, - LookUpConformanceInSubstitutionMap(subMap), - options); + InFlightSubstitutionViaSubMap IFS(subMap, options); + return subst(IFS); } ProtocolConformance * ProtocolConformance::subst(TypeSubstitutionFn subs, LookupConformanceFn conformances, SubstOptions options) const { + InFlightSubstitution IFS(subs, conformances, options); + return subst(IFS); +} + +ProtocolConformance * +ProtocolConformance::subst(InFlightSubstitution &IFS) const { switch (getKind()) { case ProtocolConformanceKind::Normal: { auto origType = getType(); @@ -936,12 +942,11 @@ ProtocolConformance::subst(TypeSubstitutionFn subs, !origType->hasArchetype()) return const_cast(this); - auto substType = origType.subst(subs, conformances, options); + auto substType = origType.subst(IFS); if (substType->isEqual(origType)) return const_cast(this); - auto subMap = SubstitutionMap::get(getGenericSignature(), - subs, conformances); + auto subMap = SubstitutionMap::get(getGenericSignature(), IFS); auto *mutableThis = const_cast(this); return substType->getASTContext() @@ -955,7 +960,7 @@ ProtocolConformance::subst(TypeSubstitutionFn subs, !origType->hasArchetype()) return const_cast(this); - auto substType = origType.subst(subs, conformances, options); + auto substType = origType.subst(IFS); // We do an exact pointer equality check because subst() can // change sugar. @@ -964,7 +969,7 @@ ProtocolConformance::subst(TypeSubstitutionFn subs, SmallVector requirements; for (auto req : getConditionalRequirements()) { - requirements.push_back(req.subst(subs, conformances, options)); + requirements.push_back(req.subst(IFS)); } auto kind = cast(this) @@ -992,11 +997,10 @@ ProtocolConformance::subst(TypeSubstitutionFn subs, if (origBaseType->hasTypeParameter() || origBaseType->hasArchetype()) { // Substitute into the superclass. - inheritedConformance = inheritedConformance->subst(subs, conformances, - options); + inheritedConformance = inheritedConformance->subst(IFS); } - auto substType = origType.subst(subs, conformances, options); + auto substType = origType.subst(IFS); return substType->getASTContext() .getInheritedConformance(substType, inheritedConformance); } @@ -1007,10 +1011,10 @@ ProtocolConformance::subst(TypeSubstitutionFn subs, auto subMap = spec->getSubstitutionMap(); auto origType = getType(); - auto substType = origType.subst(subs, conformances, options); + auto substType = origType.subst(IFS); return substType->getASTContext() .getSpecializedConformance(substType, genericConformance, - subMap.subst(subs, conformances, options)); + subMap.subst(IFS)); } } llvm_unreachable("bad ProtocolConformanceKind"); diff --git a/lib/AST/ProtocolConformanceRef.cpp b/lib/AST/ProtocolConformanceRef.cpp index 109bb40074621..e98bf08a22a21 100644 --- a/lib/AST/ProtocolConformanceRef.cpp +++ b/lib/AST/ProtocolConformanceRef.cpp @@ -19,6 +19,7 @@ #include "swift/AST/ASTContext.h" #include "swift/AST/Availability.h" #include "swift/AST/Decl.h" +#include "swift/AST/InFlightSubstitution.h" #include "swift/AST/Module.h" #include "swift/AST/PackConformance.h" #include "swift/AST/ProtocolConformance.h" @@ -57,10 +58,8 @@ ProtocolConformanceRef ProtocolConformanceRef::subst(Type origType, SubstitutionMap subMap, SubstOptions options) const { - return subst(origType, - QuerySubstitutionMap{subMap}, - LookUpConformanceInSubstitutionMap(subMap), - options); + InFlightSubstitutionViaSubMap IFS(subMap, options); + return subst(origType, IFS); } ProtocolConformanceRef @@ -68,28 +67,33 @@ ProtocolConformanceRef::subst(Type origType, TypeSubstitutionFn subs, LookupConformanceFn conformances, SubstOptions options) const { + InFlightSubstitution IFS(subs, conformances, options); + return subst(origType, IFS); +} + +ProtocolConformanceRef +ProtocolConformanceRef::subst(Type origType, InFlightSubstitution &IFS) const { if (isInvalid()) return *this; if (isConcrete()) - return ProtocolConformanceRef(getConcrete()->subst(subs, conformances, - options)); + return ProtocolConformanceRef(getConcrete()->subst(IFS)); if (isPack()) - return getPack()->subst(subs, conformances, options); + return getPack()->subst(IFS); // Handle abstract conformances below: // If the type is an opaque archetype, the conformance will remain abstract, // unless we're specifically substituting opaque types. if (auto origArchetype = origType->getAs()) { - if (!options.contains(SubstFlags::SubstituteOpaqueArchetypes) + if (!IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes) && isa(origArchetype)) { return *this; } } // Otherwise, compute the substituted type. - auto substType = origType.subst(subs, conformances, options); + auto substType = origType.subst(IFS); auto *proto = getRequirement(); @@ -105,7 +109,7 @@ ProtocolConformanceRef::subst(Type origType, } // Check the conformance map. - return conformances(origType->getCanonicalType(), substType, proto); + return IFS.lookupConformance(origType->getCanonicalType(), substType, proto); } ProtocolConformanceRef ProtocolConformanceRef::mapConformanceOutOfContext() const { diff --git a/lib/AST/SubstitutionMap.cpp b/lib/AST/SubstitutionMap.cpp index a490b51c8522f..d59f1961b5950 100644 --- a/lib/AST/SubstitutionMap.cpp +++ b/lib/AST/SubstitutionMap.cpp @@ -28,6 +28,7 @@ #include "swift/AST/Decl.h" #include "swift/AST/GenericEnvironment.h" #include "swift/AST/GenericParamList.h" +#include "swift/AST/InFlightSubstitution.h" #include "swift/AST/LazyResolver.h" #include "swift/AST/Module.h" #include "swift/AST/ProtocolConformance.h" @@ -209,6 +210,12 @@ SubstitutionMap SubstitutionMap::get(GenericSignature genericSig, SubstitutionMap SubstitutionMap::get(GenericSignature genericSig, TypeSubstitutionFn subs, LookupConformanceFn lookupConformance) { + InFlightSubstitution IFS(subs, lookupConformance, None); + return get(genericSig, IFS); +} + +SubstitutionMap SubstitutionMap::get(GenericSignature genericSig, + InFlightSubstitution &IFS) { if (!genericSig) { return SubstitutionMap(); } @@ -225,7 +232,7 @@ SubstitutionMap SubstitutionMap::get(GenericSignature genericSig, } // Record the replacement. - Type replacement = Type(gp).subst(subs, lookupConformance); + Type replacement = Type(gp).subst(IFS); assert((!replacement || replacement->hasError() || gp->isParameterPack() == replacement->is()) && @@ -240,9 +247,9 @@ SubstitutionMap SubstitutionMap::get(GenericSignature genericSig, if (req.getKind() != RequirementKind::Conformance) continue; CanType depTy = req.getFirstType()->getCanonicalType(); - auto replacement = depTy.subst(subs, lookupConformance); + auto replacement = depTy.subst(IFS); auto *proto = req.getProtocolDecl(); - auto conformance = lookupConformance(depTy, replacement, proto); + auto conformance = IFS.lookupConformance(depTy, replacement, proto); conformances.push_back(conformance); } @@ -440,14 +447,18 @@ SubstitutionMap SubstitutionMap::mapReplacementTypesOutOfContext() const { SubstitutionMap SubstitutionMap::subst(SubstitutionMap subMap, SubstOptions options) const { - return subst(QuerySubstitutionMap{subMap}, - LookUpConformanceInSubstitutionMap(subMap), - options); + InFlightSubstitutionViaSubMap IFS(subMap, options); + return subst(IFS); } SubstitutionMap SubstitutionMap::subst(TypeSubstitutionFn subs, LookupConformanceFn conformances, SubstOptions options) const { + InFlightSubstitution IFS(subs, conformances, options); + return subst(IFS); +} + +SubstitutionMap SubstitutionMap::subst(InFlightSubstitution &IFS) const { if (empty()) return SubstitutionMap(); SmallVector newSubs; @@ -457,7 +468,7 @@ SubstitutionMap SubstitutionMap::subst(TypeSubstitutionFn subs, newSubs.push_back(Type()); continue; } - newSubs.push_back(type.subst(subs, conformances, options)); + newSubs.push_back(type.subst(IFS)); assert(type->is() == newSubs.back()->is() && "substitution changed the pack-ness of a replacement type"); } @@ -474,16 +485,14 @@ SubstitutionMap SubstitutionMap::subst(TypeSubstitutionFn subs, // Fast path for concrete case -- we don't need to compute substType // at all. if (conformance.isConcrete() && - !options.contains(SubstFlags::SubstituteOpaqueArchetypes)) { + !IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes)) { newConformances.push_back( - ProtocolConformanceRef( - conformance.getConcrete()->subst(subs, conformances, options))); + ProtocolConformanceRef(conformance.getConcrete()->subst(IFS))); } else { auto origType = req.getFirstType(); - auto substType = origType.subst(*this, options); + auto substType = origType.subst(*this, IFS.getOptions()); - newConformances.push_back( - conformance.subst(substType, subs, conformances, options)); + newConformances.push_back(conformance.subst(substType, IFS)); } oldConformances = oldConformances.slice(1); diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 760191feb73bb..2c965368c254e 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -27,6 +27,7 @@ #include "swift/AST/TypeWalker.h" #include "swift/AST/Decl.h" #include "swift/AST/GenericEnvironment.h" +#include "swift/AST/InFlightSubstitution.h" #include "swift/AST/LazyResolver.h" #include "swift/AST/Module.h" #include "swift/AST/PackConformance.h" @@ -4388,12 +4389,11 @@ CanGenericFunctionType::substGenericArgs(SubstitutionMap subs) const { getPointer()->substGenericArgs(subs)->getCanonicalType()); } -static Type getMemberForBaseType(LookupConformanceFn lookupConformances, +static Type getMemberForBaseType(InFlightSubstitution &IFS, Type origBase, Type substBase, AssociatedTypeDecl *assocType, - Identifier name, - SubstOptions options) { + Identifier name) { // Produce a dependent member type for the given base type. auto getDependentMemberType = [&](Type baseType) { if (assocType) @@ -4441,7 +4441,7 @@ static Type getMemberForBaseType(LookupConformanceFn lookupConformances, auto proto = assocType->getProtocol(); ProtocolConformanceRef conformance = - lookupConformances(origBase->getCanonicalType(), substBase, proto); + IFS.lookupConformance(origBase->getCanonicalType(), substBase, proto); if (conformance.isInvalid()) return failed(); @@ -4456,7 +4456,8 @@ static Type getMemberForBaseType(LookupConformanceFn lookupConformances, assocType->getDeclaredInterfaceType()); } else if (conformance.isConcrete()) { auto witness = - conformance.getConcrete()->getTypeWitnessAndDecl(assocType, options); + conformance.getConcrete()->getTypeWitnessAndDecl(assocType, + IFS.getOptions()); witnessTy = witness.getWitnessType(); if (!witnessTy || witnessTy->hasError()) @@ -4464,7 +4465,7 @@ static Type getMemberForBaseType(LookupConformanceFn lookupConformances, // This is a hacky feature allowing code completion to migrate to // using Type::subst() without changing output. - if (options & SubstFlags::DesugarMemberTypes) { + if (IFS.getOptions() & SubstFlags::DesugarMemberTypes) { if (auto *aliasType = dyn_cast(witnessTy.getPointer())) witnessTy = aliasType->getSinglyDesugaredType(); @@ -4567,8 +4568,9 @@ Type DependentMemberType::substBaseType(Type substBase, substBase->hasTypeParameter()) return this; - return getMemberForBaseType(lookupConformance, getBase(), substBase, - getAssocType(), getName(), None); + InFlightSubstitution IFS(nullptr, lookupConformance, None); + return getMemberForBaseType(IFS, getBase(), substBase, + getAssocType(), getName()); } Type DependentMemberType::substRootParam(Type newRoot, @@ -4585,15 +4587,12 @@ Type DependentMemberType::substRootParam(Type newRoot, } static Type substGenericFunctionType(GenericFunctionType *genericFnType, - TypeSubstitutionFn substitutions, - LookupConformanceFn lookupConformances, - SubstOptions options) { + InFlightSubstitution &IFS) { // Substitute into the function type (without generic signature). auto *bareFnType = FunctionType::get(genericFnType->getParams(), genericFnType->getResult(), genericFnType->getExtInfo()); - Type result = - Type(bareFnType).subst(substitutions, lookupConformances, options); + Type result = Type(bareFnType).subst(IFS); if (!result || result->is()) return result; auto *fnType = result->castTo(); @@ -4601,8 +4600,7 @@ static Type substGenericFunctionType(GenericFunctionType *genericFnType, bool anySemanticChanges = false; SmallVector genericParams; for (auto param : genericFnType->getGenericParams()) { - Type paramTy = - Type(param).subst(substitutions, lookupConformances, options); + Type paramTy = Type(param).subst(IFS); if (!paramTy) return Type(); @@ -4624,7 +4622,7 @@ static Type substGenericFunctionType(GenericFunctionType *genericFnType, SmallVector requirements; for (const auto &req : genericFnType->getRequirements()) { // Substitute into the requirement. - auto substReqt = req.subst(substitutions, lookupConformances, options); + auto substReqt = req.subst(IFS); // Did anything change? if (!anySemanticChanges && @@ -4654,27 +4652,27 @@ static Type substGenericFunctionType(GenericFunctionType *genericFnType, fnType->getResult(), fnType->getExtInfo()); } -static Type substType(Type derivedType, - TypeSubstitutionFn substitutions, - LookupConformanceFn lookupConformances, - SubstOptions options) { +bool InFlightSubstitution::isInvariant(Type derivedType) const { + return !derivedType->hasArchetype() + && !derivedType->hasTypeParameter() + && (!Options.contains(SubstFlags::SubstituteOpaqueArchetypes) + || !derivedType->hasOpaqueArchetype()); +} + +static Type substType(Type derivedType, InFlightSubstitution &IFS) { // Handle substitutions into generic function types. if (auto genericFnType = derivedType->getAs()) { - return substGenericFunctionType(genericFnType, substitutions, - lookupConformances, options); + return substGenericFunctionType(genericFnType, IFS); } // FIXME: Change getTypeOfMember() to not pass GenericFunctionType here - if (!derivedType->hasArchetype() - && !derivedType->hasTypeParameter() - && (!options.contains(SubstFlags::SubstituteOpaqueArchetypes) - || !derivedType->hasOpaqueArchetype())) + if (IFS.isInvariant(derivedType)) return derivedType; return derivedType.transformRec([&](TypeBase *type) -> Optional { // FIXME: Add SIL versions of mapTypeIntoContext() and // mapTypeOutOfContext() and use them appropriately - assert((options.contains(SubstFlags::AllowLoweredTypes) || + assert((IFS.getOptions().contains(SubstFlags::AllowLoweredTypes) || !isa(type)) && "should not be doing AST type-substitution on a lowered SIL type;" "use SILType::subst"); @@ -4683,7 +4681,7 @@ static Type substType(Type derivedType, // we want to structurally substitute the substitutions. if (auto boxTy = dyn_cast(type)) { auto subMap = boxTy->getSubstitutions(); - auto newSubMap = subMap.subst(substitutions, lookupConformances, options); + auto newSubMap = subMap.subst(IFS); return SILBoxType::get(boxTy->getASTContext(), boxTy->getLayout(), @@ -4691,10 +4689,8 @@ static Type substType(Type derivedType, } if (auto packExpansionTy = dyn_cast(type)) { - auto patternTy = substType(packExpansionTy->getPatternType(), - substitutions, lookupConformances, options); - auto countTy = substType(packExpansionTy->getCountType(), - substitutions, lookupConformances, options); + auto patternTy = substType(packExpansionTy->getPatternType(), IFS); + auto countTy = substType(packExpansionTy->getCountType(), IFS); if (auto *archetypeTy = countTy->getAs()) countTy = archetypeTy->getReducedShape(); @@ -4705,11 +4701,11 @@ static Type substType(Type derivedType, if (silFnTy->isPolymorphic()) return None; if (auto subs = silFnTy->getInvocationSubstitutions()) { - auto newSubs = subs.subst(substitutions, lookupConformances, options); + auto newSubs = subs.subst(IFS); return silFnTy->withInvocationSubstitutions(newSubs); } if (auto subs = silFnTy->getPatternSubstitutions()) { - auto newSubs = subs.subst(substitutions, lookupConformances, options); + auto newSubs = subs.subst(IFS); return silFnTy->withPatternSubstitutions(newSubs); } return None; @@ -4719,14 +4715,11 @@ static Type substType(Type derivedType, if (auto aliasTy = dyn_cast(type)) { Type parentTy; if (auto origParentTy = aliasTy->getParent()) - parentTy = substType(origParentTy, - substitutions, lookupConformances, options); - auto underlyingTy = substType(aliasTy->getSinglyDesugaredType(), - substitutions, lookupConformances, options); + parentTy = substType(origParentTy, IFS); + auto underlyingTy = substType(aliasTy->getSinglyDesugaredType(), IFS); if (parentTy && parentTy->isExistentialType()) return underlyingTy; - auto subMap = aliasTy->getSubstitutionMap() - .subst(substitutions, lookupConformances, options); + auto subMap = aliasTy->getSubstitutionMap().subst(IFS); return Type(TypeAliasType::get(aliasTy->getDecl(), parentTy, subMap, underlyingTy)); } @@ -4736,12 +4729,11 @@ static Type substType(Type derivedType, // For dependent member types, we may need to look up the member if the // base is resolved to a non-dependent type. if (auto depMemTy = dyn_cast(type)) { - auto newBase = substType(depMemTy->getBase(), - substitutions, lookupConformances, options); - return getMemberForBaseType(lookupConformances, + auto newBase = substType(depMemTy->getBase(), IFS); + return getMemberForBaseType(IFS, depMemTy->getBase(), newBase, depMemTy->getAssocType(), - depMemTy->getName(), options); + depMemTy->getName()); } auto substOrig = dyn_cast(type); @@ -4750,13 +4742,13 @@ static Type substType(Type derivedType, // Opaque types can't normally be directly substituted unless we // specifically were asked to substitute them. - if (!options.contains(SubstFlags::SubstituteOpaqueArchetypes) + if (!IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes) && isa(substOrig)) return None; // If we have a substitution for this type, use it. - if (auto known = substitutions(substOrig)) { - if (options.contains(SubstFlags::SubstituteOpaqueArchetypes) && + if (auto known = IFS.substType(substOrig)) { + if (IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes) && isa(substOrig) && known->getCanonicalType() == substOrig->getCanonicalType()) return None; // Recursively process the substitutions of the opaque type @@ -4784,8 +4776,7 @@ static Type substType(Type derivedType, assert(parent && "Not a nested archetype"); // Substitute into the parent type. - Type substParent = substType(parent, substitutions, - lookupConformances, options); + Type substParent = substType(parent, IFS); // If the parent didn't change, we won't change. if (substParent.getPointer() == parent) @@ -4795,23 +4786,26 @@ static Type substType(Type derivedType, AssociatedTypeDecl *assocType = origArchetype->getInterfaceType() ->castTo()->getAssocType(); - return getMemberForBaseType(lookupConformances, parent, substParent, - assocType, assocType->getName(), options); + return getMemberForBaseType(IFS, parent, substParent, + assocType, assocType->getName()); }); } Type Type::subst(SubstitutionMap substitutions, SubstOptions options) const { - return substType(*this, - QuerySubstitutionMap{substitutions}, - LookUpConformanceInSubstitutionMap(substitutions), - options); + InFlightSubstitutionViaSubMap IFS(substitutions, options); + return substType(*this, IFS); } Type Type::subst(TypeSubstitutionFn substitutions, LookupConformanceFn conformances, SubstOptions options) const { - return substType(*this, substitutions, conformances, options); + InFlightSubstitution IFS(substitutions, conformances, options); + return substType(*this, IFS); +} + +Type Type::subst(InFlightSubstitution &IFS) const { + return substType(*this, IFS); } DependentMemberType *TypeBase::findUnresolvedDependentMemberType() { From 9c9b153bd235126f07fdc06031669d3f740fb060 Mon Sep 17 00:00:00 2001 From: John McCall Date: Thu, 23 Mar 2023 12:20:08 -0400 Subject: [PATCH 13/30] [NFC] Split the SIL type substitution code into its own file --- lib/SIL/IR/CMakeLists.txt | 1 + lib/SIL/IR/SILFunctionType.cpp | 708 ---------------------------- lib/SIL/IR/SILTypeSubstitution.cpp | 734 +++++++++++++++++++++++++++++ 3 files changed, 735 insertions(+), 708 deletions(-) create mode 100644 lib/SIL/IR/SILTypeSubstitution.cpp diff --git a/lib/SIL/IR/CMakeLists.txt b/lib/SIL/IR/CMakeLists.txt index 00b599bcbe050..336529098ab2f 100644 --- a/lib/SIL/IR/CMakeLists.txt +++ b/lib/SIL/IR/CMakeLists.txt @@ -30,6 +30,7 @@ target_sources(swiftSIL PRIVATE SILSuccessor.cpp SILSymbolVisitor.cpp SILType.cpp + SILTypeSubstitution.cpp SILUndef.cpp SILVTable.cpp SILValue.cpp diff --git a/lib/SIL/IR/SILFunctionType.cpp b/lib/SIL/IR/SILFunctionType.cpp index 749b4cff0ed09..b9276b8204175 100644 --- a/lib/SIL/IR/SILFunctionType.cpp +++ b/lib/SIL/IR/SILFunctionType.cpp @@ -26,8 +26,6 @@ #include "swift/AST/GenericEnvironment.h" #include "swift/AST/Module.h" #include "swift/AST/ModuleLoader.h" -#include "swift/AST/PackConformance.h" -#include "swift/AST/ProtocolConformance.h" #include "swift/AST/TypeCheckRequests.h" #include "swift/ClangImporter/ClangImporter.h" #include "swift/SIL/SILModule.h" @@ -4273,712 +4271,6 @@ TypeConverter::getConstantOverrideInfo(TypeExpansionContext context, return *result; } -namespace { - -/// Given a lowered SIL type, apply a substitution to it to produce another -/// lowered SIL type which uses the same abstraction conventions. -class SILTypeSubstituter : - public CanTypeVisitor { - TypeConverter &TC; - TypeSubstitutionFn Subst; - LookupConformanceFn Conformances; - // The signature for the original type. - // - // Replacement types are lowered with respect to the current - // context signature. - CanGenericSignature Sig; - - struct PackExpansion { - /// The shape class of pack parameters that are expanded by this - /// expansion. Set during construction and not changed. - CanType OrigShapeClass; - - /// The count type of the pack expansion in the current lane of - /// expansion, if any. Pack elements in this lane should be - /// expansions with this shape. - CanType SubstPackExpansionCount; - - /// The index of the current lane of expansion. Basic - /// substitution of pack parameters with the same shape as - /// OrigShapeClass should yield a pack, and lanewise - /// substitution should produce this element of that pack. - unsigned Index; - - PackExpansion(CanType origShapeClass) - : OrigShapeClass(origShapeClass), Index(0) {} - }; - SmallVector ActivePackExpansions; - - TypeExpansionContext typeExpansionContext; - - bool shouldSubstituteOpaqueArchetypes; - -public: - SILTypeSubstituter(TypeConverter &TC, - TypeExpansionContext context, - TypeSubstitutionFn Subst, - LookupConformanceFn Conformances, - CanGenericSignature Sig, - bool shouldSubstituteOpaqueArchetypes) - : TC(TC), - Subst(Subst), - Conformances(Conformances), - Sig(Sig), - typeExpansionContext(context), - shouldSubstituteOpaqueArchetypes(shouldSubstituteOpaqueArchetypes) - {} - - // SIL type lowering only does special things to tuples and functions. - - // When a function appears inside of another type, we only perform - // substitutions if it is not polymorphic. - CanSILFunctionType visitSILFunctionType(CanSILFunctionType origType) { - return substSILFunctionType(origType, false); - } - - SubstitutionMap substOpaqueTypes(SubstitutionMap subs) { - if (!typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) - return subs; - - return subs.subst([&](SubstitutableType *s) -> Type { - return substOpaqueTypesWithUnderlyingTypes(s->getCanonicalType(), - typeExpansionContext); - }, [&](CanType dependentType, - Type conformingReplacementType, - ProtocolDecl *conformedProtocol) -> ProtocolConformanceRef { - return substOpaqueTypesWithUnderlyingTypes( - ProtocolConformanceRef(conformedProtocol), - conformingReplacementType->getCanonicalType(), - typeExpansionContext); - }, SubstFlags::SubstituteOpaqueArchetypes); - } - - // Substitute a function type. - CanSILFunctionType substSILFunctionType(CanSILFunctionType origType, - bool isGenericApplication) { - assert((!isGenericApplication || origType->isPolymorphic()) && - "generic application without invocation signature or with " - "existing arguments"); - assert((!isGenericApplication || !shouldSubstituteOpaqueArchetypes) && - "generic application while substituting opaque archetypes"); - - // The general substitution rule is that we should only substitute - // into the free components of the type, i.e. the components that - // aren't inside a generic signature. That rule would say: - // - // - If there are invocation substitutions, just substitute those; - // the other components are necessarily inside the invocation - // generic signature. - // - // - Otherwise, if there's an invocation generic signature, - // substitute nothing. If we are applying generic arguments, - // add the appropriate invocation substitutions. - // - // - Otherwise, if there are pattern substitutions, just substitute - // those; the other components are inside the pattern generic - // signature. - // - // - Otherwise, substitute the basic components. - // - // There are two caveats here. The first is that we haven't yet - // written all the code that would be necessary in order to handle - // invocation substitutions everywhere, and so we never build those. - // Instead, we substitute into the pattern substitutions if present, - // or the components if not, and build a type with no invocation - // signature. As a special case, when substituting a coroutine type, - // we build pattern substitutions instead of substituting the - // component types in order to preserve the original yield structure, - // which factors into the continuation function ABI. - // - // The second is that this function is also used when substituting - // opaque archetypes. In this case, we may need to substitute - // into component types even within generic signatures. This is - // safe because the substitutions used in this case don't change - // generics, they just narrowly look through certain opaque archetypes. - // If substitutions are present, we still don't substitute into - // the basic components, in order to maintain the information about - // what was abstracted there. - - auto patternSubs = origType->getPatternSubstitutions(); - - // If we have an invocation signature, we generally shouldn't - // substitute into the pattern substitutions and component types. - if (auto sig = origType->getInvocationGenericSignature()) { - // Substitute the invocation substitutions if present. - if (auto invocationSubs = origType->getInvocationSubstitutions()) { - assert(!isGenericApplication); - invocationSubs = substSubstitutions(invocationSubs); - auto substType = - origType->withInvocationSubstitutions(invocationSubs); - - // Also do opaque-type substitutions on the pattern substitutions - // if requested and applicable. - if (patternSubs) { - patternSubs = substOpaqueTypes(patternSubs); - substType = substType->withPatternSubstitutions(patternSubs); - } - - return substType; - } - - // Otherwise, we shouldn't substitute any components except - // when substituting opaque archetypes. - - // If we're doing a generic application, and there are pattern - // substitutions, substitute into the pattern substitutions; or if - // it's a coroutine, build pattern substitutions; or else, fall - // through to substitute the component types as discussed above. - if (isGenericApplication) { - if (patternSubs || origType->isCoroutine()) { - CanSILFunctionType substType = origType; - if (typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) { - substType = - origType->substituteOpaqueArchetypes(TC, typeExpansionContext); - } - - SubstitutionMap subs; - if (patternSubs) { - subs = substSubstitutions(patternSubs); - } else { - subs = SubstitutionMap::get(sig, Subst, Conformances); - } - auto witnessConformance = substWitnessConformance(origType); - substType = substType->withPatternSpecialization(nullptr, subs, - witnessConformance); - if (typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) { - substType = - substType->substituteOpaqueArchetypes(TC, typeExpansionContext); - } - return substType; - } - // else fall down to component substitution - - // If we're substituting opaque archetypes, and there are pattern - // substitutions present, just substitute those and preserve the - // basic structure in the component types. Otherwise, fall through - // to substitute the component types. - } else if (shouldSubstituteOpaqueArchetypes) { - if (patternSubs) { - patternSubs = substOpaqueTypes(patternSubs); - auto witnessConformance = substWitnessConformance(origType); - return origType->withPatternSpecialization(sig, patternSubs, - witnessConformance); - } - // else fall down to component substitution - - // Otherwise, don't try to substitute bound components. - } else { - auto substType = origType; - if (patternSubs) { - patternSubs = substOpaqueTypes(patternSubs); - auto witnessConformance = substWitnessConformance(origType); - substType = substType->withPatternSpecialization(sig, patternSubs, - witnessConformance); - } - return substType; - } - - // Otherwise, if there are pattern substitutions, just substitute - // into those and don't touch the component types. - } else if (patternSubs) { - patternSubs = substSubstitutions(patternSubs); - auto witnessConformance = substWitnessConformance(origType); - return origType->withPatternSpecialization(nullptr, patternSubs, - witnessConformance); - } - - // Otherwise, we need to substitute component types. - - SmallVector substResults; - substResults.reserve(origType->getNumResults()); - for (auto origResult : origType->getResults()) { - substResults.push_back(substInterface(origResult)); - } - - auto substErrorResult = origType->getOptionalErrorResult(); - assert(!substErrorResult || - (!substErrorResult->getInterfaceType()->hasTypeParameter() && - !substErrorResult->getInterfaceType()->hasArchetype())); - - SmallVector substParams; - substParams.reserve(origType->getParameters().size()); - for (auto &origParam : origType->getParameters()) { - substParams.push_back(substInterface(origParam)); - } - - SmallVector substYields; - substYields.reserve(origType->getYields().size()); - for (auto &origYield : origType->getYields()) { - substYields.push_back(substInterface(origYield)); - } - - auto witnessMethodConformance = substWitnessConformance(origType); - - // The substituted type is no longer generic, so it'd never be - // pseudogeneric. - auto extInfo = origType->getExtInfo(); - if (!shouldSubstituteOpaqueArchetypes) - extInfo = extInfo.intoBuilder().withIsPseudogeneric(false).build(); - - auto genericSig = shouldSubstituteOpaqueArchetypes - ? origType->getInvocationGenericSignature() - : nullptr; - - return SILFunctionType::get(genericSig, extInfo, - origType->getCoroutineKind(), - origType->getCalleeConvention(), substParams, - substYields, substResults, substErrorResult, - SubstitutionMap(), SubstitutionMap(), - TC.Context, witnessMethodConformance); - } - - ProtocolConformanceRef substWitnessConformance(CanSILFunctionType origType) { - auto conformance = origType->getWitnessMethodConformanceOrInvalid(); - if (!conformance) return conformance; - - assert(origType->getExtInfo().hasSelfParam()); - auto selfType = origType->getSelfParameter().getInterfaceType(); - - // The Self type can be nested in a few layers of metatypes (etc.). - while (auto metatypeType = dyn_cast(selfType)) { - auto next = metatypeType.getInstanceType(); - if (next == selfType) - break; - selfType = next; - } - - auto substConformance = - conformance.subst(selfType, Subst, Conformances); - - // Substitute the underlying conformance of opaque type archetypes if we - // should look through opaque archetypes. - if (typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) { - SubstOptions substOptions(None); - auto substType = selfType.subst(Subst, Conformances, substOptions) - ->getCanonicalType(); - if (substType->hasOpaqueArchetype()) { - substConformance = substOpaqueTypesWithUnderlyingTypes( - substConformance, substType, typeExpansionContext); - } - } - - return substConformance; - } - - SILType subst(SILType type) { - return SILType::getPrimitiveType(visit(type.getASTType()), - type.getCategory()); - } - - SILResultInfo substInterface(SILResultInfo orig) { - return SILResultInfo(visit(orig.getInterfaceType()), orig.getConvention()); - } - - SILYieldInfo substInterface(SILYieldInfo orig) { - return SILYieldInfo(visit(orig.getInterfaceType()), orig.getConvention()); - } - - SILParameterInfo substInterface(SILParameterInfo orig) { - return SILParameterInfo(visit(orig.getInterfaceType()), - orig.getConvention(), orig.getDifferentiability()); - } - - CanType visitSILPackType(CanSILPackType origType) { - // Fast-path the empty pack. - if (origType->getNumElements() == 0) return origType; - - SmallVector substEltTypes; - - substEltTypes.reserve(origType->getNumElements()); - - for (CanType origEltType : origType->getElementTypes()) { - if (auto origExpansionType = dyn_cast(origEltType)) { - substPackExpansion(origExpansionType, [&](CanType substExpandedType) { - substEltTypes.push_back(substExpandedType); - }); - } else { - auto substEltType = visit(origEltType); - substEltTypes.push_back(substEltType); - } - } - return SILPackType::get(TC.Context, origType->getExtInfo(), substEltTypes); - } - - CanType visitPackType(CanPackType origType) { - llvm_unreachable("CanPackType shouldn't show in lowered types"); - } - - CanType visitPackExpansionType(CanPackExpansionType origType) { - CanType patternType = visit(origType.getPatternType()); - CanType countType = substASTType(origType.getCountType()); - - return CanType(PackExpansionType::get(patternType, countType)); - } - - void substPackExpansion(CanPackExpansionType origType, - llvm::function_ref addExpandedType) { - CanType origCountType = origType.getCountType(); - CanType origPatternType = origType.getPatternType(); - - // Substitute the count type (as an AST type). - CanType substCountType = substASTType(origCountType); - - // If that produces a pack type, expand the pattern element-wise. - if (auto substCountPackType = dyn_cast(substCountType)) { - // Set up for element-wise expansion. - ActivePackExpansions.emplace_back(origCountType); - - for (CanType substCountEltType : substCountPackType.getElementTypes()) { - auto expansionType = dyn_cast(substCountEltType); - ActivePackExpansions.back().SubstPackExpansionCount = - (expansionType ? expansionType.getCountType() : CanType()); - - // Expand the pattern type in the element-wise context. - CanType expandedType = visit(origPatternType); - - // Turn that into a pack expansion if appropriate for the - // count element. - if (expansionType) { - expandedType = - CanPackExpansionType::get(expandedType, - expansionType.getCountType()); - } - - addExpandedType(expandedType); - - // Move to the next element. - ActivePackExpansions.back().Index++; - } - - // Leave the element-wise context. - ActivePackExpansions.pop_back(); - return; - } - - // Otherwise, transform the pattern type abstractly and just add a - // type expansion. - CanType substPatternType = visit(origPatternType); - - CanType expandedType; - if (substCountType == origCountType && substPatternType == origPatternType) - expandedType = origType; - else - expandedType = - CanPackExpansionType::get(substPatternType, substCountType); - addExpandedType(expandedType); - } - - /// Tuples need to have their component types substituted by these - /// same rules. - CanType visitTupleType(CanTupleType origType) { - // Fast-path the empty tuple. - if (origType->getNumElements() == 0) return origType; - - SmallVector substElts; - substElts.reserve(origType->getNumElements()); - for (auto &origElt : origType->getElements()) { - CanType origEltType = CanType(origElt.getType()); - if (auto origExpansion = dyn_cast(origEltType)) { - bool first = true; - substPackExpansion(origExpansion, [&](CanType substEltType) { - auto substElt = origElt.getWithType(substEltType); - if (first) { - first = false; - } else { - substElt = substElt.getWithoutName(); - } - substElts.push_back(substElt); - }); - } else { - auto substEltType = visit(origEltType); - substElts.push_back(origElt.getWithType(substEltType)); - } - } - - // Turn unlabeled singleton scalar tuples into their underlying types. - // The AST type substituter doesn't actually implement this rule yet, - // but we need to implement it in SIL in order to support testing, - // since the type parser can't parse a singleton tuple. - // - // For compatibility with previous behavior, don't do this if the - // original tuple type was singleton. AutoDiff apparently really - // likes making singleton tuples. - if (isParenType(substElts) && !isParenType(origType->getElements())) - return CanType(substElts[0].getType()); - - return CanType(TupleType::get(substElts, TC.Context)); - } - - static bool isParenType(ArrayRef elts) { - return (elts.size() == 1 && - !elts[0].hasName() && - !isa(CanType(elts[0].getType()))); - } - - // Block storage types need to substitute their capture type by these same - // rules. - CanType visitSILBlockStorageType(CanSILBlockStorageType origType) { - auto substCaptureType = visit(origType->getCaptureType()); - return SILBlockStorageType::get(substCaptureType); - } - - /// Optionals need to have their object types substituted by these rules. - CanType visitBoundGenericEnumType(CanBoundGenericEnumType origType) { - // Only use a special rule if it's Optional. - if (!origType->getDecl()->isOptionalDecl()) { - return visitType(origType); - } - - CanType origObjectType = origType.getGenericArgs()[0]; - CanType substObjectType = visit(origObjectType); - return CanType(BoundGenericType::get(origType->getDecl(), Type(), - substObjectType)); - } - - /// Any other type would be a valid type in the AST. Just apply the - /// substitution on the AST level and then lower that. - CanType visitType(CanType origType) { - assert(!isa(origType)); - assert(!isa(origType) && !isa(origType)); - - CanType substType = substASTType(origType); - - // If the substitution didn't change anything, we know that the - // original type was a lowered type, so we're good. - if (origType == substType) { - return origType; - } - - // We've looked through all the top-level structure in the orig - // type that's affected by type lowering. If substitution has - // given us a type with top-level structure that's affected by - // type lowering, it must be because the orig type was a type - // variable of some sort, and we should lower using an opaque - // abstraction pattern. If substitution hasn't given us such a - // type, it doesn't matter what abstraction pattern we use, - // lowering will just come back with substType. So we can just - // use an opaque abstraction pattern here and not put any effort - // into computing a more "honest" abstraction pattern. - AbstractionPattern abstraction = AbstractionPattern::getOpaque(); - return TC.getLoweredRValueType(typeExpansionContext, abstraction, - substType); - } - - struct SubstRespectingExpansions { - SILTypeSubstituter *_this; - SubstRespectingExpansions(SILTypeSubstituter *_this) : _this(_this) {} - - Type operator()(SubstitutableType *origType) const { - auto substType = _this->Subst(origType); - if (!substType) return substType; - auto substPackType = dyn_cast(substType->getCanonicalType()); - if (!substPackType) return substType; - auto activeExpansion = _this->getActivePackExpansion(CanType(origType)); - if (!activeExpansion) return substType; - auto substEltType = - substPackType.getElementType(activeExpansion->Index); - auto substExpansion = dyn_cast(substEltType); - assert((bool) substExpansion == - (bool) activeExpansion->SubstPackExpansionCount); - if (substExpansion) { - assert(_this->hasSameShape(substExpansion.getCountType(), - activeExpansion->SubstPackExpansionCount)); - return substExpansion.getPatternType(); - } - return substEltType; - } - }; - - struct SubstConformanceRespectingExpansions { - SILTypeSubstituter *_this; - SubstConformanceRespectingExpansions(SILTypeSubstituter *_this) - : _this(_this) {} - - ProtocolConformanceRef operator()(CanType dependentType, - Type conformingReplacementType, - ProtocolDecl *conformingProtocol) const { - auto conformance = _this->Conformances(dependentType, - conformingReplacementType, - conformingProtocol); - if (!conformance || !conformance.isPack()) return conformance; - auto activeExpansion = _this->getActivePackExpansion(dependentType); - if (!activeExpansion) return conformance; - auto pack = conformance.getPack(); - auto substEltConf = - pack->getPatternConformances()[activeExpansion->Index]; - // There isn't currently a ProtocolConformanceExpansion that - // we would need to look through here. - return substEltConf; - }; - }; - - CanType substASTType(CanType origType) { - SubstOptions substOptions(None); - if (shouldSubstituteOpaqueArchetypes) - substOptions = SubstFlags::SubstituteOpaqueArchetypes | - SubstFlags::AllowLoweredTypes; - - if (ActivePackExpansions.empty()) - return origType.subst(Subst, Conformances, substOptions) - ->getCanonicalType(); - - return origType.subst(SubstRespectingExpansions(this), - SubstConformanceRespectingExpansions(this), - substOptions)->getCanonicalType(); - } - - SubstitutionMap substSubstitutions(SubstitutionMap subs) { - // Substitute the substitutions. - SubstOptions options = None; - if (shouldSubstituteOpaqueArchetypes) - options |= SubstFlags::SubstituteOpaqueArchetypes; - - // Expand substituted type according to the expansion context. - SubstitutionMap newSubs; - - if (ActivePackExpansions.empty()) - newSubs = subs.subst(Subst, Conformances, options); - else - newSubs = subs.subst(SubstRespectingExpansions(this), - SubstConformanceRespectingExpansions(this), - options); - - // If we need to look through opaque types in this context, re-substitute - // according to the expansion context. - newSubs = substOpaqueTypes(newSubs); - - return newSubs; - } - - PackExpansion *getActivePackExpansion(CanType dependentType) { - // We push new expansions onto the end of this vector, and we - // want to honor the innermost expansion, so we have to traverse - // in it reverse. - for (auto &entry : reverse(ActivePackExpansions)) { - if (hasSameShape(dependentType, entry.OrigShapeClass)) - return &entry; - } - return nullptr; - } - - bool hasSameShape(CanType lhs, CanType rhs) { - if (lhs->isTypeParameter() && rhs->isTypeParameter()) { - assert(Sig); - return Sig->haveSameShape(lhs, rhs); - } - - auto lhsArchetype = cast(lhs); - auto rhsArchetype = cast(rhs); - return lhsArchetype->getReducedShape() == rhsArchetype->getReducedShape(); - } -}; - -} // end anonymous namespace - -SILType SILType::subst(TypeConverter &tc, TypeSubstitutionFn subs, - LookupConformanceFn conformances, - CanGenericSignature genericSig, - bool shouldSubstituteOpaqueArchetypes) const { - if (!hasArchetype() && !hasTypeParameter() && - (!shouldSubstituteOpaqueArchetypes || - !getASTType()->hasOpaqueArchetype())) - return *this; - - SILTypeSubstituter STST(tc, TypeExpansionContext::minimal(), subs, - conformances, genericSig, - shouldSubstituteOpaqueArchetypes); - return STST.subst(*this); -} - -SILType SILType::subst(SILModule &M, TypeSubstitutionFn subs, - LookupConformanceFn conformances, - CanGenericSignature genericSig, - bool shouldSubstituteOpaqueArchetypes) const { - return subst(M.Types, subs, conformances, genericSig, - shouldSubstituteOpaqueArchetypes); -} - -SILType SILType::subst(TypeConverter &tc, SubstitutionMap subs) const { - auto sig = subs.getGenericSignature(); - return subst(tc, QuerySubstitutionMap{subs}, - LookUpConformanceInSubstitutionMap(subs), - sig.getCanonicalSignature()); -} -SILType SILType::subst(SILModule &M, SubstitutionMap subs) const{ - return subst(M.Types, subs); -} - -SILType SILType::subst(SILModule &M, SubstitutionMap subs, - TypeExpansionContext context) const { - if (!hasArchetype() && !hasTypeParameter() && - !getASTType()->hasOpaqueArchetype()) - return *this; - - // Pass the TypeSubstitutionFn and LookupConformanceFn as arguments so that - // the llvm::function_ref value's scope spans the STST.subst call since - // SILTypeSubstituter captures these functions. - auto result = [&](TypeSubstitutionFn subsFn, - LookupConformanceFn conformancesFn) -> SILType { - SILTypeSubstituter STST(M.Types, context, subsFn, conformancesFn, - subs.getGenericSignature().getCanonicalSignature(), - false); - return STST.subst(*this); - }(QuerySubstitutionMap{subs}, LookUpConformanceInSubstitutionMap(subs)); - return result; -} - -/// Apply a substitution to this polymorphic SILFunctionType so that -/// it has the form of the normal SILFunctionType for the substituted -/// type, except using the original conventions. -CanSILFunctionType -SILFunctionType::substGenericArgs(SILModule &silModule, SubstitutionMap subs, - TypeExpansionContext context) { - if (!isPolymorphic()) { - return CanSILFunctionType(this); - } - - if (subs.empty()) { - return CanSILFunctionType(this); - } - - return substGenericArgs(silModule, - QuerySubstitutionMap{subs}, - LookUpConformanceInSubstitutionMap(subs), - context); -} - -CanSILFunctionType -SILFunctionType::substGenericArgs(SILModule &silModule, - TypeSubstitutionFn subs, - LookupConformanceFn conformances, - TypeExpansionContext context) { - if (!isPolymorphic()) return CanSILFunctionType(this); - SILTypeSubstituter substituter(silModule.Types, context, subs, conformances, - getSubstGenericSignature(), - /*shouldSubstituteOpaqueTypes*/ false); - return substituter.substSILFunctionType(CanSILFunctionType(this), true); -} - -CanSILFunctionType -SILFunctionType::substituteOpaqueArchetypes(TypeConverter &TC, - TypeExpansionContext context) { - if (!hasOpaqueArchetype() || - !context.shouldLookThroughOpaqueTypeArchetypes()) - return CanSILFunctionType(this); - - ReplaceOpaqueTypesWithUnderlyingTypes replacer( - context.getContext(), context.getResilienceExpansion(), - context.isWholeModuleContext()); - - SILTypeSubstituter substituter(TC, context, replacer, replacer, - getSubstGenericSignature(), - /*shouldSubstituteOpaqueTypes*/ true); - auto resTy = - substituter.substSILFunctionType(CanSILFunctionType(this), false); - - return resTy; -} - /// Fast path for bridging types in a function type without uncurrying. CanAnyFunctionType TypeConverter::getBridgedFunctionType( AbstractionPattern pattern, CanAnyFunctionType t, Bridgeability bridging, diff --git a/lib/SIL/IR/SILTypeSubstitution.cpp b/lib/SIL/IR/SILTypeSubstitution.cpp new file mode 100644 index 0000000000000..4a5743405282f --- /dev/null +++ b/lib/SIL/IR/SILTypeSubstitution.cpp @@ -0,0 +1,734 @@ +//===--- SILTypeSubstitution.cpp - Apply substitutions to SIL types -------===// +// +// This source file is part of the Swift.org open source project +// +// Copyright (c) 2014 - 2023 Apple Inc. and the Swift project authors +// Licensed under Apache License v2.0 with Runtime Library Exception +// +// See https://swift.org/LICENSE.txt for license information +// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors +// +//===----------------------------------------------------------------------===// +// +// This file defines the core operations that apply substitutions to +// the lowered types used for SIL values. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "libsil" + +#include "swift/SIL/SILModule.h" +#include "swift/SIL/SILType.h" +#include "swift/AST/InFlightSubstitution.h" +#include "swift/AST/PackConformance.h" +#include "swift/AST/ProtocolConformance.h" +#include "swift/AST/CanTypeVisitor.h" + +using namespace swift; +using namespace Lowering; + +namespace { + +/// Given a lowered SIL type, apply a substitution to it to produce another +/// lowered SIL type which uses the same abstraction conventions. +class SILTypeSubstituter : + public CanTypeVisitor { + TypeConverter &TC; + TypeSubstitutionFn Subst; + LookupConformanceFn Conformances; + // The signature for the original type. + // + // Replacement types are lowered with respect to the current + // context signature. + CanGenericSignature Sig; + + struct PackExpansion { + /// The shape class of pack parameters that are expanded by this + /// expansion. Set during construction and not changed. + CanType OrigShapeClass; + + /// The count type of the pack expansion in the current lane of + /// expansion, if any. Pack elements in this lane should be + /// expansions with this shape. + CanType SubstPackExpansionCount; + + /// The index of the current lane of expansion. Basic + /// substitution of pack parameters with the same shape as + /// OrigShapeClass should yield a pack, and lanewise + /// substitution should produce this element of that pack. + unsigned Index; + + PackExpansion(CanType origShapeClass) + : OrigShapeClass(origShapeClass), Index(0) {} + }; + SmallVector ActivePackExpansions; + + TypeExpansionContext typeExpansionContext; + + bool shouldSubstituteOpaqueArchetypes; + +public: + SILTypeSubstituter(TypeConverter &TC, + TypeExpansionContext context, + TypeSubstitutionFn Subst, + LookupConformanceFn Conformances, + CanGenericSignature Sig, + bool shouldSubstituteOpaqueArchetypes) + : TC(TC), + Subst(Subst), + Conformances(Conformances), + Sig(Sig), + typeExpansionContext(context), + shouldSubstituteOpaqueArchetypes(shouldSubstituteOpaqueArchetypes) + {} + + // SIL type lowering only does special things to tuples and functions. + + // When a function appears inside of another type, we only perform + // substitutions if it is not polymorphic. + CanSILFunctionType visitSILFunctionType(CanSILFunctionType origType) { + return substSILFunctionType(origType, false); + } + + SubstitutionMap substOpaqueTypes(SubstitutionMap subs) { + if (!typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) + return subs; + + return subs.subst([&](SubstitutableType *s) -> Type { + return substOpaqueTypesWithUnderlyingTypes(s->getCanonicalType(), + typeExpansionContext); + }, [&](CanType dependentType, + Type conformingReplacementType, + ProtocolDecl *conformedProtocol) -> ProtocolConformanceRef { + return substOpaqueTypesWithUnderlyingTypes( + ProtocolConformanceRef(conformedProtocol), + conformingReplacementType->getCanonicalType(), + typeExpansionContext); + }, SubstFlags::SubstituteOpaqueArchetypes); + } + + // Substitute a function type. + CanSILFunctionType substSILFunctionType(CanSILFunctionType origType, + bool isGenericApplication) { + assert((!isGenericApplication || origType->isPolymorphic()) && + "generic application without invocation signature or with " + "existing arguments"); + assert((!isGenericApplication || !shouldSubstituteOpaqueArchetypes) && + "generic application while substituting opaque archetypes"); + + // The general substitution rule is that we should only substitute + // into the free components of the type, i.e. the components that + // aren't inside a generic signature. That rule would say: + // + // - If there are invocation substitutions, just substitute those; + // the other components are necessarily inside the invocation + // generic signature. + // + // - Otherwise, if there's an invocation generic signature, + // substitute nothing. If we are applying generic arguments, + // add the appropriate invocation substitutions. + // + // - Otherwise, if there are pattern substitutions, just substitute + // those; the other components are inside the pattern generic + // signature. + // + // - Otherwise, substitute the basic components. + // + // There are two caveats here. The first is that we haven't yet + // written all the code that would be necessary in order to handle + // invocation substitutions everywhere, and so we never build those. + // Instead, we substitute into the pattern substitutions if present, + // or the components if not, and build a type with no invocation + // signature. As a special case, when substituting a coroutine type, + // we build pattern substitutions instead of substituting the + // component types in order to preserve the original yield structure, + // which factors into the continuation function ABI. + // + // The second is that this function is also used when substituting + // opaque archetypes. In this case, we may need to substitute + // into component types even within generic signatures. This is + // safe because the substitutions used in this case don't change + // generics, they just narrowly look through certain opaque archetypes. + // If substitutions are present, we still don't substitute into + // the basic components, in order to maintain the information about + // what was abstracted there. + + auto patternSubs = origType->getPatternSubstitutions(); + + // If we have an invocation signature, we generally shouldn't + // substitute into the pattern substitutions and component types. + if (auto sig = origType->getInvocationGenericSignature()) { + // Substitute the invocation substitutions if present. + if (auto invocationSubs = origType->getInvocationSubstitutions()) { + assert(!isGenericApplication); + invocationSubs = substSubstitutions(invocationSubs); + auto substType = + origType->withInvocationSubstitutions(invocationSubs); + + // Also do opaque-type substitutions on the pattern substitutions + // if requested and applicable. + if (patternSubs) { + patternSubs = substOpaqueTypes(patternSubs); + substType = substType->withPatternSubstitutions(patternSubs); + } + + return substType; + } + + // Otherwise, we shouldn't substitute any components except + // when substituting opaque archetypes. + + // If we're doing a generic application, and there are pattern + // substitutions, substitute into the pattern substitutions; or if + // it's a coroutine, build pattern substitutions; or else, fall + // through to substitute the component types as discussed above. + if (isGenericApplication) { + if (patternSubs || origType->isCoroutine()) { + CanSILFunctionType substType = origType; + if (typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) { + substType = + origType->substituteOpaqueArchetypes(TC, typeExpansionContext); + } + + SubstitutionMap subs; + if (patternSubs) { + subs = substSubstitutions(patternSubs); + } else { + subs = SubstitutionMap::get(sig, Subst, Conformances); + } + auto witnessConformance = substWitnessConformance(origType); + substType = substType->withPatternSpecialization(nullptr, subs, + witnessConformance); + if (typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) { + substType = + substType->substituteOpaqueArchetypes(TC, typeExpansionContext); + } + return substType; + } + // else fall down to component substitution + + // If we're substituting opaque archetypes, and there are pattern + // substitutions present, just substitute those and preserve the + // basic structure in the component types. Otherwise, fall through + // to substitute the component types. + } else if (shouldSubstituteOpaqueArchetypes) { + if (patternSubs) { + patternSubs = substOpaqueTypes(patternSubs); + auto witnessConformance = substWitnessConformance(origType); + return origType->withPatternSpecialization(sig, patternSubs, + witnessConformance); + } + // else fall down to component substitution + + // Otherwise, don't try to substitute bound components. + } else { + auto substType = origType; + if (patternSubs) { + patternSubs = substOpaqueTypes(patternSubs); + auto witnessConformance = substWitnessConformance(origType); + substType = substType->withPatternSpecialization(sig, patternSubs, + witnessConformance); + } + return substType; + } + + // Otherwise, if there are pattern substitutions, just substitute + // into those and don't touch the component types. + } else if (patternSubs) { + patternSubs = substSubstitutions(patternSubs); + auto witnessConformance = substWitnessConformance(origType); + return origType->withPatternSpecialization(nullptr, patternSubs, + witnessConformance); + } + + // Otherwise, we need to substitute component types. + + SmallVector substResults; + substResults.reserve(origType->getNumResults()); + for (auto origResult : origType->getResults()) { + substResults.push_back(substInterface(origResult)); + } + + auto substErrorResult = origType->getOptionalErrorResult(); + assert(!substErrorResult || + (!substErrorResult->getInterfaceType()->hasTypeParameter() && + !substErrorResult->getInterfaceType()->hasArchetype())); + + SmallVector substParams; + substParams.reserve(origType->getParameters().size()); + for (auto &origParam : origType->getParameters()) { + substParams.push_back(substInterface(origParam)); + } + + SmallVector substYields; + substYields.reserve(origType->getYields().size()); + for (auto &origYield : origType->getYields()) { + substYields.push_back(substInterface(origYield)); + } + + auto witnessMethodConformance = substWitnessConformance(origType); + + // The substituted type is no longer generic, so it'd never be + // pseudogeneric. + auto extInfo = origType->getExtInfo(); + if (!shouldSubstituteOpaqueArchetypes) + extInfo = extInfo.intoBuilder().withIsPseudogeneric(false).build(); + + auto genericSig = shouldSubstituteOpaqueArchetypes + ? origType->getInvocationGenericSignature() + : nullptr; + + return SILFunctionType::get(genericSig, extInfo, + origType->getCoroutineKind(), + origType->getCalleeConvention(), substParams, + substYields, substResults, substErrorResult, + SubstitutionMap(), SubstitutionMap(), + TC.Context, witnessMethodConformance); + } + + ProtocolConformanceRef substWitnessConformance(CanSILFunctionType origType) { + auto conformance = origType->getWitnessMethodConformanceOrInvalid(); + if (!conformance) return conformance; + + assert(origType->getExtInfo().hasSelfParam()); + auto selfType = origType->getSelfParameter().getInterfaceType(); + + // The Self type can be nested in a few layers of metatypes (etc.). + while (auto metatypeType = dyn_cast(selfType)) { + auto next = metatypeType.getInstanceType(); + if (next == selfType) + break; + selfType = next; + } + + auto substConformance = + conformance.subst(selfType, Subst, Conformances); + + // Substitute the underlying conformance of opaque type archetypes if we + // should look through opaque archetypes. + if (typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) { + SubstOptions substOptions(None); + auto substType = selfType.subst(Subst, Conformances, substOptions) + ->getCanonicalType(); + if (substType->hasOpaqueArchetype()) { + substConformance = substOpaqueTypesWithUnderlyingTypes( + substConformance, substType, typeExpansionContext); + } + } + + return substConformance; + } + + SILType subst(SILType type) { + return SILType::getPrimitiveType(visit(type.getASTType()), + type.getCategory()); + } + + SILResultInfo substInterface(SILResultInfo orig) { + return SILResultInfo(visit(orig.getInterfaceType()), orig.getConvention()); + } + + SILYieldInfo substInterface(SILYieldInfo orig) { + return SILYieldInfo(visit(orig.getInterfaceType()), orig.getConvention()); + } + + SILParameterInfo substInterface(SILParameterInfo orig) { + return SILParameterInfo(visit(orig.getInterfaceType()), + orig.getConvention(), orig.getDifferentiability()); + } + + CanType visitSILPackType(CanSILPackType origType) { + // Fast-path the empty pack. + if (origType->getNumElements() == 0) return origType; + + SmallVector substEltTypes; + + substEltTypes.reserve(origType->getNumElements()); + + for (CanType origEltType : origType->getElementTypes()) { + if (auto origExpansionType = dyn_cast(origEltType)) { + substPackExpansion(origExpansionType, [&](CanType substExpandedType) { + substEltTypes.push_back(substExpandedType); + }); + } else { + auto substEltType = visit(origEltType); + substEltTypes.push_back(substEltType); + } + } + return SILPackType::get(TC.Context, origType->getExtInfo(), substEltTypes); + } + + CanType visitPackType(CanPackType origType) { + llvm_unreachable("CanPackType shouldn't show in lowered types"); + } + + CanType visitPackExpansionType(CanPackExpansionType origType) { + CanType patternType = visit(origType.getPatternType()); + CanType countType = substASTType(origType.getCountType()); + + return CanType(PackExpansionType::get(patternType, countType)); + } + + void substPackExpansion(CanPackExpansionType origType, + llvm::function_ref addExpandedType) { + CanType origCountType = origType.getCountType(); + CanType origPatternType = origType.getPatternType(); + + // Substitute the count type (as an AST type). + CanType substCountType = substASTType(origCountType); + + // If that produces a pack type, expand the pattern element-wise. + if (auto substCountPackType = dyn_cast(substCountType)) { + // Set up for element-wise expansion. + ActivePackExpansions.emplace_back(origCountType); + + for (CanType substCountEltType : substCountPackType.getElementTypes()) { + auto expansionType = dyn_cast(substCountEltType); + ActivePackExpansions.back().SubstPackExpansionCount = + (expansionType ? expansionType.getCountType() : CanType()); + + // Expand the pattern type in the element-wise context. + CanType expandedType = visit(origPatternType); + + // Turn that into a pack expansion if appropriate for the + // count element. + if (expansionType) { + expandedType = + CanPackExpansionType::get(expandedType, + expansionType.getCountType()); + } + + addExpandedType(expandedType); + + // Move to the next element. + ActivePackExpansions.back().Index++; + } + + // Leave the element-wise context. + ActivePackExpansions.pop_back(); + return; + } + + // Otherwise, transform the pattern type abstractly and just add a + // type expansion. + CanType substPatternType = visit(origPatternType); + + CanType expandedType; + if (substCountType == origCountType && substPatternType == origPatternType) + expandedType = origType; + else + expandedType = + CanPackExpansionType::get(substPatternType, substCountType); + addExpandedType(expandedType); + } + + /// Tuples need to have their component types substituted by these + /// same rules. + CanType visitTupleType(CanTupleType origType) { + // Fast-path the empty tuple. + if (origType->getNumElements() == 0) return origType; + + SmallVector substElts; + substElts.reserve(origType->getNumElements()); + for (auto &origElt : origType->getElements()) { + CanType origEltType = CanType(origElt.getType()); + if (auto origExpansion = dyn_cast(origEltType)) { + bool first = true; + substPackExpansion(origExpansion, [&](CanType substEltType) { + auto substElt = origElt.getWithType(substEltType); + if (first) { + first = false; + } else { + substElt = substElt.getWithoutName(); + } + substElts.push_back(substElt); + }); + } else { + auto substEltType = visit(origEltType); + substElts.push_back(origElt.getWithType(substEltType)); + } + } + + // Turn unlabeled singleton scalar tuples into their underlying types. + // The AST type substituter doesn't actually implement this rule yet, + // but we need to implement it in SIL in order to support testing, + // since the type parser can't parse a singleton tuple. + // + // For compatibility with previous behavior, don't do this if the + // original tuple type was singleton. AutoDiff apparently really + // likes making singleton tuples. + if (isParenType(substElts) && !isParenType(origType->getElements())) + return CanType(substElts[0].getType()); + + return CanType(TupleType::get(substElts, TC.Context)); + } + + static bool isParenType(ArrayRef elts) { + return (elts.size() == 1 && + !elts[0].hasName() && + !isa(CanType(elts[0].getType()))); + } + + // Block storage types need to substitute their capture type by these same + // rules. + CanType visitSILBlockStorageType(CanSILBlockStorageType origType) { + auto substCaptureType = visit(origType->getCaptureType()); + return SILBlockStorageType::get(substCaptureType); + } + + /// Optionals need to have their object types substituted by these rules. + CanType visitBoundGenericEnumType(CanBoundGenericEnumType origType) { + // Only use a special rule if it's Optional. + if (!origType->getDecl()->isOptionalDecl()) { + return visitType(origType); + } + + CanType origObjectType = origType.getGenericArgs()[0]; + CanType substObjectType = visit(origObjectType); + return CanType(BoundGenericType::get(origType->getDecl(), Type(), + substObjectType)); + } + + /// Any other type would be a valid type in the AST. Just apply the + /// substitution on the AST level and then lower that. + CanType visitType(CanType origType) { + assert(!isa(origType)); + assert(!isa(origType) && !isa(origType)); + + CanType substType = substASTType(origType); + + // If the substitution didn't change anything, we know that the + // original type was a lowered type, so we're good. + if (origType == substType) { + return origType; + } + + // We've looked through all the top-level structure in the orig + // type that's affected by type lowering. If substitution has + // given us a type with top-level structure that's affected by + // type lowering, it must be because the orig type was a type + // variable of some sort, and we should lower using an opaque + // abstraction pattern. If substitution hasn't given us such a + // type, it doesn't matter what abstraction pattern we use, + // lowering will just come back with substType. So we can just + // use an opaque abstraction pattern here and not put any effort + // into computing a more "honest" abstraction pattern. + AbstractionPattern abstraction = AbstractionPattern::getOpaque(); + return TC.getLoweredRValueType(typeExpansionContext, abstraction, + substType); + } + + struct SubstRespectingExpansions { + SILTypeSubstituter *_this; + SubstRespectingExpansions(SILTypeSubstituter *_this) : _this(_this) {} + + Type operator()(SubstitutableType *origType) const { + auto substType = _this->Subst(origType); + if (!substType) return substType; + auto substPackType = dyn_cast(substType->getCanonicalType()); + if (!substPackType) return substType; + auto activeExpansion = _this->getActivePackExpansion(CanType(origType)); + if (!activeExpansion) return substType; + auto substEltType = + substPackType.getElementType(activeExpansion->Index); + auto substExpansion = dyn_cast(substEltType); + assert((bool) substExpansion == + (bool) activeExpansion->SubstPackExpansionCount); + if (substExpansion) { + assert(_this->hasSameShape(substExpansion.getCountType(), + activeExpansion->SubstPackExpansionCount)); + return substExpansion.getPatternType(); + } + return substEltType; + } + }; + + struct SubstConformanceRespectingExpansions { + SILTypeSubstituter *_this; + SubstConformanceRespectingExpansions(SILTypeSubstituter *_this) + : _this(_this) {} + + ProtocolConformanceRef operator()(CanType dependentType, + Type conformingReplacementType, + ProtocolDecl *conformingProtocol) const { + auto conformance = _this->Conformances(dependentType, + conformingReplacementType, + conformingProtocol); + if (!conformance || !conformance.isPack()) return conformance; + auto activeExpansion = _this->getActivePackExpansion(dependentType); + if (!activeExpansion) return conformance; + auto pack = conformance.getPack(); + auto substEltConf = + pack->getPatternConformances()[activeExpansion->Index]; + // There isn't currently a ProtocolConformanceExpansion that + // we would need to look through here. + return substEltConf; + }; + }; + + CanType substASTType(CanType origType) { + SubstOptions substOptions(None); + if (shouldSubstituteOpaqueArchetypes) + substOptions = SubstFlags::SubstituteOpaqueArchetypes | + SubstFlags::AllowLoweredTypes; + + if (ActivePackExpansions.empty()) + return origType.subst(Subst, Conformances, substOptions) + ->getCanonicalType(); + + return origType.subst(SubstRespectingExpansions(this), + SubstConformanceRespectingExpansions(this), + substOptions)->getCanonicalType(); + } + + SubstitutionMap substSubstitutions(SubstitutionMap subs) { + // Substitute the substitutions. + SubstOptions options = None; + if (shouldSubstituteOpaqueArchetypes) + options |= SubstFlags::SubstituteOpaqueArchetypes; + + // Expand substituted type according to the expansion context. + SubstitutionMap newSubs; + + if (ActivePackExpansions.empty()) + newSubs = subs.subst(Subst, Conformances, options); + else + newSubs = subs.subst(SubstRespectingExpansions(this), + SubstConformanceRespectingExpansions(this), + options); + + // If we need to look through opaque types in this context, re-substitute + // according to the expansion context. + newSubs = substOpaqueTypes(newSubs); + + return newSubs; + } + + PackExpansion *getActivePackExpansion(CanType dependentType) { + // We push new expansions onto the end of this vector, and we + // want to honor the innermost expansion, so we have to traverse + // in it reverse. + for (auto &entry : reverse(ActivePackExpansions)) { + if (hasSameShape(dependentType, entry.OrigShapeClass)) + return &entry; + } + return nullptr; + } + + bool hasSameShape(CanType lhs, CanType rhs) { + if (lhs->isTypeParameter() && rhs->isTypeParameter()) { + assert(Sig); + return Sig->haveSameShape(lhs, rhs); + } + + auto lhsArchetype = cast(lhs); + auto rhsArchetype = cast(rhs); + return lhsArchetype->getReducedShape() == rhsArchetype->getReducedShape(); + } +}; + +} // end anonymous namespace + +SILType SILType::subst(TypeConverter &tc, TypeSubstitutionFn subs, + LookupConformanceFn conformances, + CanGenericSignature genericSig, + bool shouldSubstituteOpaqueArchetypes) const { + if (!hasArchetype() && !hasTypeParameter() && + (!shouldSubstituteOpaqueArchetypes || + !getASTType()->hasOpaqueArchetype())) + return *this; + + SILTypeSubstituter STST(tc, TypeExpansionContext::minimal(), subs, + conformances, genericSig, + shouldSubstituteOpaqueArchetypes); + return STST.subst(*this); +} + +SILType SILType::subst(SILModule &M, TypeSubstitutionFn subs, + LookupConformanceFn conformances, + CanGenericSignature genericSig, + bool shouldSubstituteOpaqueArchetypes) const { + return subst(M.Types, subs, conformances, genericSig, + shouldSubstituteOpaqueArchetypes); +} + +SILType SILType::subst(TypeConverter &tc, SubstitutionMap subs) const { + auto sig = subs.getGenericSignature(); + return subst(tc, QuerySubstitutionMap{subs}, + LookUpConformanceInSubstitutionMap(subs), + sig.getCanonicalSignature()); +} +SILType SILType::subst(SILModule &M, SubstitutionMap subs) const{ + return subst(M.Types, subs); +} + +SILType SILType::subst(SILModule &M, SubstitutionMap subs, + TypeExpansionContext context) const { + if (!hasArchetype() && !hasTypeParameter() && + !getASTType()->hasOpaqueArchetype()) + return *this; + + // Pass the TypeSubstitutionFn and LookupConformanceFn as arguments so that + // the llvm::function_ref value's scope spans the STST.subst call since + // SILTypeSubstituter captures these functions. + auto result = [&](TypeSubstitutionFn subsFn, + LookupConformanceFn conformancesFn) -> SILType { + SILTypeSubstituter STST(M.Types, context, subsFn, conformancesFn, + subs.getGenericSignature().getCanonicalSignature(), + false); + return STST.subst(*this); + }(QuerySubstitutionMap{subs}, LookUpConformanceInSubstitutionMap(subs)); + return result; +} + +/// Apply a substitution to this polymorphic SILFunctionType so that +/// it has the form of the normal SILFunctionType for the substituted +/// type, except using the original conventions. +CanSILFunctionType +SILFunctionType::substGenericArgs(SILModule &silModule, SubstitutionMap subs, + TypeExpansionContext context) { + if (!isPolymorphic()) { + return CanSILFunctionType(this); + } + + if (subs.empty()) { + return CanSILFunctionType(this); + } + + return substGenericArgs(silModule, + QuerySubstitutionMap{subs}, + LookUpConformanceInSubstitutionMap(subs), + context); +} + +CanSILFunctionType +SILFunctionType::substGenericArgs(SILModule &silModule, + TypeSubstitutionFn subs, + LookupConformanceFn conformances, + TypeExpansionContext context) { + if (!isPolymorphic()) return CanSILFunctionType(this); + SILTypeSubstituter substituter(silModule.Types, context, subs, conformances, + getSubstGenericSignature(), + /*shouldSubstituteOpaqueTypes*/ false); + return substituter.substSILFunctionType(CanSILFunctionType(this), true); +} + +CanSILFunctionType +SILFunctionType::substituteOpaqueArchetypes(TypeConverter &TC, + TypeExpansionContext context) { + if (!hasOpaqueArchetype() || + !context.shouldLookThroughOpaqueTypeArchetypes()) + return CanSILFunctionType(this); + + ReplaceOpaqueTypesWithUnderlyingTypes replacer( + context.getContext(), context.getResilienceExpansion(), + context.isWholeModuleContext()); + + SILTypeSubstituter substituter(TC, context, replacer, replacer, + getSubstGenericSignature(), + /*shouldSubstituteOpaqueTypes*/ true); + auto resTy = + substituter.substSILFunctionType(CanSILFunctionType(this), false); + + return resTy; +} From cbbdc83f13e55534adba4291c8c3d94e8636d5e4 Mon Sep 17 00:00:00 2001 From: John McCall Date: Thu, 23 Mar 2023 12:50:25 -0400 Subject: [PATCH 14/30] [NFC] Add a convenience function to IFS --- include/swift/AST/InFlightSubstitution.h | 4 ++++ lib/AST/ProtocolConformanceRef.cpp | 2 +- lib/AST/SubstitutionMap.cpp | 2 +- lib/AST/Type.cpp | 6 +++--- 4 files changed, 9 insertions(+), 5 deletions(-) diff --git a/include/swift/AST/InFlightSubstitution.h b/include/swift/AST/InFlightSubstitution.h index a137929d4c9d9..15302efc3f527 100644 --- a/include/swift/AST/InFlightSubstitution.h +++ b/include/swift/AST/InFlightSubstitution.h @@ -59,6 +59,10 @@ class InFlightSubstitution { return Options; } + bool shouldSubstituteOpaqueArchetypes() const { + return Options.contains(SubstFlags::SubstituteOpaqueArchetypes); + } + /// Is the given type invariant to substitution? bool isInvariant(Type type) const; }; diff --git a/lib/AST/ProtocolConformanceRef.cpp b/lib/AST/ProtocolConformanceRef.cpp index e98bf08a22a21..ccde9d8d81f45 100644 --- a/lib/AST/ProtocolConformanceRef.cpp +++ b/lib/AST/ProtocolConformanceRef.cpp @@ -86,7 +86,7 @@ ProtocolConformanceRef::subst(Type origType, InFlightSubstitution &IFS) const { // If the type is an opaque archetype, the conformance will remain abstract, // unless we're specifically substituting opaque types. if (auto origArchetype = origType->getAs()) { - if (!IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes) + if (!IFS.shouldSubstituteOpaqueArchetypes() && isa(origArchetype)) { return *this; } diff --git a/lib/AST/SubstitutionMap.cpp b/lib/AST/SubstitutionMap.cpp index d59f1961b5950..3d0214d77fc36 100644 --- a/lib/AST/SubstitutionMap.cpp +++ b/lib/AST/SubstitutionMap.cpp @@ -485,7 +485,7 @@ SubstitutionMap SubstitutionMap::subst(InFlightSubstitution &IFS) const { // Fast path for concrete case -- we don't need to compute substType // at all. if (conformance.isConcrete() && - !IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes)) { + !IFS.shouldSubstituteOpaqueArchetypes()) { newConformances.push_back( ProtocolConformanceRef(conformance.getConcrete()->subst(IFS))); } else { diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 2c965368c254e..796a416f136ae 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -4655,7 +4655,7 @@ static Type substGenericFunctionType(GenericFunctionType *genericFnType, bool InFlightSubstitution::isInvariant(Type derivedType) const { return !derivedType->hasArchetype() && !derivedType->hasTypeParameter() - && (!Options.contains(SubstFlags::SubstituteOpaqueArchetypes) + && (!shouldSubstituteOpaqueArchetypes() || !derivedType->hasOpaqueArchetype()); } @@ -4742,13 +4742,13 @@ static Type substType(Type derivedType, InFlightSubstitution &IFS) { // Opaque types can't normally be directly substituted unless we // specifically were asked to substitute them. - if (!IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes) + if (!IFS.shouldSubstituteOpaqueArchetypes() && isa(substOrig)) return None; // If we have a substitution for this type, use it. if (auto known = IFS.substType(substOrig)) { - if (IFS.getOptions().contains(SubstFlags::SubstituteOpaqueArchetypes) && + if (IFS.shouldSubstituteOpaqueArchetypes() && isa(substOrig) && known->getCanonicalType() == substOrig->getCanonicalType()) return None; // Recursively process the substitutions of the opaque type From aa713a8908ef4df325954b95d59f7a950dfb574d Mon Sep 17 00:00:00 2001 From: John McCall Date: Thu, 23 Mar 2023 13:42:49 -0400 Subject: [PATCH 15/30] [NFC] Add an operation to change the current options on NFC --- include/swift/AST/InFlightSubstitution.h | 25 ++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/include/swift/AST/InFlightSubstitution.h b/include/swift/AST/InFlightSubstitution.h index 15302efc3f527..1317b1508cecb 100644 --- a/include/swift/AST/InFlightSubstitution.h +++ b/include/swift/AST/InFlightSubstitution.h @@ -55,6 +55,31 @@ class InFlightSubstitution { conformedProtocol); } + class OptionsAdjustmentScope { + InFlightSubstitution &IFS; + SubstOptions SavedOptions; + + public: + OptionsAdjustmentScope(InFlightSubstitution &IFS, SubstOptions newOptions) + : IFS(IFS), SavedOptions(IFS.Options) { + IFS.Options = newOptions; + } + + OptionsAdjustmentScope(const OptionsAdjustmentScope &) = delete; + OptionsAdjustmentScope &operator=(const OptionsAdjustmentScope &) = delete; + + ~OptionsAdjustmentScope() { + IFS.Options = SavedOptions; + } + }; + + template + auto withNewOptions(SubstOptions options, Fn &&fn) + -> decltype(std::forward(fn)()) { + OptionsAdjustmentScope scope(*this, options); + return std::forward(fn)(); + } + SubstOptions getOptions() const { return Options; } From ad091a4e16d857388a839b083188de6e482e8fb1 Mon Sep 17 00:00:00 2001 From: John McCall Date: Thu, 23 Mar 2023 13:43:20 -0400 Subject: [PATCH 16/30] [NFC] Use InFlightSubstitution in the SIL type substituter --- include/swift/AST/Types.h | 3 + include/swift/SIL/SILType.h | 4 + lib/SIL/IR/SILTypeSubstitution.cpp | 154 +++++++++++++++-------------- 3 files changed, 87 insertions(+), 74 deletions(-) diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index f60c0bc181adc..8130c06f51682 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -5168,6 +5168,9 @@ class SILFunctionType final TypeSubstitutionFn subs, LookupConformanceFn conformances, TypeExpansionContext context); + CanSILFunctionType substGenericArgs(SILModule &silModule, + InFlightSubstitution &IFS, + TypeExpansionContext context); CanSILFunctionType substituteOpaqueArchetypes(Lowering::TypeConverter &TC, TypeExpansionContext context); diff --git a/include/swift/SIL/SILType.h b/include/swift/SIL/SILType.h index a0331c5ead282..358956b88aa73 100644 --- a/include/swift/SIL/SILType.h +++ b/include/swift/SIL/SILType.h @@ -651,6 +651,10 @@ class SILType { CanGenericSignature genericSig = CanGenericSignature(), bool shouldSubstituteOpaqueArchetypes = false) const; + SILType subst(Lowering::TypeConverter &tc, + InFlightSubstitution &IFS, + CanGenericSignature genericSig) const; + SILType subst(Lowering::TypeConverter &tc, SubstitutionMap subs) const; SILType subst(SILModule &M, SubstitutionMap subs) const; diff --git a/lib/SIL/IR/SILTypeSubstitution.cpp b/lib/SIL/IR/SILTypeSubstitution.cpp index 4a5743405282f..51e8d242f6174 100644 --- a/lib/SIL/IR/SILTypeSubstitution.cpp +++ b/lib/SIL/IR/SILTypeSubstitution.cpp @@ -34,8 +34,8 @@ namespace { class SILTypeSubstituter : public CanTypeVisitor { TypeConverter &TC; - TypeSubstitutionFn Subst; - LookupConformanceFn Conformances; + InFlightSubstitution &IFS; + // The signature for the original type. // // Replacement types are lowered with respect to the current @@ -65,21 +65,15 @@ class SILTypeSubstituter : TypeExpansionContext typeExpansionContext; - bool shouldSubstituteOpaqueArchetypes; - public: SILTypeSubstituter(TypeConverter &TC, TypeExpansionContext context, - TypeSubstitutionFn Subst, - LookupConformanceFn Conformances, - CanGenericSignature Sig, - bool shouldSubstituteOpaqueArchetypes) + InFlightSubstitution &IFS, + CanGenericSignature Sig) : TC(TC), - Subst(Subst), - Conformances(Conformances), + IFS(IFS), Sig(Sig), - typeExpansionContext(context), - shouldSubstituteOpaqueArchetypes(shouldSubstituteOpaqueArchetypes) + typeExpansionContext(context) {} // SIL type lowering only does special things to tuples and functions. @@ -113,7 +107,7 @@ class SILTypeSubstituter : assert((!isGenericApplication || origType->isPolymorphic()) && "generic application without invocation signature or with " "existing arguments"); - assert((!isGenericApplication || !shouldSubstituteOpaqueArchetypes) && + assert((!isGenericApplication || !IFS.shouldSubstituteOpaqueArchetypes()) && "generic application while substituting opaque archetypes"); // The general substitution rule is that we should only substitute @@ -194,7 +188,7 @@ class SILTypeSubstituter : if (patternSubs) { subs = substSubstitutions(patternSubs); } else { - subs = SubstitutionMap::get(sig, Subst, Conformances); + subs = SubstitutionMap::get(sig, IFS); } auto witnessConformance = substWitnessConformance(origType); substType = substType->withPatternSpecialization(nullptr, subs, @@ -211,7 +205,7 @@ class SILTypeSubstituter : // substitutions present, just substitute those and preserve the // basic structure in the component types. Otherwise, fall through // to substitute the component types. - } else if (shouldSubstituteOpaqueArchetypes) { + } else if (IFS.shouldSubstituteOpaqueArchetypes()) { if (patternSubs) { patternSubs = substOpaqueTypes(patternSubs); auto witnessConformance = substWitnessConformance(origType); @@ -271,10 +265,10 @@ class SILTypeSubstituter : // The substituted type is no longer generic, so it'd never be // pseudogeneric. auto extInfo = origType->getExtInfo(); - if (!shouldSubstituteOpaqueArchetypes) + if (!IFS.shouldSubstituteOpaqueArchetypes()) extInfo = extInfo.intoBuilder().withIsPseudogeneric(false).build(); - auto genericSig = shouldSubstituteOpaqueArchetypes + auto genericSig = IFS.shouldSubstituteOpaqueArchetypes() ? origType->getInvocationGenericSignature() : nullptr; @@ -301,15 +295,14 @@ class SILTypeSubstituter : selfType = next; } - auto substConformance = - conformance.subst(selfType, Subst, Conformances); + auto substConformance = conformance.subst(selfType, IFS); // Substitute the underlying conformance of opaque type archetypes if we // should look through opaque archetypes. if (typeExpansionContext.shouldLookThroughOpaqueTypeArchetypes()) { - SubstOptions substOptions(None); - auto substType = selfType.subst(Subst, Conformances, substOptions) - ->getCanonicalType(); + auto substType = IFS.withNewOptions(None, [&] { + return selfType.subst(IFS)->getCanonicalType(); + }); if (substType->hasOpaqueArchetype()) { substConformance = substOpaqueTypesWithUnderlyingTypes( substConformance, substType, typeExpansionContext); @@ -523,7 +516,7 @@ class SILTypeSubstituter : SubstRespectingExpansions(SILTypeSubstituter *_this) : _this(_this) {} Type operator()(SubstitutableType *origType) const { - auto substType = _this->Subst(origType); + auto substType = _this->IFS.substType(origType); if (!substType) return substType; auto substPackType = dyn_cast(substType->getCanonicalType()); if (!substPackType) return substType; @@ -551,9 +544,10 @@ class SILTypeSubstituter : ProtocolConformanceRef operator()(CanType dependentType, Type conformingReplacementType, ProtocolDecl *conformingProtocol) const { - auto conformance = _this->Conformances(dependentType, - conformingReplacementType, - conformingProtocol); + auto conformance = + _this->IFS.lookupConformance(dependentType, + conformingReplacementType, + conformingProtocol); if (!conformance || !conformance.isPack()) return conformance; auto activeExpansion = _this->getActivePackExpansion(dependentType); if (!activeExpansion) return conformance; @@ -567,35 +561,23 @@ class SILTypeSubstituter : }; CanType substASTType(CanType origType) { - SubstOptions substOptions(None); - if (shouldSubstituteOpaqueArchetypes) - substOptions = SubstFlags::SubstituteOpaqueArchetypes | - SubstFlags::AllowLoweredTypes; - if (ActivePackExpansions.empty()) - return origType.subst(Subst, Conformances, substOptions) - ->getCanonicalType(); + return origType.subst(IFS)->getCanonicalType(); return origType.subst(SubstRespectingExpansions(this), SubstConformanceRespectingExpansions(this), - substOptions)->getCanonicalType(); + IFS.getOptions())->getCanonicalType(); } SubstitutionMap substSubstitutions(SubstitutionMap subs) { - // Substitute the substitutions. - SubstOptions options = None; - if (shouldSubstituteOpaqueArchetypes) - options |= SubstFlags::SubstituteOpaqueArchetypes; - - // Expand substituted type according to the expansion context. SubstitutionMap newSubs; if (ActivePackExpansions.empty()) - newSubs = subs.subst(Subst, Conformances, options); + newSubs = subs.subst(IFS); else newSubs = subs.subst(SubstRespectingExpansions(this), SubstConformanceRespectingExpansions(this), - options); + IFS.getOptions()); // If we need to look through opaque types in this context, re-substitute // according to the expansion context. @@ -629,18 +611,39 @@ class SILTypeSubstituter : } // end anonymous namespace +static bool isSubstitutionInvariant(SILType ty, + bool shouldSubstituteOpaqueArchetypes) { + return (!ty.hasArchetype() && + !ty.hasTypeParameter() && + (!shouldSubstituteOpaqueArchetypes || + !ty.getRawASTType()->hasOpaqueArchetype())); +} + SILType SILType::subst(TypeConverter &tc, TypeSubstitutionFn subs, LookupConformanceFn conformances, CanGenericSignature genericSig, bool shouldSubstituteOpaqueArchetypes) const { - if (!hasArchetype() && !hasTypeParameter() && - (!shouldSubstituteOpaqueArchetypes || - !getASTType()->hasOpaqueArchetype())) + if (isSubstitutionInvariant(*this, shouldSubstituteOpaqueArchetypes)) return *this; - SILTypeSubstituter STST(tc, TypeExpansionContext::minimal(), subs, - conformances, genericSig, - shouldSubstituteOpaqueArchetypes); + auto substOptions = + (shouldSubstituteOpaqueArchetypes + ? SubstOptions(SubstFlags::SubstituteOpaqueArchetypes) + : SubstOptions(None)); + InFlightSubstitution IFS(subs, conformances, substOptions); + + SILTypeSubstituter STST(tc, TypeExpansionContext::minimal(), IFS, + genericSig); + return STST.subst(*this); +} + +SILType SILType::subst(TypeConverter &tc, InFlightSubstitution &IFS, + CanGenericSignature genericSig) const { + if (isSubstitutionInvariant(*this, IFS.shouldSubstituteOpaqueArchetypes())) + return *this; + + SILTypeSubstituter STST(tc, TypeExpansionContext::minimal(), IFS, + genericSig); return STST.subst(*this); } @@ -654,9 +657,9 @@ SILType SILType::subst(SILModule &M, TypeSubstitutionFn subs, SILType SILType::subst(TypeConverter &tc, SubstitutionMap subs) const { auto sig = subs.getGenericSignature(); - return subst(tc, QuerySubstitutionMap{subs}, - LookUpConformanceInSubstitutionMap(subs), - sig.getCanonicalSignature()); + + InFlightSubstitutionViaSubMap IFS(subs, None); + return subst(tc, IFS, sig.getCanonicalSignature()); } SILType SILType::subst(SILModule &M, SubstitutionMap subs) const{ return subst(M.Types, subs); @@ -664,21 +667,14 @@ SILType SILType::subst(SILModule &M, SubstitutionMap subs) const{ SILType SILType::subst(SILModule &M, SubstitutionMap subs, TypeExpansionContext context) const { - if (!hasArchetype() && !hasTypeParameter() && - !getASTType()->hasOpaqueArchetype()) + if (isSubstitutionInvariant(*this, false)) return *this; - // Pass the TypeSubstitutionFn and LookupConformanceFn as arguments so that - // the llvm::function_ref value's scope spans the STST.subst call since - // SILTypeSubstituter captures these functions. - auto result = [&](TypeSubstitutionFn subsFn, - LookupConformanceFn conformancesFn) -> SILType { - SILTypeSubstituter STST(M.Types, context, subsFn, conformancesFn, - subs.getGenericSignature().getCanonicalSignature(), - false); - return STST.subst(*this); - }(QuerySubstitutionMap{subs}, LookUpConformanceInSubstitutionMap(subs)); - return result; + InFlightSubstitutionViaSubMap IFS(subs, None); + + SILTypeSubstituter STST(M.Types, context, IFS, + subs.getGenericSignature().getCanonicalSignature()); + return STST.subst(*this); } /// Apply a substitution to this polymorphic SILFunctionType so that @@ -695,10 +691,9 @@ SILFunctionType::substGenericArgs(SILModule &silModule, SubstitutionMap subs, return CanSILFunctionType(this); } - return substGenericArgs(silModule, - QuerySubstitutionMap{subs}, - LookUpConformanceInSubstitutionMap(subs), - context); + InFlightSubstitutionViaSubMap IFS(subs, None); + + return substGenericArgs(silModule, IFS, context); } CanSILFunctionType @@ -707,9 +702,19 @@ SILFunctionType::substGenericArgs(SILModule &silModule, LookupConformanceFn conformances, TypeExpansionContext context) { if (!isPolymorphic()) return CanSILFunctionType(this); - SILTypeSubstituter substituter(silModule.Types, context, subs, conformances, - getSubstGenericSignature(), - /*shouldSubstituteOpaqueTypes*/ false); + + InFlightSubstitution IFS(subs, conformances, None); + return substGenericArgs(silModule, IFS, context); +} + +CanSILFunctionType +SILFunctionType::substGenericArgs(SILModule &silModule, + InFlightSubstitution &IFS, + TypeExpansionContext context) { + if (!isPolymorphic()) return CanSILFunctionType(this); + + SILTypeSubstituter substituter(silModule.Types, context, IFS, + getSubstGenericSignature()); return substituter.substSILFunctionType(CanSILFunctionType(this), true); } @@ -724,9 +729,10 @@ SILFunctionType::substituteOpaqueArchetypes(TypeConverter &TC, context.getContext(), context.getResilienceExpansion(), context.isWholeModuleContext()); - SILTypeSubstituter substituter(TC, context, replacer, replacer, - getSubstGenericSignature(), - /*shouldSubstituteOpaqueTypes*/ true); + InFlightSubstitution IFS(replacer, replacer, + SubstFlags::SubstituteOpaqueArchetypes); + + SILTypeSubstituter substituter(TC, context, IFS, getSubstGenericSignature()); auto resTy = substituter.substSILFunctionType(CanSILFunctionType(this), false); From 6cb5dbcd8279f22281d63e47c2ea2ecca775f1d5 Mon Sep 17 00:00:00 2001 From: John McCall Date: Sat, 25 Mar 2023 18:39:44 -0400 Subject: [PATCH 17/30] The canonical type of a pack expansion has a reduced shape. --- lib/AST/ASTContext.cpp | 13 ++++++++++++- lib/AST/Type.cpp | 2 ++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index d0a9c8022d67a..09b1fcb27c627 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -3273,8 +3273,19 @@ PackExpansionType *PackExpansionType::get(Type patternType, Type countType) { .PackExpansionTypes.FindNodeOrInsertPos(id, insertPos)) return expType; + // The canonical pack expansion type uses the canonical shape. + // For interface types, we'd need a signature to do this properly, + // but for archetypes we can do it directly. + bool countIsCanonical = countType->isCanonical(); + if (countIsCanonical) { + if (auto archetype = dyn_cast(countType.getPointer())) { + auto reducedShape = archetype->getReducedShape(); + countIsCanonical = (reducedShape.getPointer() == archetype); + } + } + const ASTContext *canCtx = - (patternType->isCanonical() && countType->isCanonical()) + (patternType->isCanonical() && countIsCanonical) ? &context : nullptr; PackExpansionType *expansionType = new (context, arena) PackExpansionType(patternType, countType, properties, diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index 796a416f136ae..a43c332ebcf28 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -1673,6 +1673,8 @@ CanType TypeBase::computeCanonicalType() { auto *expansion = cast(this); auto patternType = expansion->getPatternType()->getCanonicalType(); auto countType = expansion->getCountType()->getCanonicalType(); + if (auto packArchetype = dyn_cast(countType)) + countType = packArchetype->getReducedShape(); Result = PackExpansionType::get(patternType, countType); break; } From 532000fe6fac6875e28aa421fe2a2b8546c525e5 Mon Sep 17 00:00:00 2001 From: John McCall Date: Sat, 25 Mar 2023 18:46:29 -0400 Subject: [PATCH 18/30] Perform component-wise substitution of pack expansions immediately. Substitution of a pack expansion type may now produce a pack type. We immediately expand that pack when transforming a tuple, a function parameter, or a pack. I had to duplicate the component-wise transformation logic in the simplifyType transform, which I'm not pleased about, but a little code duplication seemed a lot better than trying to unify the code in two very different places. I think we're very close to being able to assert that pack expansion shapes are either pack archetypes or pack parameters; unfortunately, the pack matchers intentionally produce expansions of packs, and I didn't want to add that to an already-large patch. --- include/swift/AST/InFlightSubstitution.h | 78 +++++++- lib/AST/ASTContext.cpp | 6 + lib/AST/PackConformance.cpp | 224 +++++------------------ lib/AST/ParameterPack.cpp | 6 +- lib/AST/Type.cpp | 173 +++++++++++++++-- lib/SIL/IR/SILTypeSubstitution.cpp | 170 ++--------------- lib/Sema/CSApply.cpp | 4 +- lib/Sema/ConstraintSystem.cpp | 106 +++++++++-- lib/Sema/TypeCheckType.cpp | 19 +- 9 files changed, 399 insertions(+), 387 deletions(-) diff --git a/include/swift/AST/InFlightSubstitution.h b/include/swift/AST/InFlightSubstitution.h index 1317b1508cecb..bda9dd5d56118 100644 --- a/include/swift/AST/InFlightSubstitution.h +++ b/include/swift/AST/InFlightSubstitution.h @@ -32,6 +32,12 @@ class InFlightSubstitution { TypeSubstitutionFn BaselineSubstType; LookupConformanceFn BaselineLookupConformance; + struct ActivePackExpansion { + bool isSubstExpansion = false; + unsigned expansionIndex = 0; + }; + SmallVector ActivePackExpansions; + public: InFlightSubstitution(TypeSubstitutionFn substType, LookupConformanceFn lookupConformance, @@ -43,16 +49,72 @@ class InFlightSubstitution { InFlightSubstitution(const InFlightSubstitution &) = delete; InFlightSubstitution &operator=(const InFlightSubstitution &) = delete; - Type substType(SubstitutableType *ty) { - return BaselineSubstType(ty); - } - + // TODO: when we add PackElementType, we should recognize it during + // substitution and either call different methods on this class or + // pass an extra argument for the pack-expansion depth D. We should + // be able to rely on that to mark a pack-element reference instead + // of checking whether the original type was a pack. Substitution + // should use the D'th entry from the end of ActivePackExpansions to + // guide the element substitution: + // - project the given index of the pack substitution + // - wrap it in a PackElementType if it's a subst expansion + // - the depth of that PackElementType is the number of subst + // expansions between the depth entry and the end of + // ActivePackExpansions + + /// Perform primitive substitution on the given type. Returns Type() + /// if the type should not be substituted as a whole. + Type substType(SubstitutableType *origType); + + /// Perform primitive conformance lookup on the given type. ProtocolConformanceRef lookupConformance(CanType dependentType, Type conformingReplacementType, - ProtocolDecl *conformedProtocol) { - return BaselineLookupConformance(dependentType, - conformingReplacementType, - conformedProtocol); + ProtocolDecl *conformedProtocol); + + /// Given the shape type of a pack expansion, invoke the given callback + /// for each expanded component of it. If the substituted component + /// is an expansion component, the desired shape of that expansion + /// is passed as the argument; otherwise, the argument is Type(). + /// In either case, an active expansion is entered on this IFS for + /// the duration of the call to handleComponent, and subsequent + /// pack-element type references will substitute to the corresponding + /// element of the substitution of the pack. + void expandPackExpansionShape(Type origShape, + llvm::function_ref handleComponent); + + /// Call the given function for each expanded component type of the + /// given pack expansion type. The function will be invoked with the + /// active expansion still active. + void expandPackExpansionType(PackExpansionType *origExpansionType, + llvm::function_ref handleComponentType) { + expandPackExpansionShape(origExpansionType->getCountType(), + [&](Type substComponentShape) { + auto origPatternType = origExpansionType->getPatternType(); + auto substEltType = origPatternType.subst(*this); + + auto substComponentType = + (substComponentShape + ? PackExpansionType::get(substEltType, substComponentShape) + : substEltType); + handleComponentType(substComponentType); + }); + } + + /// Return a list of component types that the pack expansion expands to. + SmallVector + expandPackExpansionType(PackExpansionType *origExpansionType) { + SmallVector substComponentTypes; + expandPackExpansionType(origExpansionType, substComponentTypes); + return substComponentTypes; + } + + /// Expand the list of component types that the pack expansion expands + /// to into the given array. + void expandPackExpansionType(PackExpansionType *origExpansionType, + SmallVectorImpl &substComponentTypes) { + expandPackExpansionType(origExpansionType, [&](Type substComponentType) { + substComponentTypes.push_back(substComponentType); + }); } class OptionsAdjustmentScope { diff --git a/lib/AST/ASTContext.cpp b/lib/AST/ASTContext.cpp index 09b1fcb27c627..cd98397277c63 100644 --- a/lib/AST/ASTContext.cpp +++ b/lib/AST/ASTContext.cpp @@ -3258,6 +3258,12 @@ CanPackExpansionType::get(CanType patternType, CanType countType) { } PackExpansionType *PackExpansionType::get(Type patternType, Type countType) { + assert(!patternType->is()); + assert(!countType->is()); + // FIXME: stop doing this deliberately in PackExpansionMatcher + //assert(!patternType->is()); + //assert(!countType->is()); + auto properties = patternType->getRecursiveProperties(); properties |= countType->getRecursiveProperties(); diff --git a/lib/AST/PackConformance.cpp b/lib/AST/PackConformance.cpp index e4006e6c8ae7c..93662faea8eac 100644 --- a/lib/AST/PackConformance.cpp +++ b/lib/AST/PackConformance.cpp @@ -167,207 +167,65 @@ ProtocolConformanceRef PackConformance::subst(SubstitutionMap subMap, return subst(IFS); } -// TODO: Move this elsewhere since it's generally useful -static bool arePackShapesEqual(PackType *lhs, PackType *rhs) { - if (lhs->getNumElements() != rhs->getNumElements()) - return false; - - for (unsigned i = 0, e = lhs->getNumElements(); i < e; ++i) { - auto lhsElt = lhs->getElementType(i); - auto rhsElt = rhs->getElementType(i); - - if (lhsElt->is() != rhsElt->is()) - return false; - } - - return true; -} - -static bool isRootParameterPack(Type t) { - if (auto *paramTy = t->getAs()) { - return paramTy->isParameterPack(); - } else if (auto *archetypeTy = t->getAs()) { - return archetypeTy->isRoot(); - } - - return false; -} - -static bool isRootedInParameterPack(Type t) { - if (auto *archetypeTy = t->getAs()) { - return true; - } - - return t->getRootGenericParam()->isParameterPack(); -} - namespace { -template -class PackExpander { -protected: +struct PackConformanceExpander { InFlightSubstitution &IFS; + ArrayRef origConformances; - PackExpander(InFlightSubstitution &IFS) : IFS(IFS) {} - - ImplClass *asImpl() { - return static_cast(this); - } - - /// We're replacing a pack expansion type with a pack -- flatten the pack - /// using the pack expansion's pattern. - void addExpandedExpansion(Type origPatternType, PackType *expandedCountType, - unsigned i) { - - // Get all pack parameters referenced from the pattern. - SmallVector rootParameterPacks; - origPatternType->getTypeParameterPacks(rootParameterPacks); - - // Each pack parameter referenced from the pattern must be replaced - // with a pack type, and all pack types must have the same shape as - // the expanded count pack type. - llvm::SmallDenseMap expandedPacks; - for (auto origParamType : rootParameterPacks) { - auto substParamType = origParamType.subst(IFS); - - if (auto expandedParamType = substParamType->template getAs()) { - assert(arePackShapesEqual(expandedParamType, expandedCountType) && - "TODO: Return an invalid conformance if this fails"); - - auto inserted = expandedPacks.insert( - std::make_pair(origParamType->getCanonicalType(), - expandedParamType)).second; - assert(inserted && - "getTypeParameterPacks() should not return duplicates"); - } else { - assert(false && - "TODO: Return an invalid conformance if this fails"); - } - } - - // For each element of the expanded count, compute the substituted - // pattern type. - for (unsigned j = 0, ee = expandedCountType->getNumElements(); j < ee; ++j) { - auto projectedSubs = [&](SubstitutableType *type) -> Type { - // Nested sequence archetypes get passed in here, but we must - // handle them via the standard nested type path. - if (auto *archetypeType = dyn_cast(type)) { - if (!archetypeType->isRoot()) - return Type(); - } - - // Compute the substituted type using our parent substitutions. - auto substType = Type(type).subst(IFS); - - // If the substituted type is a pack, project the jth element. - if (isRootParameterPack(type)) { - // FIXME: What if you have something like G... where G<> is - // variadic? - assert(substType->template is() && - "TODO: Return an invalid conformance if this fails"); - auto *packType = substType->template castTo(); - assert(arePackShapesEqual(packType, expandedCountType) && - "TODO: Return an invalid conformance if this fails"); - - return packType->getElementType(j); - } - - return IFS.substType(type); - }; - - auto projectedConformances = [&](CanType origType, Type substType, - ProtocolDecl *proto) -> ProtocolConformanceRef { - auto substConformance = - IFS.lookupConformance(origType, substType, proto); - - // If the substituted conformance is a pack, project the jth element. - if (isRootedInParameterPack(origType)) { - return substConformance.getPack()->getPatternConformances()[j]; - } - - return substConformance; - }; - - auto origCountElement = expandedCountType->getElementType(j); - auto substCountElement = origCountElement.subst( - projectedSubs, projectedConformances, IFS.getOptions()); - - asImpl()->add(origCountElement, substCountElement, i); - } - } - - /// A pack expansion remains unexpanded, so we substitute the pattern and - /// form a new pack expansion. - void addUnexpandedExpansion(Type origPatternType, Type substCountType, - unsigned i) { - auto substPatternType = origPatternType.subst(IFS); - auto substExpansion = PackExpansionType::get(substPatternType, substCountType); +public: + // Results built up by the expansion. + SmallVector substElementTypes; + SmallVector substConformances; - asImpl()->add(origPatternType, substExpansion, i); - } + PackConformanceExpander(InFlightSubstitution &IFS, + ArrayRef origConformances) + : IFS(IFS), origConformances(origConformances) {} - /// Scalar elements of the original pack are substituted and added to the - /// flattened pack. - void addScalar(Type origElement, unsigned i) { - auto substElement = origElement.subst(IFS); +private: + /// Substitute a scalar element of the original pack. + void substScalar(Type origElementType, + ProtocolConformanceRef origConformance) { + auto substElementType = origElementType.subst(IFS); + auto substConformance = origConformance.subst(origElementType, IFS); - asImpl()->add(origElement, substElement, i); + substElementTypes.push_back(substElementType); + substConformances.push_back(substConformance); } - /// Potentially expand an element of the original pack. - void maybeExpandExpansion(PackExpansionType *origExpansion, unsigned i) { - auto origPatternType = origExpansion->getPatternType(); - auto origCountType = origExpansion->getCountType(); - - auto substCountType = origCountType.subst(IFS); - - // If the substituted count type is a pack, we're expanding the - // original element. - if (auto *expandedCountType = substCountType->template getAs()) { - addExpandedExpansion(origPatternType, expandedCountType, i); - return; - } - - addUnexpandedExpansion(origPatternType, substCountType, i); + /// Substitute and expand an expansion element of the original pack. + void substExpansion(PackExpansionType *origExpansionType, + ProtocolConformanceRef origConformance) { + IFS.expandPackExpansionType(origExpansionType, + [&](Type substComponentType) { + auto origPatternType = origExpansionType->getPatternType(); + + // Just substitute the conformance. We don't directly represent + // pack expansion conformances here; it's sort of implicit in the + // corresponding pack element type. + auto substConformance = origConformance.subst(origPatternType, IFS); + + substElementTypes.push_back(substComponentType); + substConformances.push_back(substConformance); + }); } public: void expand(PackType *origPackType) { - for (unsigned i = 0, e = origPackType->getNumElements(); i < e; ++i) { - auto origElement = origPackType->getElementType(i); + assert(origPackType->getNumElements() == origConformances.size()); - // Check if the original element is potentially being expanded. - if (auto *origExpansion = origElement->getAs()) { - maybeExpandExpansion(origExpansion, i); - continue; + for (auto i : range(origPackType->getNumElements())) { + auto origElementType = origPackType->getElementType(i); + if (auto *origExpansion = origElementType->getAs()) { + substExpansion(origExpansion, origConformances[i]); + } else { + substScalar(origElementType, origConformances[i]); } - - addScalar(origElement, i); } } }; -class PackConformanceExpander : public PackExpander { -public: - SmallVector substElements; - SmallVector substConformances; - - ArrayRef origConformances; - - PackConformanceExpander(InFlightSubstitution &IFS, - ArrayRef origConformances) - : PackExpander(IFS), origConformances(origConformances) {} - - void add(Type origType, Type substType, unsigned i) { - substElements.push_back(substType); - - // FIXME: Pass down projection callbacks - substConformances.push_back(origConformances[i].subst( - origType, IFS)); - } -}; - -} +} // end anonymous namespace ProtocolConformanceRef PackConformance::subst(TypeSubstitutionFn subs, LookupConformanceFn conformances, @@ -382,7 +240,7 @@ PackConformance::subst(InFlightSubstitution &IFS) const { expander.expand(ConformingType); auto &ctx = Protocol->getASTContext(); - auto *substConformingType = PackType::get(ctx, expander.substElements); + auto *substConformingType = PackType::get(ctx, expander.substElementTypes); auto substConformance = PackConformance::get(substConformingType, Protocol, expander.substConformances); diff --git a/lib/AST/ParameterPack.cpp b/lib/AST/ParameterPack.cpp index d3957e718d41a..fb2e93b80219d 100644 --- a/lib/AST/ParameterPack.cpp +++ b/lib/AST/ParameterPack.cpp @@ -449,15 +449,15 @@ PackType *PackType::get(const ASTContext &C, auto arg = args[i]; if (params[i]->isParameterPack()) { - wrappedArgs.push_back(PackExpansionType::get( - arg, arg->getReducedShape())); + auto argPackElements = arg->castTo()->getElementTypes(); + wrappedArgs.append(argPackElements.begin(), argPackElements.end()); continue; } wrappedArgs.push_back(arg); } - return get(C, wrappedArgs)->flattenPackTypes(); + return get(C, wrappedArgs); } PackType *PackType::getSingletonPackExpansion(Type param) { diff --git a/lib/AST/Type.cpp b/lib/AST/Type.cpp index a43c332ebcf28..90d9612b333ea 100644 --- a/lib/AST/Type.cpp +++ b/lib/AST/Type.cpp @@ -4654,6 +4654,84 @@ static Type substGenericFunctionType(GenericFunctionType *genericFnType, fnType->getResult(), fnType->getExtInfo()); } +void InFlightSubstitution::expandPackExpansionShape(Type origShape, + llvm::function_ref handleComponent) { + + // Substitute the shape using the baseline substitutions, not the + // current elementwise projections. + auto substShape = origShape.subst(BaselineSubstType, + BaselineLookupConformance, + Options); + + auto substPackShape = substShape->getAs(); + if (!substPackShape) { + ActivePackExpansions.push_back({/*is subst expansion*/true, 0}); + handleComponent(substShape); + ActivePackExpansions.pop_back(); + return; + } + + ActivePackExpansions.push_back({false, 0}); + for (auto substElt : substPackShape->getElementTypes()) { + auto substExpansion = substElt->getAs(); + auto substExpansionShape = + (substExpansion ? substExpansion->getCountType() : Type()); + + ActivePackExpansions.back().isSubstExpansion = + (substExpansion != nullptr); + handleComponent(substExpansionShape); + ActivePackExpansions.back().expansionIndex++; + } + ActivePackExpansions.pop_back(); +} + +Type InFlightSubstitution::substType(SubstitutableType *origType) { + auto substType = BaselineSubstType(origType); + if (!substType || ActivePackExpansions.empty()) + return substType; + + auto substPackType = substType->getAs(); + if (!substPackType) + return substType; + + auto &activeExpansion = ActivePackExpansions.back(); + auto index = activeExpansion.expansionIndex; + assert(index < substPackType->getNumElements() && + "replacement for pack parameter did not have the right " + "size for expansion"); + auto substEltType = substPackType->getElementType(index); + if (activeExpansion.isSubstExpansion) { + assert(substEltType->is() && + "substituted shape mismatch: expected an expansion component"); + substEltType = substEltType->castTo()->getPatternType(); + } else { + assert(!substEltType->is() && + "substituted shape mismatch: expected a scalar component"); + } + return substEltType; +} + +ProtocolConformanceRef +InFlightSubstitution::lookupConformance(CanType dependentType, + Type conformingReplacementType, + ProtocolDecl *conformedProtocol) { + auto substConfRef = BaselineLookupConformance(dependentType, + conformingReplacementType, + conformedProtocol); + if (!substConfRef || + ActivePackExpansions.empty() || + !substConfRef.isPack()) + return substConfRef; + + auto substPackConf = substConfRef.getPack(); + auto substPackPatterns = substPackConf->getPatternConformances(); + auto index = ActivePackExpansions.back().expansionIndex; + assert(index < substPackPatterns.size() && + "replacement for pack parameter did not have the right " + "size for expansion"); + return substPackPatterns[index]; +} + bool InFlightSubstitution::isInvariant(Type derivedType) const { return !derivedType->hasArchetype() && !derivedType->hasTypeParameter() @@ -4691,12 +4769,9 @@ static Type substType(Type derivedType, InFlightSubstitution &IFS) { } if (auto packExpansionTy = dyn_cast(type)) { - auto patternTy = substType(packExpansionTy->getPatternType(), IFS); - auto countTy = substType(packExpansionTy->getCountType(), IFS); - if (auto *archetypeTy = countTy->getAs()) - countTy = archetypeTy->getReducedShape(); - - return Type(PackExpansionType::get(patternTy, countTy)->expand()); + auto eltTys = IFS.expandPackExpansionType(packExpansionTy); + if (eltTys.size() == 1) return eltTys[0]; + return Type(PackType::get(packExpansionTy->getASTContext(), eltTys)); } if (auto silFnTy = dyn_cast(type)) { @@ -5194,6 +5269,23 @@ Type Type::transform(llvm::function_ref fn) const { }); } +static PackType *getTransformedPack(Type substType) { + if (auto pack = substType->getAs()) { + return pack; + } + + // The pack matchers like to make expansions out of packs, and + // these types then propagate out into transforms. Make sure we + // flatten them exactly if they were the underlying pack. + // FIXME: stop doing this and make PackExpansionType::get assert + // that we never construct these types + if (auto expansion = substType->getAs()) { + return expansion->getPatternType()->getAs(); + } + + return nullptr; +} + Type Type::transformRec( llvm::function_ref(TypeBase *)> fn) const { return transformWithPosition(TypePosition::Invariant, @@ -5631,13 +5723,19 @@ case TypeKind::Id: anyChanged = true; } - elements.push_back(transformedEltTy); + // If the transformed type is a pack, immediately expand it. + if (auto eltPack = getTransformedPack(transformedEltTy)) { + auto eltElements = eltPack->getElementTypes(); + elements.append(eltElements.begin(), eltElements.end()); + } else { + elements.push_back(transformedEltTy); + } } if (!anyChanged) return *this; - return PackType::get(Ptr->getASTContext(), elements)->flattenPackTypes(); + return PackType::get(Ptr->getASTContext(), elements); } case TypeKind::SILPack: { @@ -5689,6 +5787,8 @@ case TypeKind::Id: case TypeKind::PackExpansion: { auto expand = cast(base); + // Substitution completely replaces this. + Type transformedPat = expand->getPatternType().transformWithPosition(pos, fn); if (!transformedPat) @@ -5703,7 +5803,14 @@ case TypeKind::Id: transformedCount.getPointer() == expand->getCountType().getPointer()) return *this; - return PackExpansionType::get(transformedPat, transformedCount)->expand(); + // // If we transform the count to a pack type, expand the pattern. + // // This is necessary because of how we piece together types in + // // the constraint system. + // if (auto countPack = transformedCount->getAs()) { + // return PackExpansionType::expand(transformedPat, countPack); + // } + + return PackExpansionType::get(transformedPat, transformedCount); } case TypeKind::Tuple: { @@ -5734,13 +5841,35 @@ case TypeKind::Id: } // Add the new tuple element, with the transformed type. - elements.push_back(elt.getWithType(transformedEltTy)); + // Expand packs immediately. + if (auto eltPack = getTransformedPack(transformedEltTy)) { + bool first = true; + for (auto eltElement : eltPack->getElementTypes()) { + if (first) { + elements.push_back(elt.getWithType(eltElement)); + first = false; + } else { + elements.push_back(TupleTypeElt(eltElement)); + } + } + } else { + elements.push_back(elt.getWithType(transformedEltTy)); + } } if (!anyChanged) return *this; - return TupleType::get(elements, Ptr->getASTContext())->flattenPackTypes(); + // If the transform would yield a singleton tuple, and we didn't + // start with one, flatten to produce the element type. + if (elements.size() == 1 && + !elements[0].getType()->is() && + !(tuple->getNumElements() == 1 && + !tuple->getElementType(0)->is())) { + return elements[0].getType(); + } + + return TupleType::get(elements, Ptr->getASTContext()); } @@ -5794,7 +5923,21 @@ case TypeKind::Id: flags = flags.withInOut(true); } - substParams.emplace_back(substType, label, flags, internalLabel); + if (auto substPack = getTransformedPack(substType)) { + bool first = true; + for (auto substEltType : substPack->getElementTypes()) { + if (first) { + substParams.emplace_back(substEltType, label, flags, + internalLabel); + first = false; + } else { + substParams.emplace_back(substEltType, Identifier(), flags, + Identifier()); + } + } + } else { + substParams.emplace_back(substType, label, flags, internalLabel); + } } // Transform result type. @@ -5836,8 +5979,7 @@ case TypeKind::Id: return GenericFunctionType::get(genericSig, substParams, resultTy); return GenericFunctionType::get(genericSig, substParams, resultTy, function->getExtInfo() - .withGlobalActor(globalActorType)) - ->flattenPackTypes(); + .withGlobalActor(globalActorType)); } if (isUnchanged) return *this; @@ -5846,8 +5988,7 @@ case TypeKind::Id: return FunctionType::get(substParams, resultTy); return FunctionType::get(substParams, resultTy, function->getExtInfo() - .withGlobalActor(globalActorType)) - ->flattenPackTypes(); + .withGlobalActor(globalActorType)); } case TypeKind::ArraySlice: { diff --git a/lib/SIL/IR/SILTypeSubstitution.cpp b/lib/SIL/IR/SILTypeSubstitution.cpp index 51e8d242f6174..b74c754cf6feb 100644 --- a/lib/SIL/IR/SILTypeSubstitution.cpp +++ b/lib/SIL/IR/SILTypeSubstitution.cpp @@ -42,27 +42,6 @@ class SILTypeSubstituter : // context signature. CanGenericSignature Sig; - struct PackExpansion { - /// The shape class of pack parameters that are expanded by this - /// expansion. Set during construction and not changed. - CanType OrigShapeClass; - - /// The count type of the pack expansion in the current lane of - /// expansion, if any. Pack elements in this lane should be - /// expansions with this shape. - CanType SubstPackExpansionCount; - - /// The index of the current lane of expansion. Basic - /// substitution of pack parameters with the same shape as - /// OrigShapeClass should yield a pack, and lanewise - /// substitution should produce this element of that pack. - unsigned Index; - - PackExpansion(CanType origShapeClass) - : OrigShapeClass(origShapeClass), Index(0) {} - }; - SmallVector ActivePackExpansions; - TypeExpansionContext typeExpansionContext; public: @@ -356,63 +335,21 @@ class SILTypeSubstituter : } CanType visitPackExpansionType(CanPackExpansionType origType) { - CanType patternType = visit(origType.getPatternType()); - CanType countType = substASTType(origType.getCountType()); - - return CanType(PackExpansionType::get(patternType, countType)); + llvm_unreachable("shouldn't substitute an independent lowered pack " + "expansion type"); } void substPackExpansion(CanPackExpansionType origType, llvm::function_ref addExpandedType) { - CanType origCountType = origType.getCountType(); - CanType origPatternType = origType.getPatternType(); - - // Substitute the count type (as an AST type). - CanType substCountType = substASTType(origCountType); - - // If that produces a pack type, expand the pattern element-wise. - if (auto substCountPackType = dyn_cast(substCountType)) { - // Set up for element-wise expansion. - ActivePackExpansions.emplace_back(origCountType); - - for (CanType substCountEltType : substCountPackType.getElementTypes()) { - auto expansionType = dyn_cast(substCountEltType); - ActivePackExpansions.back().SubstPackExpansionCount = - (expansionType ? expansionType.getCountType() : CanType()); - - // Expand the pattern type in the element-wise context. - CanType expandedType = visit(origPatternType); - - // Turn that into a pack expansion if appropriate for the - // count element. - if (expansionType) { - expandedType = - CanPackExpansionType::get(expandedType, - expansionType.getCountType()); - } - - addExpandedType(expandedType); - - // Move to the next element. - ActivePackExpansions.back().Index++; + IFS.expandPackExpansionShape(origType.getCountType(), + [&](Type substExpansionShape) { + CanType substComponentType = visit(origType.getPatternType()); + if (substExpansionShape) { + substComponentType = CanPackExpansionType::get(substComponentType, + substExpansionShape->getCanonicalType()); } - - // Leave the element-wise context. - ActivePackExpansions.pop_back(); - return; - } - - // Otherwise, transform the pattern type abstractly and just add a - // type expansion. - CanType substPatternType = visit(origPatternType); - - CanType expandedType; - if (substCountType == origCountType && substPatternType == origPatternType) - expandedType = origType; - else - expandedType = - CanPackExpansionType::get(substPatternType, substCountType); - addExpandedType(expandedType); + addExpandedType(substComponentType); + }); } /// Tuples need to have their component types substituted by these @@ -511,73 +448,12 @@ class SILTypeSubstituter : substType); } - struct SubstRespectingExpansions { - SILTypeSubstituter *_this; - SubstRespectingExpansions(SILTypeSubstituter *_this) : _this(_this) {} - - Type operator()(SubstitutableType *origType) const { - auto substType = _this->IFS.substType(origType); - if (!substType) return substType; - auto substPackType = dyn_cast(substType->getCanonicalType()); - if (!substPackType) return substType; - auto activeExpansion = _this->getActivePackExpansion(CanType(origType)); - if (!activeExpansion) return substType; - auto substEltType = - substPackType.getElementType(activeExpansion->Index); - auto substExpansion = dyn_cast(substEltType); - assert((bool) substExpansion == - (bool) activeExpansion->SubstPackExpansionCount); - if (substExpansion) { - assert(_this->hasSameShape(substExpansion.getCountType(), - activeExpansion->SubstPackExpansionCount)); - return substExpansion.getPatternType(); - } - return substEltType; - } - }; - - struct SubstConformanceRespectingExpansions { - SILTypeSubstituter *_this; - SubstConformanceRespectingExpansions(SILTypeSubstituter *_this) - : _this(_this) {} - - ProtocolConformanceRef operator()(CanType dependentType, - Type conformingReplacementType, - ProtocolDecl *conformingProtocol) const { - auto conformance = - _this->IFS.lookupConformance(dependentType, - conformingReplacementType, - conformingProtocol); - if (!conformance || !conformance.isPack()) return conformance; - auto activeExpansion = _this->getActivePackExpansion(dependentType); - if (!activeExpansion) return conformance; - auto pack = conformance.getPack(); - auto substEltConf = - pack->getPatternConformances()[activeExpansion->Index]; - // There isn't currently a ProtocolConformanceExpansion that - // we would need to look through here. - return substEltConf; - }; - }; - CanType substASTType(CanType origType) { - if (ActivePackExpansions.empty()) - return origType.subst(IFS)->getCanonicalType(); - - return origType.subst(SubstRespectingExpansions(this), - SubstConformanceRespectingExpansions(this), - IFS.getOptions())->getCanonicalType(); + return origType.subst(IFS)->getCanonicalType(); } SubstitutionMap substSubstitutions(SubstitutionMap subs) { - SubstitutionMap newSubs; - - if (ActivePackExpansions.empty()) - newSubs = subs.subst(IFS); - else - newSubs = subs.subst(SubstRespectingExpansions(this), - SubstConformanceRespectingExpansions(this), - IFS.getOptions()); + SubstitutionMap newSubs = subs.subst(IFS); // If we need to look through opaque types in this context, re-substitute // according to the expansion context. @@ -585,28 +461,6 @@ class SILTypeSubstituter : return newSubs; } - - PackExpansion *getActivePackExpansion(CanType dependentType) { - // We push new expansions onto the end of this vector, and we - // want to honor the innermost expansion, so we have to traverse - // in it reverse. - for (auto &entry : reverse(ActivePackExpansions)) { - if (hasSameShape(dependentType, entry.OrigShapeClass)) - return &entry; - } - return nullptr; - } - - bool hasSameShape(CanType lhs, CanType rhs) { - if (lhs->isTypeParameter() && rhs->isTypeParameter()) { - assert(Sig); - return Sig->haveSameShape(lhs, rhs); - } - - auto lhsArchetype = cast(lhs); - auto rhsArchetype = cast(rhs); - return lhsArchetype->getReducedShape() == rhsArchetype->getReducedShape(); - } }; } // end anonymous namespace diff --git a/lib/Sema/CSApply.cpp b/lib/Sema/CSApply.cpp index 19b8cd4d83d7b..ef5d9edd47305 100644 --- a/lib/Sema/CSApply.cpp +++ b/lib/Sema/CSApply.cpp @@ -7125,8 +7125,8 @@ Expr *ExprRewriter::coerceToType(Expr *expr, Type toType, auto *expansion = dyn_cast(expr); auto *elementEnv = expansion->getGenericEnvironment(); - auto toElementType = elementEnv->mapPackTypeIntoElementContext( - toExpansionType->getPatternType()->mapTypeOutOfContext()); + auto toElementType = elementEnv->mapContextualPackTypeIntoElementContext( + toExpansionType->getPatternType()); auto *pattern = coerceToType(expansion->getPatternExpr(), toElementType, locator); diff --git a/lib/Sema/ConstraintSystem.cpp b/lib/Sema/ConstraintSystem.cpp index df47209c617f5..894a7a28f38d7 100644 --- a/lib/Sema/ConstraintSystem.cpp +++ b/lib/Sema/ConstraintSystem.cpp @@ -3713,16 +3713,56 @@ void ConstraintSystem::resolveOverload(ConstraintLocator *locator, } } -Type ConstraintSystem::simplifyTypeImpl(Type type, - llvm::function_ref getFixedTypeFn) const { - return type.transform([&](Type type) -> Type { - if (auto tvt = dyn_cast(type.getPointer())) - return getFixedTypeFn(tvt); +namespace { + +struct TypeSimplifier { + const ConstraintSystem &CS; + llvm::function_ref GetFixedTypeFn; + + struct ActivePackExpansion { + bool isPackExpansion = false; + unsigned index = 0; + }; + SmallVector ActivePackExpansions; + + TypeSimplifier(const ConstraintSystem &CS, + llvm::function_ref getFixedTypeFn) + : CS(CS), GetFixedTypeFn(getFixedTypeFn) {} + + Type operator()(Type type) { + if (auto tvt = dyn_cast(type.getPointer())) { + auto fixedTy = GetFixedTypeFn(tvt); + + // TODO: the following logic should be applied when rewriting + // PackElementType. + if (ActivePackExpansions.empty()) { + return fixedTy; + } + + if (auto fixedPack = fixedTy->getAs()) { + auto &activeExpansion = ActivePackExpansions.back(); + if (activeExpansion.index >= fixedPack->getNumElements()) { + return tvt; + } + + auto fixedElt = fixedPack->getElementType(activeExpansion.index); + auto fixedExpansion = fixedElt->getAs(); + if (activeExpansion.isPackExpansion && fixedExpansion) { + return fixedExpansion->getPatternType(); + } else if (!activeExpansion.isPackExpansion && !fixedExpansion) { + return fixedElt; + } else { + return tvt; + } + } + + return fixedTy; + } if (auto tuple = dyn_cast(type.getPointer())) { if (tuple->getNumElements() == 1) { auto element = tuple->getElement(0); - auto elementType = simplifyTypeImpl(element.getType(), getFixedTypeFn); + auto elementType = element.getType().transform(*this); // Flatten single-element tuples containing type variables that cannot // bind to packs. @@ -3733,14 +3773,47 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, } } + if (auto expansion = dyn_cast(type.getPointer())) { + // Transform the count type, ignoring any active pack expansions. + auto countType = expansion->getCountType().transform( + TypeSimplifier(CS, GetFixedTypeFn)); + + if (auto countPack = countType->getAs()) { + SmallVector elts; + ActivePackExpansions.push_back({false, 0}); + for (auto countElt : countPack->getElementTypes()) { + auto countExpansion = countElt->getAs(); + ActivePackExpansions.back().isPackExpansion = + (countExpansion != nullptr); + + auto elt = expansion->getPatternType().transform(*this); + if (countExpansion) + elt = PackExpansionType::get(elt, countExpansion->getCountType()); + elts.push_back(elt); + + ActivePackExpansions.back().index++; + } + ActivePackExpansions.pop_back(); + + if (elts.size() == 1) + return elts[0]; + return PackType::get(CS.getASTContext(), elts); + } else { + ActivePackExpansions.push_back({true, 0}); + auto patternType = expansion->getPatternType().transform(*this); + ActivePackExpansions.pop_back(); + return PackExpansionType::get(patternType, countType); + } + } + // If this is a dependent member type for which we end up simplifying // the base to a non-type-variable, perform lookup. if (auto depMemTy = dyn_cast(type.getPointer())) { // Simplify the base. - Type newBase = simplifyTypeImpl(depMemTy->getBase(), getFixedTypeFn); + Type newBase = depMemTy->getBase().transform(*this); if (newBase->isPlaceholder()) { - return PlaceholderType::get(getASTContext(), depMemTy); + return PlaceholderType::get(CS.getASTContext(), depMemTy); } // If nothing changed, we're done. @@ -3760,7 +3833,7 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, if (lookupBaseType->mayHaveMembers() || lookupBaseType->is()) { auto *proto = assocType->getProtocol(); - auto conformance = DC->getParentModule()->lookupConformance( + auto conformance = CS.DC->getParentModule()->lookupConformance( lookupBaseType, proto); if (!conformance) { // FIXME: This regresses diagnostics if removed, but really the @@ -3774,9 +3847,9 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, // so the concrete dependent member type is considered a "hole" in // order to continue solving. auto memberTy = DependentMemberType::get(lookupBaseType, assocType); - if (shouldAttemptFixes() && - getPhase() == ConstraintSystemPhase::Solving) { - return PlaceholderType::get(getASTContext(), memberTy); + if (CS.shouldAttemptFixes() && + CS.getPhase() == ConstraintSystemPhase::Solving) { + return PlaceholderType::get(CS.getASTContext(), memberTy); } return memberTy; @@ -3792,7 +3865,14 @@ Type ConstraintSystem::simplifyTypeImpl(Type type, } return type; - }); + } +}; + +} // end anonymous namespace + +Type ConstraintSystem::simplifyTypeImpl(Type type, + llvm::function_ref getFixedTypeFn) const { + return type.transform(TypeSimplifier(*this, getFixedTypeFn)); } Type ConstraintSystem::simplifyType(Type type) const { diff --git a/lib/Sema/TypeCheckType.cpp b/lib/Sema/TypeCheckType.cpp index 08315473e14b1..44cead7af37fa 100644 --- a/lib/Sema/TypeCheckType.cpp +++ b/lib/Sema/TypeCheckType.cpp @@ -918,11 +918,22 @@ static Type applyGenericArguments(Type type, TypeResolution resolution, assert(found != matcher.pairs.end()); auto arg = found->rhs; - if (auto *expansionType = arg->getAs()) - arg = expansionType->getPatternType(); - if (arg->isParameterPack()) - arg = PackType::getSingletonPackExpansion(arg); + // PackMatcher will always produce a PackExpansionType as the + // arg for a pack parameter, if necessary by wrapping a PackType + // in one. (It's a weird representation.) Look for that pattern + // and unwrap the pack. Otherwise, we must have matched with a + // single component which happened to be an expansion; wrap that + // in a PackType. In either case, we always want arg to end up + // a PackType. + if (auto *expansionType = arg->getAs()) { + auto pattern = expansionType->getPatternType(); + if (auto pack = pattern->getAs()) { + arg = pack; + } else { + arg = PackType::get(ctx, {expansionType}); + } + } args.push_back(arg); } From 5229198ec4ad98523cca7766c5d228496e382851 Mon Sep 17 00:00:00 2001 From: John McCall Date: Sat, 25 Mar 2023 18:52:33 -0400 Subject: [PATCH 19/30] Remove some dead code; these operations are now done directly in transformRec. --- include/swift/AST/Types.h | 8 -- lib/AST/ParameterPack.cpp | 179 -------------------------------------- 2 files changed, 187 deletions(-) diff --git a/include/swift/AST/Types.h b/include/swift/AST/Types.h index 8130c06f51682..5600e9be54237 100644 --- a/include/swift/AST/Types.h +++ b/include/swift/AST/Types.h @@ -2449,8 +2449,6 @@ class TupleType final : public TypeBase, public llvm::FoldingSetNode, bool containsPackExpansionType() const; - Type flattenPackTypes(); - private: TupleType(ArrayRef elements, const ASTContext *CanCtx, RecursiveTypeProperties properties) @@ -3520,8 +3518,6 @@ class AnyFunctionType : public TypeBase { static bool containsPackExpansionType(ArrayRef params); - AnyFunctionType *flattenPackTypes(); - static void printParams(ArrayRef Params, raw_ostream &OS, const PrintOptions &PO = PrintOptions()); static void printParams(ArrayRef Params, ASTPrinter &Printer, @@ -6842,8 +6838,6 @@ class PackType final : public TypeBase, public llvm::FoldingSetNode, bool containsPackExpansionType() const; - PackType *flattenPackTypes(); - CanTypeWrapper getReducedShape(); public: @@ -6938,8 +6932,6 @@ class PackExpansionType : public TypeBase, public llvm::FoldingSetNode { /// Retrieves the count type of this pack expansion. Type getCountType() const { return countType; } - PackExpansionType *expand(); - CanType getReducedShape(); public: diff --git a/lib/AST/ParameterPack.cpp b/lib/AST/ParameterPack.cpp index fb2e93b80219d..5a831a534c14f 100644 --- a/lib/AST/ParameterPack.cpp +++ b/lib/AST/ParameterPack.cpp @@ -94,60 +94,6 @@ PackType *TypeBase::getPackSubstitutionAsPackType() { } } -/// G<{X1, ..., Xn}, {Y1, ..., Yn}>... => {G, ..., G}... -PackExpansionType *PackExpansionType::expand() { - auto countType = getCountType(); - auto *countPack = countType->getAs(); - if (countPack == nullptr) - return this; - - auto patternType = getPatternType(); - if (patternType->is()) - return this; - - unsigned j = 0; - SmallVector expandedTypes; - for (auto type : countPack->getElementTypes()) { - Type expandedCount; - if (auto *expansion = type->getAs()) - expandedCount = expansion->getCountType(); - - auto expandedPattern = patternType.transformRec( - [&](Type t) -> Optional { - if (t->is()) - return t; - - if (auto *nestedPack = t->getAs()) { - auto nestedPackElts = nestedPack->getElementTypes(); - if (j < nestedPackElts.size()) { - if (expandedCount) { - if (auto *expansion = nestedPackElts[j]->getAs()) - return expansion->getPatternType(); - } else { - return nestedPackElts[j]; - } - } - - return ErrorType::get(t->getASTContext()); - } - - return None; - }); - - if (expandedCount) { - expandedTypes.push_back(PackExpansionType::get(expandedPattern, - expandedCount)); - } else { - expandedTypes.push_back(expandedPattern); - } - - ++j; - } - - auto *packType = PackType::get(getASTContext(), expandedTypes); - return PackExpansionType::get(packType, countType); -} - CanType PackExpansionType::getReducedShape() { auto reducedShape = countType->getReducedShape(); if (reducedShape == getASTContext().TheEmptyTupleType) @@ -174,54 +120,6 @@ bool CanTupleType::containsPackExpansionTypeImpl(CanTupleType tuple) { return false; } -/// (W, {X, Y}..., Z) => (W, X, Y, Z) -Type TupleType::flattenPackTypes() { - bool anyChanged = false; - SmallVector elts; - - for (unsigned i = 0, e = getNumElements(); i < e; ++i) { - auto elt = getElement(i); - - if (auto *expansionType = elt.getType()->getAs()) { - if (auto *packType = expansionType->getPatternType()->getAs()) { - if (!anyChanged) { - elts.append(getElements().begin(), getElements().begin() + i); - anyChanged = true; - } - - bool first = true; - for (auto packElt : packType->getElementTypes()) { - if (first) { - elts.push_back(TupleTypeElt(packElt, elt.getName())); - first = false; - continue; - } - elts.push_back(TupleTypeElt(packElt)); - } - - continue; - } - } - - if (anyChanged) - elts.push_back(elt); - } - - if (!anyChanged) - return this; - - // If pack substitution yields a single-element tuple, the tuple - // structure is flattened to produce the element type. - if (elts.size() == 1) { - auto type = elts.front().getType(); - if (!type->is() && !type->is()) { - return type; - } - } - - return TupleType::get(elts, getASTContext()); -} - bool AnyFunctionType::containsPackExpansionType(ArrayRef params) { for (auto param : params) { if (param.getPlainType()->is()) @@ -231,50 +129,6 @@ bool AnyFunctionType::containsPackExpansionType(ArrayRef params) { return false; } -/// (W, {X, Y}..., Z) -> T => (W, X, Y, Z) -> T -AnyFunctionType *AnyFunctionType::flattenPackTypes() { - bool anyChanged = false; - SmallVector params; - - for (unsigned i = 0, e = getParams().size(); i < e; ++i) { - auto param = getParams()[i]; - - if (auto *expansionType = param.getPlainType()->getAs()) { - if (auto *packType = expansionType->getPatternType()->getAs()) { - if (!anyChanged) { - params.append(getParams().begin(), getParams().begin() + i); - anyChanged = true; - } - - bool first = true; - for (auto packElt : packType->getElementTypes()) { - if (first) { - params.push_back(param.withType(packElt)); - first = false; - continue; - } - params.push_back(param.withType(packElt).getWithoutLabels()); - } - - continue; - } - } - - if (anyChanged) - params.push_back(param); - } - - if (!anyChanged) - return this; - - if (auto *genericFuncType = getAs()) { - return GenericFunctionType::get(genericFuncType->getGenericSignature(), - params, getResult(), getExtInfo()); - } else { - return FunctionType::get(params, getResult(), getExtInfo()); - } -} - bool PackType::containsPackExpansionType() const { for (auto type : getElementTypes()) { if (type->is()) @@ -284,39 +138,6 @@ bool PackType::containsPackExpansionType() const { return false; } -/// {W, {X, Y}..., Z} => {W, X, Y, Z} -PackType *PackType::flattenPackTypes() { - bool anyChanged = false; - SmallVector elts; - - for (unsigned i = 0, e = getNumElements(); i < e; ++i) { - auto elt = getElementType(i); - - if (auto *expansionType = elt->getAs()) { - if (auto *packType = expansionType->getPatternType()->getAs()) { - if (!anyChanged) { - elts.append(getElementTypes().begin(), getElementTypes().begin() + i); - anyChanged = true; - } - - for (auto packElt : packType->getElementTypes()) { - elts.push_back(packElt); - } - - continue; - } - } - - if (anyChanged) - elts.push_back(elt); - } - - if (!anyChanged) - return this; - - return PackType::get(getASTContext(), elts); -} - template static CanPackType getReducedShapeOfPack(const ASTContext &ctx, const T &elementTypes) { From 9f74a1164a173b1cc4c41f4e9d6d1bc85629952d Mon Sep 17 00:00:00 2001 From: John McCall Date: Mon, 27 Mar 2023 17:15:24 -0400 Subject: [PATCH 20/30] Fix the dumping of AbstractionPatterns with substitutions --- lib/SIL/IR/AbstractionPattern.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index 6f8e375959366..8ab7eb4790691 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -1373,12 +1373,12 @@ static void printGenerics(raw_ostream &out, const AbstractionPattern &pattern) { // It'd be really nice if we could get these interleaved with the types. if (auto subs = pattern.getGenericSubstitutions()) { out << "@<"; - bool first = false; + bool first = true; for (auto sub : subs.getReplacementTypes()) { if (!first) { out << ","; } else { - first = true; + first = false; } out << sub; } From af752fbaa4a7340478e003c156511d1723b67fac Mon Sep 17 00:00:00 2001 From: John McCall Date: Mon, 27 Mar 2023 17:15:57 -0400 Subject: [PATCH 21/30] Pass down the substitutions of the original pattern when extracting a subst abstraction pattern from a generic nominal type. --- lib/SIL/IR/AbstractionPattern.cpp | 19 +++++++++++-------- test/SILGen/variadic-generic-closures.swift | 6 ++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/lib/SIL/IR/AbstractionPattern.cpp b/lib/SIL/IR/AbstractionPattern.cpp index 8ab7eb4790691..931da4a9a8d9c 100644 --- a/lib/SIL/IR/AbstractionPattern.cpp +++ b/lib/SIL/IR/AbstractionPattern.cpp @@ -1961,8 +1961,11 @@ class SubstFunctionTypePatternVisitor : CanType(CanMetatypeType::get(substInstance)); } - CanType handleGenericNominalType(CanType orig, CanType subst, - CanGenericSignature origSig) { + CanType handleGenericNominalType(AbstractionPattern origPattern, CanType subst) { + CanType orig = origPattern.getType(); + CanGenericSignature origSig = origPattern.getGenericSignatureOrNull(); + auto origPatternSubs = origPattern.getGenericSubstitutions(); + // If there are no loose type parameters in the pattern here, we don't need // to do a recursive visit at all. if (!orig->hasTypeParameter() @@ -2007,6 +2010,7 @@ class SubstFunctionTypePatternVisitor if (differentOrigClass) { orig = subst; origSig = TC.getCurGenericSignature(); + origPatternSubs = SubstitutionMap(); assert((!subst->hasTypeParameter() || origSig) && "lowering mismatched interface types in a context without " "a generic signature"); @@ -2028,7 +2032,8 @@ class SubstFunctionTypePatternVisitor ->getCanonicalType(); replacementTypes[gp->getCanonicalType()->castTo()] - = visit(substParamTy, AbstractionPattern(origSig, origParamTy)); + = visit(substParamTy, + AbstractionPattern(origPatternSubs, origSig, origParamTy)); } auto newSubMap = SubstitutionMap::get(nomGenericSig, @@ -2048,8 +2053,7 @@ class SubstFunctionTypePatternVisitor // If the type is generic (because it's a nested type in a generic context), // process the generic type bindings. if (!isa(nomDecl) && nomDecl->isGenericContext()) { - return handleGenericNominalType(pattern.getType(), nom, - pattern.getGenericSignatureOrNull()); + return handleGenericNominalType(pattern, nom); } // Otherwise, there are no structural type parameters to visit. @@ -2058,8 +2062,7 @@ class SubstFunctionTypePatternVisitor CanType visitBoundGenericType(CanBoundGenericType bgt, AbstractionPattern pattern) { - return handleGenericNominalType(pattern.getType(), bgt, - pattern.getGenericSignatureOrNull()); + return handleGenericNominalType(pattern, bgt); } CanType visitPackType(CanPackType pack, AbstractionPattern pattern) { @@ -2085,7 +2088,7 @@ class SubstFunctionTypePatternVisitor // the pack substitution for that parameter recorded in the pattern. // Remember that we're within an expansion. - // FIXME: when we introduce PackReferenceType we'll need to be clear + // FIXME: when we introduce PackElementType we'll need to be clear // about which pack expansions to treat this way. llvm::SaveAndRestore scope(WithinExpansion, true); diff --git a/test/SILGen/variadic-generic-closures.swift b/test/SILGen/variadic-generic-closures.swift index b9a026f4f1573..16cbb6c847b25 100644 --- a/test/SILGen/variadic-generic-closures.swift +++ b/test/SILGen/variadic-generic-closures.swift @@ -9,3 +9,9 @@ public struct G {} public func caller(fn: (repeat G) -> ()) { fn(repeat G()) } + +// rdar://107108803 +public struct UsesG { + public init(builder: (repeat G) -> E) {} +} +UsesG { a, b, c in 0 } From cf94aa72530cdefc24299e02ef2ef3106f19cb7f Mon Sep 17 00:00:00 2001 From: John McCall Date: Mon, 27 Mar 2023 17:17:18 -0400 Subject: [PATCH 22/30] Assert that we don't produce a SILFunction with a type with opened type parameters. --- lib/SIL/IR/SILFunction.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/SIL/IR/SILFunction.cpp b/lib/SIL/IR/SILFunction.cpp index f3107f50a95e6..eca0c14b3e613 100644 --- a/lib/SIL/IR/SILFunction.cpp +++ b/lib/SIL/IR/SILFunction.cpp @@ -186,6 +186,10 @@ void SILFunction::init( IsExactSelfClass_t isExactSelfClass, IsDistributed_t isDistributed, IsRuntimeAccessible_t isRuntimeAccessible) { setName(Name); + + assert(!LoweredType->hasTypeParameter() && + "function type has open type parameters"); + this->LoweredType = LoweredType; this->GenericEnv = genericEnv; this->SpecializationInfo = nullptr; From 070cfb07960ffcb15bd2e2c0784514b0b260c726 Mon Sep 17 00:00:00 2001 From: John McCall Date: Mon, 27 Mar 2023 17:23:59 -0400 Subject: [PATCH 23/30] Make BlackHoleInitialization support pack initialization Fixes rdar://107151145 --- lib/SILGen/Initialization.h | 9 +++++++++ lib/SILGen/SILGenDecl.cpp | 9 +++++++++ test/SILGen/variadic-generic-tuples.swift | 8 ++++++++ 3 files changed, 26 insertions(+) diff --git a/lib/SILGen/Initialization.h b/lib/SILGen/Initialization.h index 1bced77f1e065..93081d4970c56 100644 --- a/lib/SILGen/Initialization.h +++ b/lib/SILGen/Initialization.h @@ -491,6 +491,15 @@ class BlackHoleInitialization : public Initialization { return buf; } + bool canPerformPackExpansionInitialization() const override { + return true; + } + + void performPackExpansionInitialization(SILGenFunction &SGF, + SILLocation loc, + SILValue indexWithinComponent, + llvm::function_ref fn) override; + void copyOrInitValueInto(SILGenFunction &SGF, SILLocation loc, ManagedValue value, bool isInit) override; diff --git a/lib/SILGen/SILGenDecl.cpp b/lib/SILGen/SILGenDecl.cpp index 6ead74f108574..9db9e9ac92a6e 100644 --- a/lib/SILGen/SILGenDecl.cpp +++ b/lib/SILGen/SILGenDecl.cpp @@ -2172,6 +2172,15 @@ void SILGenFunction::destroyLocalVariable(SILLocation silLoc, VarDecl *vd) { llvm_unreachable("unhandled case"); } +void BlackHoleInitialization::performPackExpansionInitialization( + SILGenFunction &SGF, + SILLocation loc, + SILValue indexWithinComponent, + llvm::function_ref fn) { + BlackHoleInitialization subInit; + fn(&subInit); +} + void BlackHoleInitialization::copyOrInitValueInto(SILGenFunction &SGF, SILLocation loc, ManagedValue value, bool isInit) { // Normally we do not do anything if we have a black hole diff --git a/test/SILGen/variadic-generic-tuples.swift b/test/SILGen/variadic-generic-tuples.swift index b1585c5804b98..bde4743a3762d 100644 --- a/test/SILGen/variadic-generic-tuples.swift +++ b/test/SILGen/variadic-generic-tuples.swift @@ -283,3 +283,11 @@ struct MemberwiseTupleHolder { func callVariadicMemberwiseInit() -> MemberwiseTupleHolder { return MemberwiseTupleHolder(content: (0, "hello")) } + +// rdar://107151145: when we tuple-destructure a black hole +// initialization, the resulting element initializations need to +// handle pack expansion initialization +struct EmptyContainer {} +func f(_: repeat each T) { + let _ = (repeat EmptyContainer()) +} From a4edc3e58de9300bb2dd9b7ae9d716afe31c54cd Mon Sep 17 00:00:00 2001 From: John McCall Date: Tue, 28 Mar 2023 14:32:58 -0400 Subject: [PATCH 24/30] Teach ResultPlan to handle packs correctly when we're not emitting into an Initialization. rdar://107161241 --- lib/SILGen/ResultPlan.cpp | 135 ++++++++++++++-------- lib/SILGen/ResultPlan.h | 10 +- test/SILGen/variadic-generic-tuples.swift | 32 +++++ 3 files changed, 128 insertions(+), 49 deletions(-) diff --git a/lib/SILGen/ResultPlan.cpp b/lib/SILGen/ResultPlan.cpp index 99875cfc01f7c..979ce327e82bf 100644 --- a/lib/SILGen/ResultPlan.cpp +++ b/lib/SILGen/ResultPlan.cpp @@ -27,6 +27,15 @@ using namespace Lowering; // Result Plans //===----------------------------------------------------------------------===// +void ResultPlan::finishAndAddTo(SILGenFunction &SGF, SILLocation loc, + ArrayRef &directResults, + SILValue bridgedForeignError, + RValue &result) { + auto rvalue = finish(SGF, loc, directResults, bridgedForeignError); + assert(!rvalue.isInContext()); + result.addElement(std::move(rvalue)); +} + namespace { /// A result plan for evaluating an indirect result into the address @@ -343,14 +352,16 @@ class ScalarResultPlan final : public ResultPlan { /// using a temporary buffer initialized by a sub-plan. class InitValueFromTemporaryResultPlan final : public ResultPlan { Initialization *init; + CanType substType; ResultPlanPtr subPlan; std::unique_ptr temporary; public: InitValueFromTemporaryResultPlan( - Initialization *init, ResultPlanPtr &&subPlan, + Initialization *init, CanType substType, + ResultPlanPtr &&subPlan, std::unique_ptr &&temporary) - : init(init), subPlan(std::move(subPlan)), + : init(init), substType(substType), subPlan(std::move(subPlan)), temporary(std::move(temporary)) {} RValue finish(SILGenFunction &SGF, SILLocation loc, @@ -362,10 +373,15 @@ class InitValueFromTemporaryResultPlan final : public ResultPlan { (void)subResult; ManagedValue value = temporary->getManagedAddress(); - init->copyOrInitValueInto(SGF, loc, value, /*init*/ true); - init->finishInitialization(SGF); - return RValue::forInContext(); + if (init) { + init->copyOrInitValueInto(SGF, loc, value, /*init*/ true); + init->finishInitialization(SGF); + + return RValue::forInContext(); + } + + return RValue(SGF, loc, substType, value); } void @@ -414,28 +430,30 @@ class PackExpansionResultPlan : public ResultPlan { public: PackExpansionResultPlan(ResultPlanBuilder &builder, SILValue packAddr, - MutableArrayRef inits, + Optional> inits, AbstractionPattern origExpansionType, CanTupleEltTypeArrayRef substEltTypes) : PackAddr(packAddr) { + assert(!inits || inits->size() == substEltTypes.size()); + auto packTy = packAddr->getType().castTo(); auto formalPackType = CanPackType::get(packTy->getASTContext(), substEltTypes); auto origPatternType = origExpansionType.getPackExpansionPatternType(); - ComponentPlans.reserve(inits.size()); - for (auto i : indices(inits)) { - auto &init = inits[i]; + ComponentPlans.reserve(substEltTypes.size()); + for (auto i : indices(substEltTypes)) { + Initialization *init = inits ? (*inits)[i].get() : nullptr; CanType substEltType = substEltTypes[i]; if (isa(substEltType)) { ComponentPlans.emplace_back( builder.buildPackExpansionIntoPack(packAddr, formalPackType, i, - init.get(), origPatternType)); + init, origPatternType)); } else { ComponentPlans.emplace_back( builder.buildScalarIntoPack(packAddr, formalPackType, i, - init.get(), origPatternType)); + init, origPatternType)); } } } @@ -451,6 +469,16 @@ class PackExpansionResultPlan : public ResultPlan { return RValue::forInContext(); } + void finishAndAddTo(SILGenFunction &SGF, SILLocation loc, + ArrayRef &directResults, + SILValue bridgedForeignError, + RValue &result) override { + for (auto &componentPlan : ComponentPlans) { + componentPlan->finishAndAddTo(SGF, loc, directResults, + bridgedForeignError, result); + } + } + void gatherIndirectResultAddrs(SILGenFunction &SGF, SILLocation loc, SmallVectorImpl &outList) const override { outList.push_back(PackAddr); @@ -557,19 +585,27 @@ class PackTransformResultPlan final : public ResultPlan { /// components. class TupleRValueResultPlan final : public ResultPlan { CanTupleType substType; - SmallVector eltPlans; + + SmallVector origEltPlans; public: TupleRValueResultPlan(ResultPlanBuilder &builder, AbstractionPattern origType, CanTupleType substType) : substType(substType) { // Create plans for all the elements. - eltPlans.reserve(substType->getNumElements()); - for (auto i : indices(substType->getElementTypes())) { - AbstractionPattern origEltType = origType.getTupleElementType(i); - CanType substEltType = substType.getElementType(i); - eltPlans.push_back(builder.build(nullptr, origEltType, substEltType)); - } + origEltPlans.reserve(substType->getNumElements()); + origType.forEachTupleElement(substType, + [&](TupleElementGenerator &origElt) { + AbstractionPattern origEltType = origElt.getOrigType(); + auto substEltTypes = origElt.getSubstTypes(); + if (!origElt.isOrigPackExpansion()) { + origEltPlans.push_back( + builder.build(nullptr, origEltType, substEltTypes[0])); + } else { + origEltPlans.push_back( + builder.buildForPackExpansion(None, origEltType, substEltTypes)); + } + }); } RValue finish(SILGenFunction &SGF, SILLocation loc, @@ -578,10 +614,9 @@ class TupleRValueResultPlan final : public ResultPlan { RValue tupleRV(substType); // Finish all the component tuples. - for (auto &plan : eltPlans) { - RValue eltRV = - plan->finish(SGF, loc, directResults, bridgedForeignError); - tupleRV.addElement(std::move(eltRV)); + for (auto &plan : origEltPlans) { + plan->finishAndAddTo(SGF, loc, directResults, bridgedForeignError, + tupleRV); } return tupleRV; @@ -590,8 +625,8 @@ class TupleRValueResultPlan final : public ResultPlan { void gatherIndirectResultAddrs(SILGenFunction &SGF, SILLocation loc, SmallVectorImpl &outList) const override { - for (const auto &eltPlan : eltPlans) { - eltPlan->gatherIndirectResultAddrs(SGF, loc, outList); + for (const auto &plan : origEltPlans) { + plan->gatherIndirectResultAddrs(SGF, loc, outList); } } }; @@ -1143,10 +1178,10 @@ ResultPlanPtr ResultPlanBuilder::buildForScalar(Initialization *init, } ResultPlanPtr ResultPlanBuilder:: - buildForPackExpansion(MutableArrayRef inits, + buildForPackExpansion(Optional> inits, AbstractionPattern origExpansionType, CanTupleEltTypeArrayRef substTypes) { - assert(inits.size() == substTypes.size()); + assert(!inits || inits->size() == substTypes.size()); // Pack expansions in the original result type always turn into // a single @pack_out result. @@ -1155,7 +1190,7 @@ ResultPlanPtr ResultPlanBuilder:: auto packTy = result.getSILStorageType(SGF.SGM.M, calleeTypeInfo.substFnType, SGF.getTypeExpansionContext()); - assert(packTy.castTo()->getNumElements() == inits.size()); + assert(packTy.castTo()->getNumElements() == substTypes.size()); // TODO: try to just forward a single pack @@ -1236,7 +1271,6 @@ ResultPlanBuilder::buildScalarIntoPack(SILValue packAddr, Initialization *init, AbstractionPattern origType) { assert(!origType.isPackExpansion()); - assert(init); auto substType = formalPackType.getElementType(componentIndex); assert(!isa(substType)); @@ -1245,7 +1279,8 @@ ResultPlanBuilder::buildScalarIntoPack(SILValue packAddr, ->getElementType(componentIndex); SILResultInfo resultInfo(loweredEltType, ResultConvention::Indirect); - // Use the normal scalar emission path. + // Use the normal scalar emission path to gather an indirect result + // of that type. auto plan = buildForScalar(init, origType, substType, resultInfo); // Immediately gather the indirect result. @@ -1265,38 +1300,44 @@ ResultPlanBuilder::buildScalarIntoPack(SILValue packAddr, ResultPlanPtr ResultPlanBuilder::buildForTuple(Initialization *init, AbstractionPattern origType, CanTupleType substType) { - // If we don't have an initialization for the tuple, just build the - // individual components. - if (!init) { - return ResultPlanPtr(new TupleRValueResultPlan(*this, origType, substType)); - } - - // Okay, we have an initialization for the tuple that we need to emit into. - - // If we can just split the initialization, do so. - if (init->canSplitIntoTupleElements()) { + // If we have an initialization, and we can split it, do so. + if (init && init->canSplitIntoTupleElements()) { return ResultPlanPtr( new TupleInitializationResultPlan(*this, init, origType, substType)); } - // Otherwise, we're going to have to call copyOrInitValueInto, which only - // takes a single value. - - // If the tuple is address-only, we'll get much better code if we - // emit into a single buffer. + // Otherwise, if the tuple contains a pack expansion, we'll need to + // initialize a single buffer one way or another: either we're giving + // this to RValue (which wants a single value for tuples with pack + // expansions) or we'll have to call copyOrInitValueInto on init + // (which expects a single value). Create a temporary, build into + // that, and then call the initialization. + // + // We also use this path when we have an init and the type is + // address-only, because we'll need to call copyOrInitValueInto and + // we'll get better code by building that up indirectly. But we don't + // do that if we're not using lowered addresses because we prefer to + // build tuples with scalar operations. auto &substTL = SGF.getTypeLowering(substType); + assert(substTL.isAddressOnly() || !substType.containsPackExpansionType()); if (substTL.isAddressOnly() && (substType.containsPackExpansionType() || - SGF.F.getConventions().useLoweredAddresses())) { + (init != nullptr && SGF.F.getConventions().useLoweredAddresses()))) { // Create a temporary. auto temporary = SGF.emitTemporary(loc, substTL); // Build a sub-plan to emit into the temporary. auto subplan = buildForTuple(temporary.get(), origType, substType); - // Make a plan to initialize into that. + // Make a plan to produce the final result from that. return ResultPlanPtr(new InitValueFromTemporaryResultPlan( - init, std::move(subplan), std::move(temporary))); + init, substType, std::move(subplan), std::move(temporary))); + } + + // If we don't have an initialization, just build the individual + // components. + if (!init) { + return ResultPlanPtr(new TupleRValueResultPlan(*this, origType, substType)); } // Build a sub-plan that doesn't know about the initialization. diff --git a/lib/SILGen/ResultPlan.h b/lib/SILGen/ResultPlan.h index 3322cc8ba1cc1..f194fdfc86625 100644 --- a/lib/SILGen/ResultPlan.h +++ b/lib/SILGen/ResultPlan.h @@ -42,6 +42,12 @@ class ResultPlan { virtual RValue finish(SILGenFunction &SGF, SILLocation loc, ArrayRef &directResults, SILValue bridgedForeignError) = 0; + + virtual void finishAndAddTo(SILGenFunction &SGF, SILLocation loc, + ArrayRef &directResults, + SILValue bridgedForeignError, + RValue &result); + virtual ~ResultPlan() = default; /// Defers the emission of the given breadcrumb until \p finish is invoked. @@ -92,8 +98,8 @@ struct ResultPlanBuilder { ResultPlanPtr buildForTuple(Initialization *emitInto, AbstractionPattern origType, CanTupleType substType); - ResultPlanPtr buildForPackExpansion(MutableArrayRef inits, - AbstractionPattern origPatternType, + ResultPlanPtr buildForPackExpansion(Optional> inits, + AbstractionPattern origExpansionType, CanTupleEltTypeArrayRef substTypes); ResultPlanPtr buildPackExpansionIntoPack(SILValue packAddr, CanPackType formalPackType, diff --git a/test/SILGen/variadic-generic-tuples.swift b/test/SILGen/variadic-generic-tuples.swift index bde4743a3762d..d86ebecb0e395 100644 --- a/test/SILGen/variadic-generic-tuples.swift +++ b/test/SILGen/variadic-generic-tuples.swift @@ -291,3 +291,35 @@ struct EmptyContainer {} func f(_: repeat each T) { let _ = (repeat EmptyContainer()) } + +// rdar://107161241: handle receiving tuples that originally contained +// packs that are not emitted into an initialization +struct FancyTuple { + var x: (repeat each T) + + func makeTuple() -> (repeat each T) { + return (repeat each x.element) + } +} + +// CHECK: sil{{.*}} @$s4main23testFancyTuple_concreteyyF : +// Create a pack to receive the results from makeTuple. +// CHECK: [[PACK:%.*]] = alloc_pack $Pack{Int, String, Bool} +// CHECK-NEXT: [[INT_ADDR:%.*]] = alloc_stack $Int +// CHECK-NEXT: [[INT_INDEX:%.*]] = scalar_pack_index 0 of $Pack{Int, String, Bool} +// CHECK-NEXT: pack_element_set [[INT_ADDR]] : $*Int into [[INT_INDEX]] of [[PACK]] : $*Pack{Int, String, Bool} +// CHECK-NEXT: [[STRING_ADDR:%.*]] = alloc_stack $String +// CHECK-NEXT: [[STRING_INDEX:%.*]] = scalar_pack_index 1 of $Pack{Int, String, Bool} +// CHECK-NEXT: pack_element_set [[STRING_ADDR]] : $*String into [[STRING_INDEX]] of [[PACK]] : $*Pack{Int, String, Bool} +// CHECK-NEXT: [[BOOL_ADDR:%.*]] = alloc_stack $Bool +// CHECK-NEXT: [[BOOL_INDEX:%.*]] = scalar_pack_index 2 of $Pack{Int, String, Bool} +// CHECK-NEXT: pack_element_set [[BOOL_ADDR]] : $*Bool into [[BOOL_INDEX]] of [[PACK]] : $*Pack{Int, String, Bool} +// CHECK: [[FN:%.*]] = function_ref @$s4main10FancyTupleV04makeC0xxQp_tyF +// CHECK-NEXT: apply [[FN]]([[PACK]], {{.*}}) +func testFancyTuple_concrete() { + FancyTuple(x: (1, "hi", false)).makeTuple() +} + +func testFancyTuple_pack(values: repeat each T) { + FancyTuple(x: (1, "hi", repeat each values, false)).makeTuple() +} From 6badca63ae3f371bbe6e2da243fbd330ce0e2055 Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Thu, 23 Mar 2023 12:11:16 -0700 Subject: [PATCH 25/30] [PackExpansionMatcher] Adjust ParamPackMatcher to account for labels First step in preparation to unify different matchers which is effectively no-op because function parameters do not have labels. - Common prefix/suffix should account for presence of labels - Labeled parameters cannot appear in the region absorbed by a pack expansion type. --- lib/AST/PackExpansionMatcher.cpp | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/lib/AST/PackExpansionMatcher.cpp b/lib/AST/PackExpansionMatcher.cpp index 3c78e36b7ba58..af9cc99ad0aa3 100644 --- a/lib/AST/PackExpansionMatcher.cpp +++ b/lib/AST/PackExpansionMatcher.cpp @@ -142,6 +142,9 @@ bool ParamPackMatcher::match() { auto lhsParam = lhsParams[lhsIdx]; auto rhsParam = rhsParams[rhsIdx]; + if (lhsParam.getLabel() != rhsParam.getLabel()) + break; + // FIXME: Check flags auto lhsType = lhsParam.getPlainType(); @@ -170,6 +173,9 @@ bool ParamPackMatcher::match() { // FIXME: Check flags + if (lhsParam.getLabel() != rhsParam.getLabel()) + break; + auto lhsType = lhsParam.getPlainType(); auto rhsType = rhsParam.getPlainType(); @@ -203,6 +209,9 @@ bool ParamPackMatcher::match() { SmallVector rhsTypes; for (auto rhsParam : rhsParams) { + if (rhsParam.hasLabel()) + return true; + // FIXME: Check rhs flags rhsTypes.push_back(rhsParam.getPlainType()); } @@ -224,6 +233,9 @@ bool ParamPackMatcher::match() { SmallVector lhsTypes; for (auto lhsParam : lhsParams) { + if (lhsParam.hasLabel()) + return true; + // FIXME: Check lhs flags lhsTypes.push_back(lhsParam.getPlainType()); } From c9b8140de4f79343448609978595f7398ab4dcc4 Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Thu, 23 Mar 2023 13:36:13 -0700 Subject: [PATCH 26/30] [CSSimplify] NFC: Refactor `TupleMatcher` to consolidate pack expansion matching --- lib/Sema/CSSimplify.cpp | 125 ++++++++++++++-------------------------- 1 file changed, 43 insertions(+), 82 deletions(-) diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index 62c8cea74d9a4..2c6824cd1909e 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -2069,13 +2069,19 @@ class TupleMatcher { TupleType *tuple2; public: + enum class MatchKind : uint8_t { + Equality, + Subtype, + Conversion, + }; + SmallVector pairs; bool hasLabelMismatch = false; TupleMatcher(TupleType *tuple1, TupleType *tuple2) - : tuple1(tuple1), tuple2(tuple2) {} + : tuple1(tuple1), tuple2(tuple2) {} - bool matchBind() { + bool match(MatchKind kind, ConstraintLocatorBuilder locator) { // FIXME: TuplePackMatcher should completely replace the non-variadic // case too eventually. if (tuple1->containsPackExpansionType() || @@ -2091,42 +2097,32 @@ class TupleMatcher { if (tuple1->getNumElements() != tuple2->getNumElements()) return true; - for (unsigned i = 0, n = tuple1->getNumElements(); i != n; ++i) { - const auto &elt1 = tuple1->getElement(i); - const auto &elt2 = tuple2->getElement(i); + switch (kind) { + case MatchKind::Equality: + return matchEquality(isInPatternMatchingContext(locator)); - // If the names don't match, we have a conflict. - if (elt1.getName() != elt2.getName()) - return true; + case MatchKind::Subtype: + return matchSubtype(); - pairs.emplace_back(elt1.getType(), elt2.getType(), i, i); + case MatchKind::Conversion: + return matchConversion(); } - - return false; } - bool matchInPatternMatchingContext() { - // FIXME: TuplePackMatcher should completely replace the non-variadic - // case too eventually. - if (tuple1->containsPackExpansionType() || - tuple2->containsPackExpansionType()) { - TuplePackMatcher matcher(tuple1, tuple2); - if (matcher.match()) - return true; - - pairs = matcher.pairs; - return false; - } - - if (tuple1->getNumElements() != tuple2->getNumElements()) - return true; - +private: + bool matchEquality(bool inPatternMatchingContext) { for (unsigned i = 0, n = tuple1->getNumElements(); i != n; ++i) { const auto &elt1 = tuple1->getElement(i); const auto &elt2 = tuple2->getElement(i); - if (elt1.hasName() && elt1.getName() != elt2.getName()) - return true; + if (inPatternMatchingContext) { + if (elt1.hasName() && elt1.getName() != elt2.getName()) + return true; + } else { + // If the names don't match, we have a conflict. + if (elt1.getName() != elt2.getName()) + return true; + } pairs.emplace_back(elt1.getType(), elt2.getType(), i, i); } @@ -2135,21 +2131,6 @@ class TupleMatcher { } bool matchSubtype() { - // FIXME: TuplePackMatcher should completely replace the non-variadic - // case too eventually. - if (tuple1->containsPackExpansionType() || - tuple2->containsPackExpansionType()) { - TuplePackMatcher matcher(tuple1, tuple2); - if (matcher.match()) - return true; - - pairs = matcher.pairs; - return false; - } - - if (tuple1->getNumElements() != tuple2->getNumElements()) - return true; - for (unsigned i = 0, n = tuple1->getNumElements(); i != n; ++i) { const auto &elt1 = tuple1->getElement(i); const auto &elt2 = tuple2->getElement(i); @@ -2173,18 +2154,6 @@ class TupleMatcher { } bool matchConversion() { - // FIXME: TuplePackMatcher should completely replace the non-variadic - // case too eventually. - if (tuple1->containsPackExpansionType() || - tuple2->containsPackExpansionType()) { - TuplePackMatcher matcher(tuple1, tuple2); - if (matcher.match()) - return true; - - pairs = matcher.pairs; - return false; - } - SmallVector sources; if (computeTupleShuffle(tuple1, tuple2, sources)) return true; @@ -2208,23 +2177,16 @@ ConstraintSystem::TypeMatchResult ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, ConstraintKind kind, TypeMatchOptions flags, ConstraintLocatorBuilder locator) { - TupleMatcher matcher(tuple1, tuple2); + using TupleMatchKind = TupleMatcher::MatchKind; ConstraintKind subkind; + TupleMatchKind matchKind; switch (kind) { case ConstraintKind::Bind: case ConstraintKind::Equal: { subkind = kind; - - if (isInPatternMatchingContext(locator)) { - if (matcher.matchInPatternMatchingContext()) - return getTypeMatchFailure(locator); - } else { - if (matcher.matchBind()) - return getTypeMatchFailure(locator); - } - + matchKind = TupleMatchKind::Equality; break; } @@ -2234,17 +2196,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, case ConstraintKind::Subtype: case ConstraintKind::BindToPointerType: { subkind = kind; - - if (matcher.matchSubtype()) - return getTypeMatchFailure(locator); - - if (matcher.hasLabelMismatch) { - // If we had a label mismatch, emit a warning. This is something we - // shouldn't permit, as it's more permissive than what a conversion would - // allow. Ideally we'd turn this into an error in Swift 6 mode. - recordFix(AllowTupleLabelMismatch::create( - *this, tuple1, tuple2, getConstraintLocator(locator))); - } + matchKind = TupleMatchKind::Subtype; break; } @@ -2252,11 +2204,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, case ConstraintKind::ArgumentConversion: case ConstraintKind::OperatorArgumentConversion: { subkind = ConstraintKind::Conversion; - - // Compute the element shuffles for conversions. - if (matcher.matchConversion()) - return getTypeMatchFailure(locator); - + matchKind = TupleMatchKind::Conversion; break; } @@ -2296,6 +2244,19 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, llvm_unreachable("Bad constraint kind in matchTupleTypes()"); } + TupleMatcher matcher(tuple1, tuple2); + + if (matcher.match(matchKind, locator)) + return getTypeMatchFailure(locator); + + if (matcher.hasLabelMismatch) { + // If we had a label mismatch, emit a warning. This is something we + // shouldn't permit, as it's more permissive than what a conversion would + // allow. Ideally we'd turn this into an error in Swift 6 mode. + recordFix(AllowTupleLabelMismatch::create(*this, tuple1, tuple2, + getConstraintLocator(locator))); + } + TypeMatchOptions subflags = getDefaultDecompositionOptions(flags); for (auto pair : matcher.pairs) { From eb475b44f6b0ac4d5b8831b7a037e750c6b9e93a Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Thu, 23 Mar 2023 14:17:35 -0700 Subject: [PATCH 27/30] [AST] NFC: Unify implementation for pack expansion matching for type lists --- include/swift/AST/PackExpansionMatcher.h | 73 +++++++--- lib/AST/PackExpansionMatcher.cpp | 172 ++++++++--------------- 2 files changed, 112 insertions(+), 133 deletions(-) diff --git a/include/swift/AST/PackExpansionMatcher.h b/include/swift/AST/PackExpansionMatcher.h index f24fd59e406bd..6d1acac35de48 100644 --- a/include/swift/AST/PackExpansionMatcher.h +++ b/include/swift/AST/PackExpansionMatcher.h @@ -63,49 +63,78 @@ class TuplePackMatcher { bool match(); }; -/// Performs a structural match of two lists of (unlabeled) function -/// parameters. +/// Performs a structural match of two lists of types. /// /// The invariant is that each list must only contain at most one pack /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. -class ParamPackMatcher { - ArrayRef lhsParams; - ArrayRef rhsParams; +class TypeListPackMatcher { + struct Element { + private: + Identifier label; + Type type; + ParameterTypeFlags flags; + + public: + Element(Identifier label, Type type, + ParameterTypeFlags flags = ParameterTypeFlags()) + : label(label), type(type), flags(flags) {} + + bool hasLabel() const { return !label.empty(); } + Identifier getLabel() const { return label; } + + Type getType() const { return type; } + + static Element from(const TupleTypeElt &tupleElt); + static Element from(const AnyFunctionType::Param &funcParam); + static Element from(Type type); + }; ASTContext &ctx; + SmallVector lhsElements; + SmallVector rhsElements; + +protected: + TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, + ArrayRef rhs); + + TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, + ArrayRef rhs); + + TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, ArrayRef rhs); + public: SmallVector pairs; - ParamPackMatcher(ArrayRef lhsParams, - ArrayRef rhsParams, - ASTContext &ctx); - bool match(); }; -/// Performs a structural match of two lists of types. +/// Performs a structural match of two lists of (unlabeled) function +/// parameters. /// /// The invariant is that each list must only contain at most one pack /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. -class PackMatcher { - ArrayRef lhsTypes; - ArrayRef rhsTypes; - - ASTContext &ctx; - +class ParamPackMatcher : public TypeListPackMatcher { public: - SmallVector pairs; - - PackMatcher(ArrayRef lhsTypes, - ArrayRef rhsTypes, - ASTContext &ctx); + ParamPackMatcher(ArrayRef lhsParams, + ArrayRef rhsParams, ASTContext &ctx) + : TypeListPackMatcher(ctx, lhsParams, rhsParams) {} +}; - bool match(); +/// Performs a structural match of two lists of types. +/// +/// The invariant is that each list must only contain at most one pack +/// expansion type. After collecting a common prefix and suffix, the +/// pack expansion on either side asborbs the remaining elements on the +/// other side. +class PackMatcher : public TypeListPackMatcher { +public: + PackMatcher(ArrayRef lhsTypes, ArrayRef rhsTypes, ASTContext &ctx) + : TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {} }; } // end namespace swift diff --git a/lib/AST/PackExpansionMatcher.cpp b/lib/AST/PackExpansionMatcher.cpp index af9cc99ad0aa3..54b804f0b83de 100644 --- a/lib/AST/PackExpansionMatcher.cpp +++ b/lib/AST/PackExpansionMatcher.cpp @@ -123,13 +123,60 @@ bool TuplePackMatcher::match() { return false; } -ParamPackMatcher::ParamPackMatcher( - ArrayRef lhsParams, - ArrayRef rhsParams, - ASTContext &ctx) - : lhsParams(lhsParams), rhsParams(rhsParams), ctx(ctx) {} +TypeListPackMatcher::Element +TypeListPackMatcher::Element::from(const TupleTypeElt &elt) { + return {elt.getName(), elt.getType()}; +} + +TypeListPackMatcher::Element +TypeListPackMatcher::Element::from(const AnyFunctionType::Param ¶m) { + return {param.getLabel(), param.getPlainType(), param.getParameterFlags()}; +} + +TypeListPackMatcher::Element TypeListPackMatcher::Element::from(Type type) { + return {/*label=*/Identifier(), type}; +} + +TypeListPackMatcher::TypeListPackMatcher(ASTContext &ctx, + ArrayRef lhsParams, + ArrayRef rhsParams) + : ctx(ctx) { + llvm::transform(lhsParams, std::back_inserter(lhsElements), + [&](const auto &elt) { return Element::from(elt); }); + llvm::transform(rhsParams, std::back_inserter(rhsElements), + [&](const auto &elt) { return Element::from(elt); }); +} + +TypeListPackMatcher::TypeListPackMatcher( + ASTContext &ctx, ArrayRef lhsParams, + ArrayRef rhsParams) + : ctx(ctx) { + llvm::transform(lhsParams, std::back_inserter(lhsElements), + [&](const auto &elt) { + assert(!elt.hasLabel()); + return Element::from(elt); + }); + llvm::transform(rhsParams, std::back_inserter(rhsElements), + [&](const auto &elt) { + assert(!elt.hasLabel()); + return Element::from(elt); + }); +} + +TypeListPackMatcher::TypeListPackMatcher(ASTContext &ctx, + ArrayRef lhsParams, + ArrayRef rhsParams) + : ctx(ctx) { + llvm::transform(lhsParams, std::back_inserter(lhsElements), + [&](const auto &elt) { return Element::from(elt); }); + llvm::transform(rhsParams, std::back_inserter(rhsElements), + [&](const auto &elt) { return Element::from(elt); }); +} + +bool TypeListPackMatcher::match() { + ArrayRef lhsParams(lhsElements); + ArrayRef rhsParams(rhsElements); -bool ParamPackMatcher::match() { unsigned minLength = std::min(lhsParams.size(), rhsParams.size()); // Consume the longest possible prefix where neither type in @@ -147,8 +194,8 @@ bool ParamPackMatcher::match() { // FIXME: Check flags - auto lhsType = lhsParam.getPlainType(); - auto rhsType = rhsParam.getPlainType(); + auto lhsType = lhsParam.getType(); + auto rhsType = rhsParam.getType(); if (lhsType->is() || rhsType->is()) { @@ -176,8 +223,8 @@ bool ParamPackMatcher::match() { if (lhsParam.getLabel() != rhsParam.getLabel()) break; - auto lhsType = lhsParam.getPlainType(); - auto rhsType = rhsParam.getPlainType(); + auto lhsType = lhsParam.getType(); + auto rhsType = rhsParam.getType(); if (lhsType->is() || rhsType->is()) { @@ -202,7 +249,7 @@ bool ParamPackMatcher::match() { // If the left hand side is a single pack expansion type, bind it // to what remains of the right hand side. if (lhsParams.size() == 1) { - auto lhsType = lhsParams[0].getPlainType(); + auto lhsType = lhsParams[0].getType(); if (auto *lhsExpansion = lhsType->getAs()) { unsigned lhsIdx = prefixLength; unsigned rhsIdx = prefixLength; @@ -213,7 +260,7 @@ bool ParamPackMatcher::match() { return true; // FIXME: Check rhs flags - rhsTypes.push_back(rhsParam.getPlainType()); + rhsTypes.push_back(rhsParam.getType()); } auto rhs = createPackBinding(ctx, rhsTypes); @@ -226,7 +273,7 @@ bool ParamPackMatcher::match() { // If the right hand side is a single pack expansion type, bind it // to what remains of the left hand side. if (rhsParams.size() == 1) { - auto rhsType = rhsParams[0].getPlainType(); + auto rhsType = rhsParams[0].getType(); if (auto *rhsExpansion = rhsType->getAs()) { unsigned lhsIdx = prefixLength; unsigned rhsIdx = prefixLength; @@ -237,7 +284,7 @@ bool ParamPackMatcher::match() { return true; // FIXME: Check lhs flags - lhsTypes.push_back(lhsParam.getPlainType()); + lhsTypes.push_back(lhsParam.getType()); } auto lhs = createPackBinding(ctx, lhsTypes); @@ -255,100 +302,3 @@ bool ParamPackMatcher::match() { // like {T..., Int} vs {Float, U...}. return true; } - -PackMatcher::PackMatcher( - ArrayRef lhsTypes, - ArrayRef rhsTypes, - ASTContext &ctx) - : lhsTypes(lhsTypes), rhsTypes(rhsTypes), ctx(ctx) {} - -bool PackMatcher::match() { - unsigned minLength = std::min(lhsTypes.size(), rhsTypes.size()); - - // Consume the longest possible prefix where neither type in - // the pair is a pack expansion type. - unsigned prefixLength = 0; - for (unsigned i = 0; i < minLength; ++i) { - unsigned lhsIdx = i; - unsigned rhsIdx = i; - - auto lhsType = lhsTypes[lhsIdx]; - auto rhsType = rhsTypes[rhsIdx]; - - if (lhsType->is() || - rhsType->is()) { - break; - } - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++prefixLength; - } - - // Consume the longest possible suffix where neither type in - // the pair is a pack expansion type. - unsigned suffixLength = 0; - for (unsigned i = 0; i < minLength - prefixLength; ++i) { - unsigned lhsIdx = lhsTypes.size() - i - 1; - unsigned rhsIdx = rhsTypes.size() - i - 1; - - auto lhsType = lhsTypes[lhsIdx]; - auto rhsType = rhsTypes[rhsIdx]; - - if (lhsType->is() || - rhsType->is()) { - break; - } - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++suffixLength; - } - - assert(prefixLength + suffixLength <= lhsTypes.size()); - assert(prefixLength + suffixLength <= rhsTypes.size()); - - // Drop the consumed prefix and suffix from each list of types. - lhsTypes = lhsTypes.drop_front(prefixLength).drop_back(suffixLength); - rhsTypes = rhsTypes.drop_front(prefixLength).drop_back(suffixLength); - - // If nothing remains, we're done. - if (lhsTypes.empty() && rhsTypes.empty()) - return false; - - // If the left hand side is a single pack expansion type, bind it - // to what remains of the right hand side. - if (lhsTypes.size() == 1) { - auto lhsType = lhsTypes[0]; - if (auto *lhsExpansion = lhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; - - auto rhs = createPackBinding(ctx, rhsTypes); - - pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx); - return false; - } - } - - // If the right hand side is a single pack expansion type, bind it - // to what remains of the left hand side. - if (rhsTypes.size() == 1) { - auto rhsType = rhsTypes[0]; - if (auto *rhsExpansion = rhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; - - auto lhs = createPackBinding(ctx, lhsTypes); - - pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx); - return false; - } - } - - // Otherwise, all remaining possibilities are invalid: - // - Neither side has any pack expansions, and they have different lengths. - // - One side has a pack expansion but the other side is too short, eg - // {Int, T..., Float} vs {Int}. - // - The prefix and suffix are mismatched, so we're left with something - // like {T..., Int} vs {Float, U...}. - return true; -} From 659f5dfcf5d3431b47868e2bd9c130e417f1192c Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Thu, 23 Mar 2023 14:35:19 -0700 Subject: [PATCH 28/30] [AST] PackExpansionMatcher: use common prefix/suffix algorithm for tuple matching --- include/swift/AST/PackExpansionMatcher.h | 42 ++++------ lib/AST/PackExpansionMatcher.cpp | 84 ------------------- .../pack-expansion-expressions.swift | 37 +++++++- 3 files changed, 54 insertions(+), 109 deletions(-) diff --git a/include/swift/AST/PackExpansionMatcher.h b/include/swift/AST/PackExpansionMatcher.h index 6d1acac35de48..e3d131efc8b26 100644 --- a/include/swift/AST/PackExpansionMatcher.h +++ b/include/swift/AST/PackExpansionMatcher.h @@ -41,28 +41,6 @@ struct MatchedPair { : lhs(lhs), rhs(rhs), lhsIdx(lhsIdx), rhsIdx(rhsIdx) {} }; -/// Performs a structural match of two lists of tuple elements. The invariant -/// is that a pack expansion type must not be followed by an unlabeled -/// element, that is, it is either the last element or the next element has -/// a label. -/// -/// In this manner, an element with a pack expansion type "absorbs" all -/// unlabeled elements up to the next label. An element with any other type -/// matches exactly one element on the other side. -class TuplePackMatcher { - ArrayRef lhsElts; - ArrayRef rhsElts; - - ASTContext &ctx; - -public: - SmallVector pairs; - - TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple); - - bool match(); -}; - /// Performs a structural match of two lists of types. /// /// The invariant is that each list must only contain at most one pack @@ -76,16 +54,18 @@ class TypeListPackMatcher { Type type; ParameterTypeFlags flags; - public: Element(Identifier label, Type type, ParameterTypeFlags flags = ParameterTypeFlags()) : label(label), type(type), flags(flags) {} + public: bool hasLabel() const { return !label.empty(); } Identifier getLabel() const { return label; } Type getType() const { return type; } + ParameterTypeFlags getFlags() const { return flags; } + static Element from(const TupleTypeElt &tupleElt); static Element from(const AnyFunctionType::Param &funcParam); static Element from(Type type); @@ -108,7 +88,21 @@ class TypeListPackMatcher { public: SmallVector pairs; - bool match(); + [[nodiscard]] bool match(); +}; + +/// Performs a structural match of two lists of tuple elements. +/// +/// The invariant is that each list must only contain at most one pack +/// expansion type. After collecting a common prefix and suffix, the +/// pack expansion on either side asborbs the remaining elements on the +/// other side. +class TuplePackMatcher : public TypeListPackMatcher { +public: + TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple) + : TypeListPackMatcher(lhsTuple->getASTContext(), + lhsTuple->getElements(), + rhsTuple->getElements()) {} }; /// Performs a structural match of two lists of (unlabeled) function diff --git a/lib/AST/PackExpansionMatcher.cpp b/lib/AST/PackExpansionMatcher.cpp index 54b804f0b83de..ee71c747d6fb0 100644 --- a/lib/AST/PackExpansionMatcher.cpp +++ b/lib/AST/PackExpansionMatcher.cpp @@ -39,90 +39,6 @@ static PackExpansionType *createPackBinding(ASTContext &ctx, return PackExpansionType::get(packType, packType); } -static PackExpansionType *gatherTupleElements(ArrayRef &elts, - Identifier name, - ASTContext &ctx) { - SmallVector types; - - if (!elts.empty() && elts.front().getName() == name) { - do { - types.push_back(elts.front().getType()); - elts = elts.slice(1); - } while (!elts.empty() && !elts.front().hasName()); - } - - return createPackBinding(ctx, types); -} - -TuplePackMatcher::TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple) - : lhsElts(lhsTuple->getElements()), - rhsElts(rhsTuple->getElements()), - ctx(lhsTuple->getASTContext()) {} - -bool TuplePackMatcher::match() { - unsigned lhsIdx = 0; - unsigned rhsIdx = 0; - - // Iterate over the two tuples in parallel, popping elements from - // the start. - while (true) { - // If both tuples have been exhausted, we're done. - if (lhsElts.empty() && rhsElts.empty()) - return false; - - if (lhsElts.empty()) { - assert(!rhsElts.empty()); - return true; - } - - // A pack expansion type on the left hand side absorbs all elements - // from the right hand side up to the next mismatched label. - auto lhsElt = lhsElts.front(); - if (auto *lhsExpansionType = lhsElt.getType()->getAs()) { - lhsElts = lhsElts.slice(1); - - assert(lhsElts.empty() || lhsElts.front().hasName() && - "Tuple element with pack expansion type cannot be followed " - "by an unlabeled element"); - - auto rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx); - pairs.emplace_back(lhsExpansionType, rhs, lhsIdx++, rhsIdx); - continue; - } - - if (rhsElts.empty()) { - assert(!lhsElts.empty()); - return true; - } - - // A pack expansion type on the right hand side absorbs all elements - // from the left hand side up to the next mismatched label. - auto rhsElt = rhsElts.front(); - if (auto *rhsExpansionType = rhsElt.getType()->getAs()) { - rhsElts = rhsElts.slice(1); - - assert(rhsElts.empty() || rhsElts.front().hasName() && - "Tuple element with pack expansion type cannot be followed " - "by an unlabeled element"); - - auto lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx); - pairs.emplace_back(lhs, rhsExpansionType, lhsIdx, rhsIdx++); - continue; - } - - // Neither side is a pack expansion. We must have an exact match. - if (lhsElt.getName() != rhsElt.getName()) - return true; - - lhsElts = lhsElts.slice(1); - rhsElts = rhsElts.slice(1); - - pairs.emplace_back(lhsElt.getType(), rhsElt.getType(), lhsIdx++, rhsIdx++); - } - - return false; -} - TypeListPackMatcher::Element TypeListPackMatcher::Element::from(const TupleTypeElt &elt) { return {elt.getName(), elt.getType()}; diff --git a/test/Constraints/pack-expansion-expressions.swift b/test/Constraints/pack-expansion-expressions.swift index 7f82b2e1968be..94e4f0f52a3ae 100644 --- a/test/Constraints/pack-expansion-expressions.swift +++ b/test/Constraints/pack-expansion-expressions.swift @@ -172,9 +172,19 @@ do { get { 42 } set {} } + + subscript(simpleTuple args: (repeat each T)) -> Int { + get { return 0 } + set {} + } + + subscript(compoundTuple args: (String, repeat each T)) -> Int { + get { return 0 } + set {} + } } - func test_that_variadic_generics_claim_unlabeled_arguments(_ args: repeat each T, test: inout TestArgMatching) { + func test_that_variadic_generics_claim_unlabeled_arguments(_ args: repeat each T, test: inout TestArgMatching, extra: String) { func testLabeled(data: repeat each U) {} func testUnlabeled(_: repeat each U) {} func testInBetween(_: repeat each U, other: String) {} @@ -193,6 +203,31 @@ do { _ = test[data: repeat each args, 0, ""] test[data: repeat each args, "", 42] = 0 + + do { + let first = "" + let second = "" + let third = 42 + + _ = test[simpleTuple: (repeat each args)] + _ = test[simpleTuple: (repeat each args, extra)] + _ = test[simpleTuple: (first, second)] + _ = test[compoundTuple: (first, repeat each args)] + _ = test[compoundTuple: (first, repeat each args, extra)] + _ = test[compoundTuple: (first, second, third)] + } + + do { + func testRef() -> (repeat each T, String) { fatalError() } + func testResult() -> (repeat each T) { fatalError() } + + func experiment1() -> (repeat each U, String) { + testResult() // Ok + } + + func experiment2(_: () -> (repeat each U)) -> (repeat each U) { fatalError() } + let _: (Int, String) = experiment2(testRef) // Ok + } } } From 9d24f8001f0f5d2931d4b4bb05ef53e2d7345998 Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Fri, 24 Mar 2023 10:58:49 -0700 Subject: [PATCH 29/30] [AST] PackExpansionMatcher/NFC: Templatarize `TypeListPackMatcher` --- include/swift/AST/PackExpansionMatcher.h | 185 ++++++++++++++++---- lib/AST/PackExpansionMatcher.cpp | 213 ++++------------------- 2 files changed, 188 insertions(+), 210 deletions(-) diff --git a/include/swift/AST/PackExpansionMatcher.h b/include/swift/AST/PackExpansionMatcher.h index e3d131efc8b26..a3b11709475dd 100644 --- a/include/swift/AST/PackExpansionMatcher.h +++ b/include/swift/AST/PackExpansionMatcher.h @@ -47,48 +47,169 @@ struct MatchedPair { /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. +template class TypeListPackMatcher { - struct Element { - private: - Identifier label; - Type type; - ParameterTypeFlags flags; + ASTContext &ctx; - Element(Identifier label, Type type, - ParameterTypeFlags flags = ParameterTypeFlags()) - : label(label), type(type), flags(flags) {} + ArrayRef lhsElements; + ArrayRef rhsElements; - public: - bool hasLabel() const { return !label.empty(); } - Identifier getLabel() const { return label; } +protected: + TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, + ArrayRef rhs) + : ctx(ctx), lhsElements(lhs), rhsElements(rhs) {} - Type getType() const { return type; } +public: + SmallVector pairs; - ParameterTypeFlags getFlags() const { return flags; } + [[nodiscard]] bool match() { + ArrayRef lhsParams(lhsElements); + ArrayRef rhsParams(rhsElements); - static Element from(const TupleTypeElt &tupleElt); - static Element from(const AnyFunctionType::Param &funcParam); - static Element from(Type type); - }; + unsigned minLength = std::min(lhsParams.size(), rhsParams.size()); - ASTContext &ctx; + // Consume the longest possible prefix where neither type in + // the pair is a pack expansion type. + unsigned prefixLength = 0; + for (unsigned i = 0; i < minLength; ++i) { + unsigned lhsIdx = i; + unsigned rhsIdx = i; - SmallVector lhsElements; - SmallVector rhsElements; + auto lhsElt = lhsParams[lhsIdx]; + auto rhsElt = rhsParams[rhsIdx]; -protected: - TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, - ArrayRef rhs); + if (getElementLabel(lhsElt) != getElementLabel(rhsElt)) + break; - TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, - ArrayRef rhs); + // FIXME: Check flags - TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, ArrayRef rhs); + auto lhsType = getElementType(lhsElt); + auto rhsType = getElementType(rhsElt); -public: - SmallVector pairs; + if (lhsType->template is() || + rhsType->template is()) { + break; + } + + // FIXME: Check flags + + pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); + ++prefixLength; + } + + // Consume the longest possible suffix where neither type in + // the pair is a pack expansion type. + unsigned suffixLength = 0; + for (unsigned i = 0; i < minLength - prefixLength; ++i) { + unsigned lhsIdx = lhsParams.size() - i - 1; + unsigned rhsIdx = rhsParams.size() - i - 1; + + auto lhsElt = lhsParams[lhsIdx]; + auto rhsElt = rhsParams[rhsIdx]; + + // FIXME: Check flags + + if (getElementLabel(lhsElt) != getElementLabel(rhsElt)) + break; + + auto lhsType = getElementType(lhsElt); + auto rhsType = getElementType(rhsElt); + + if (lhsType->template is() || + rhsType->template is()) { + break; + } + + pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); + ++suffixLength; + } + + assert(prefixLength + suffixLength <= lhsParams.size()); + assert(prefixLength + suffixLength <= rhsParams.size()); + + // Drop the consumed prefix and suffix from each list of types. + lhsParams = lhsParams.drop_front(prefixLength).drop_back(suffixLength); + rhsParams = rhsParams.drop_front(prefixLength).drop_back(suffixLength); + + // If nothing remains, we're done. + if (lhsParams.empty() && rhsParams.empty()) + return false; + + // If the left hand side is a single pack expansion type, bind it + // to what remains of the right hand side. + if (lhsParams.size() == 1) { + auto lhsType = getElementType(lhsParams[0]); + if (auto *lhsExpansion = lhsType->template getAs()) { + unsigned lhsIdx = prefixLength; + unsigned rhsIdx = prefixLength; + + SmallVector rhsTypes; + for (auto rhsElt : rhsParams) { + if (!getElementLabel(rhsElt).empty()) + return true; + + // FIXME: Check rhs flags + rhsTypes.push_back(getElementType(rhsElt)); + } + auto rhs = createPackBinding(rhsTypes); + + // FIXME: Check lhs flags + pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx); + return false; + } + } + + // If the right hand side is a single pack expansion type, bind it + // to what remains of the left hand side. + if (rhsParams.size() == 1) { + auto rhsType = getElementType(rhsParams[0]); + if (auto *rhsExpansion = rhsType->template getAs()) { + unsigned lhsIdx = prefixLength; + unsigned rhsIdx = prefixLength; + + SmallVector lhsTypes; + for (auto lhsElt : lhsParams) { + if (!getElementLabel(lhsElt).empty()) + return true; + + // FIXME: Check lhs flags + lhsTypes.push_back(getElementType(lhsElt)); + } + auto lhs = createPackBinding(lhsTypes); + + // FIXME: Check rhs flags + pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx); + return false; + } + } + + // Otherwise, all remaining possibilities are invalid: + // - Neither side has any pack expansions, and they have different lengths. + // - One side has a pack expansion but the other side is too short, eg + // {Int, T..., Float} vs {Int}. + // - The prefix and suffix are mismatched, so we're left with something + // like {T..., Int} vs {Float, U...}. + return true; + } + +private: + Identifier getElementLabel(const Element &) const; + Type getElementType(const Element &) const; + ParameterTypeFlags getElementFlags(const Element &) const; + + PackExpansionType *createPackBinding(ArrayRef types) const { + // If there is only one element and it's a PackExpansionType, + // return it directly. + if (types.size() == 1) { + if (auto *expansionType = types.front()->getAs()) { + return expansionType; + } + } - [[nodiscard]] bool match(); + // Otherwise, wrap the elements in PackExpansionType(PackType(...)). + auto *packType = PackType::get(ctx, types); + return PackExpansionType::get(packType, packType); + } }; /// Performs a structural match of two lists of tuple elements. @@ -97,7 +218,7 @@ class TypeListPackMatcher { /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. -class TuplePackMatcher : public TypeListPackMatcher { +class TuplePackMatcher : public TypeListPackMatcher { public: TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple) : TypeListPackMatcher(lhsTuple->getASTContext(), @@ -112,7 +233,7 @@ class TuplePackMatcher : public TypeListPackMatcher { /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. -class ParamPackMatcher : public TypeListPackMatcher { +class ParamPackMatcher : public TypeListPackMatcher { public: ParamPackMatcher(ArrayRef lhsParams, ArrayRef rhsParams, ASTContext &ctx) @@ -125,7 +246,7 @@ class ParamPackMatcher : public TypeListPackMatcher { /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. -class PackMatcher : public TypeListPackMatcher { +class PackMatcher : public TypeListPackMatcher { public: PackMatcher(ArrayRef lhsTypes, ArrayRef rhsTypes, ASTContext &ctx) : TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {} diff --git a/lib/AST/PackExpansionMatcher.cpp b/lib/AST/PackExpansionMatcher.cpp index ee71c747d6fb0..61df8064a232f 100644 --- a/lib/AST/PackExpansionMatcher.cpp +++ b/lib/AST/PackExpansionMatcher.cpp @@ -24,197 +24,54 @@ using namespace swift; -static PackExpansionType *createPackBinding(ASTContext &ctx, - ArrayRef types) { - // If there is only one element and it's a PackExpansionType, - // return it directly. - if (types.size() == 1) { - if (auto *expansionType = types.front()->getAs()) { - return expansionType; - } - } - - // Otherwise, wrap the elements in PackExpansionType(PackType(...)). - auto *packType = PackType::get(ctx, types); - return PackExpansionType::get(packType, packType); +template <> +Identifier TypeListPackMatcher::getElementLabel( + const TupleTypeElt &elt) const { + return elt.getName(); } -TypeListPackMatcher::Element -TypeListPackMatcher::Element::from(const TupleTypeElt &elt) { - return {elt.getName(), elt.getType()}; +template <> +Type TypeListPackMatcher::getElementType( + const TupleTypeElt &elt) const { + return elt.getType(); } -TypeListPackMatcher::Element -TypeListPackMatcher::Element::from(const AnyFunctionType::Param ¶m) { - return {param.getLabel(), param.getPlainType(), param.getParameterFlags()}; +template <> +ParameterTypeFlags TypeListPackMatcher::getElementFlags( + const TupleTypeElt &elt) const { + return ParameterTypeFlags(); } -TypeListPackMatcher::Element TypeListPackMatcher::Element::from(Type type) { - return {/*label=*/Identifier(), type}; +template <> +Identifier TypeListPackMatcher::getElementLabel( + const AnyFunctionType::Param &elt) const { + return elt.getLabel(); } -TypeListPackMatcher::TypeListPackMatcher(ASTContext &ctx, - ArrayRef lhsParams, - ArrayRef rhsParams) - : ctx(ctx) { - llvm::transform(lhsParams, std::back_inserter(lhsElements), - [&](const auto &elt) { return Element::from(elt); }); - llvm::transform(rhsParams, std::back_inserter(rhsElements), - [&](const auto &elt) { return Element::from(elt); }); +template <> +Type TypeListPackMatcher::getElementType( + const AnyFunctionType::Param &elt) const { + return elt.getPlainType(); } -TypeListPackMatcher::TypeListPackMatcher( - ASTContext &ctx, ArrayRef lhsParams, - ArrayRef rhsParams) - : ctx(ctx) { - llvm::transform(lhsParams, std::back_inserter(lhsElements), - [&](const auto &elt) { - assert(!elt.hasLabel()); - return Element::from(elt); - }); - llvm::transform(rhsParams, std::back_inserter(rhsElements), - [&](const auto &elt) { - assert(!elt.hasLabel()); - return Element::from(elt); - }); +template <> +ParameterTypeFlags TypeListPackMatcher::getElementFlags( + const AnyFunctionType::Param &elt) const { + return elt.getParameterFlags(); } -TypeListPackMatcher::TypeListPackMatcher(ASTContext &ctx, - ArrayRef lhsParams, - ArrayRef rhsParams) - : ctx(ctx) { - llvm::transform(lhsParams, std::back_inserter(lhsElements), - [&](const auto &elt) { return Element::from(elt); }); - llvm::transform(rhsParams, std::back_inserter(rhsElements), - [&](const auto &elt) { return Element::from(elt); }); +template <> +Identifier TypeListPackMatcher::getElementLabel(const Type &elt) const { + return Identifier(); } -bool TypeListPackMatcher::match() { - ArrayRef lhsParams(lhsElements); - ArrayRef rhsParams(rhsElements); - - unsigned minLength = std::min(lhsParams.size(), rhsParams.size()); - - // Consume the longest possible prefix where neither type in - // the pair is a pack expansion type. - unsigned prefixLength = 0; - for (unsigned i = 0; i < minLength; ++i) { - unsigned lhsIdx = i; - unsigned rhsIdx = i; - - auto lhsParam = lhsParams[lhsIdx]; - auto rhsParam = rhsParams[rhsIdx]; - - if (lhsParam.getLabel() != rhsParam.getLabel()) - break; - - // FIXME: Check flags - - auto lhsType = lhsParam.getType(); - auto rhsType = rhsParam.getType(); - - if (lhsType->is() || - rhsType->is()) { - break; - } - - // FIXME: Check flags - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++prefixLength; - } - - // Consume the longest possible suffix where neither type in - // the pair is a pack expansion type. - unsigned suffixLength = 0; - for (unsigned i = 0; i < minLength - prefixLength; ++i) { - unsigned lhsIdx = lhsParams.size() - i - 1; - unsigned rhsIdx = rhsParams.size() - i - 1; - - auto lhsParam = lhsParams[lhsIdx]; - auto rhsParam = rhsParams[rhsIdx]; - - // FIXME: Check flags - - if (lhsParam.getLabel() != rhsParam.getLabel()) - break; - - auto lhsType = lhsParam.getType(); - auto rhsType = rhsParam.getType(); - - if (lhsType->is() || - rhsType->is()) { - break; - } - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++suffixLength; - } - - assert(prefixLength + suffixLength <= lhsParams.size()); - assert(prefixLength + suffixLength <= rhsParams.size()); - - // Drop the consumed prefix and suffix from each list of types. - lhsParams = lhsParams.drop_front(prefixLength).drop_back(suffixLength); - rhsParams = rhsParams.drop_front(prefixLength).drop_back(suffixLength); - - // If nothing remains, we're done. - if (lhsParams.empty() && rhsParams.empty()) - return false; - - // If the left hand side is a single pack expansion type, bind it - // to what remains of the right hand side. - if (lhsParams.size() == 1) { - auto lhsType = lhsParams[0].getType(); - if (auto *lhsExpansion = lhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; - - SmallVector rhsTypes; - for (auto rhsParam : rhsParams) { - if (rhsParam.hasLabel()) - return true; - - // FIXME: Check rhs flags - rhsTypes.push_back(rhsParam.getType()); - } - auto rhs = createPackBinding(ctx, rhsTypes); - - // FIXME: Check lhs flags - pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx); - return false; - } - } - - // If the right hand side is a single pack expansion type, bind it - // to what remains of the left hand side. - if (rhsParams.size() == 1) { - auto rhsType = rhsParams[0].getType(); - if (auto *rhsExpansion = rhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; - - SmallVector lhsTypes; - for (auto lhsParam : lhsParams) { - if (lhsParam.hasLabel()) - return true; - - // FIXME: Check lhs flags - lhsTypes.push_back(lhsParam.getType()); - } - auto lhs = createPackBinding(ctx, lhsTypes); - - // FIXME: Check rhs flags - pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx); - return false; - } - } +template <> +Type TypeListPackMatcher::getElementType(const Type &elt) const { + return elt; +} - // Otherwise, all remaining possibilities are invalid: - // - Neither side has any pack expansions, and they have different lengths. - // - One side has a pack expansion but the other side is too short, eg - // {Int, T..., Float} vs {Int}. - // - The prefix and suffix are mismatched, so we're left with something - // like {T..., Int} vs {Float, U...}. - return true; +template <> +ParameterTypeFlags +TypeListPackMatcher::getElementFlags(const Type &elt) const { + return ParameterTypeFlags(); } From 92afbc5e1577daafb16712f692c2e8fb5e336d59 Mon Sep 17 00:00:00 2001 From: Pavel Yaskevich Date: Fri, 24 Mar 2023 11:02:05 -0700 Subject: [PATCH 30/30] [AST] PackExpansionMatcher/NFC: Rename `{lhs, rhs}Params` to `{lhs, rhs}Elts` --- include/swift/AST/PackExpansionMatcher.h | 40 ++++++++++++------------ 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/include/swift/AST/PackExpansionMatcher.h b/include/swift/AST/PackExpansionMatcher.h index a3b11709475dd..0d21c70c59554 100644 --- a/include/swift/AST/PackExpansionMatcher.h +++ b/include/swift/AST/PackExpansionMatcher.h @@ -63,10 +63,10 @@ class TypeListPackMatcher { SmallVector pairs; [[nodiscard]] bool match() { - ArrayRef lhsParams(lhsElements); - ArrayRef rhsParams(rhsElements); + ArrayRef lhsElts(lhsElements); + ArrayRef rhsElts(rhsElements); - unsigned minLength = std::min(lhsParams.size(), rhsParams.size()); + unsigned minLength = std::min(lhsElts.size(), rhsElts.size()); // Consume the longest possible prefix where neither type in // the pair is a pack expansion type. @@ -75,8 +75,8 @@ class TypeListPackMatcher { unsigned lhsIdx = i; unsigned rhsIdx = i; - auto lhsElt = lhsParams[lhsIdx]; - auto rhsElt = rhsParams[rhsIdx]; + auto lhsElt = lhsElts[lhsIdx]; + auto rhsElt = rhsElts[rhsIdx]; if (getElementLabel(lhsElt) != getElementLabel(rhsElt)) break; @@ -101,11 +101,11 @@ class TypeListPackMatcher { // the pair is a pack expansion type. unsigned suffixLength = 0; for (unsigned i = 0; i < minLength - prefixLength; ++i) { - unsigned lhsIdx = lhsParams.size() - i - 1; - unsigned rhsIdx = rhsParams.size() - i - 1; + unsigned lhsIdx = lhsElts.size() - i - 1; + unsigned rhsIdx = rhsElts.size() - i - 1; - auto lhsElt = lhsParams[lhsIdx]; - auto rhsElt = rhsParams[rhsIdx]; + auto lhsElt = lhsElts[lhsIdx]; + auto rhsElt = rhsElts[rhsIdx]; // FIXME: Check flags @@ -124,27 +124,27 @@ class TypeListPackMatcher { ++suffixLength; } - assert(prefixLength + suffixLength <= lhsParams.size()); - assert(prefixLength + suffixLength <= rhsParams.size()); + assert(prefixLength + suffixLength <= lhsElts.size()); + assert(prefixLength + suffixLength <= rhsElts.size()); // Drop the consumed prefix and suffix from each list of types. - lhsParams = lhsParams.drop_front(prefixLength).drop_back(suffixLength); - rhsParams = rhsParams.drop_front(prefixLength).drop_back(suffixLength); + lhsElts = lhsElts.drop_front(prefixLength).drop_back(suffixLength); + rhsElts = rhsElts.drop_front(prefixLength).drop_back(suffixLength); // If nothing remains, we're done. - if (lhsParams.empty() && rhsParams.empty()) + if (lhsElts.empty() && rhsElts.empty()) return false; // If the left hand side is a single pack expansion type, bind it // to what remains of the right hand side. - if (lhsParams.size() == 1) { - auto lhsType = getElementType(lhsParams[0]); + if (lhsElts.size() == 1) { + auto lhsType = getElementType(lhsElts[0]); if (auto *lhsExpansion = lhsType->template getAs()) { unsigned lhsIdx = prefixLength; unsigned rhsIdx = prefixLength; SmallVector rhsTypes; - for (auto rhsElt : rhsParams) { + for (auto rhsElt : rhsElts) { if (!getElementLabel(rhsElt).empty()) return true; @@ -161,14 +161,14 @@ class TypeListPackMatcher { // If the right hand side is a single pack expansion type, bind it // to what remains of the left hand side. - if (rhsParams.size() == 1) { - auto rhsType = getElementType(rhsParams[0]); + if (rhsElts.size() == 1) { + auto rhsType = getElementType(rhsElts[0]); if (auto *rhsExpansion = rhsType->template getAs()) { unsigned lhsIdx = prefixLength; unsigned rhsIdx = prefixLength; SmallVector lhsTypes; - for (auto lhsElt : lhsParams) { + for (auto lhsElt : lhsElts) { if (!getElementLabel(lhsElt).empty()) return true;