diff --git a/include/swift/AST/PackExpansionMatcher.h b/include/swift/AST/PackExpansionMatcher.h index f24fd59e406bd..0d21c70c59554 100644 --- a/include/swift/AST/PackExpansionMatcher.h +++ b/include/swift/AST/PackExpansionMatcher.h @@ -41,26 +41,189 @@ struct MatchedPair { : lhs(lhs), rhs(rhs), lhsIdx(lhsIdx), rhsIdx(rhsIdx) {} }; -/// Performs a structural match of two lists of tuple elements. The invariant -/// is that a pack expansion type must not be followed by an unlabeled -/// element, that is, it is either the last element or the next element has -/// a label. +/// Performs a structural match of two lists of types. /// -/// In this manner, an element with a pack expansion type "absorbs" all -/// unlabeled elements up to the next label. An element with any other type -/// matches exactly one element on the other side. -class TuplePackMatcher { - ArrayRef lhsElts; - ArrayRef rhsElts; - +/// The invariant is that each list must only contain at most one pack +/// expansion type. After collecting a common prefix and suffix, the +/// pack expansion on either side asborbs the remaining elements on the +/// other side. +template +class TypeListPackMatcher { ASTContext &ctx; + ArrayRef lhsElements; + ArrayRef rhsElements; + +protected: + TypeListPackMatcher(ASTContext &ctx, ArrayRef lhs, + ArrayRef rhs) + : ctx(ctx), lhsElements(lhs), rhsElements(rhs) {} + public: SmallVector pairs; - TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple); + [[nodiscard]] bool match() { + ArrayRef lhsElts(lhsElements); + ArrayRef rhsElts(rhsElements); + + unsigned minLength = std::min(lhsElts.size(), rhsElts.size()); + + // Consume the longest possible prefix where neither type in + // the pair is a pack expansion type. + unsigned prefixLength = 0; + for (unsigned i = 0; i < minLength; ++i) { + unsigned lhsIdx = i; + unsigned rhsIdx = i; + + auto lhsElt = lhsElts[lhsIdx]; + auto rhsElt = rhsElts[rhsIdx]; + + if (getElementLabel(lhsElt) != getElementLabel(rhsElt)) + break; + + // FIXME: Check flags + + auto lhsType = getElementType(lhsElt); + auto rhsType = getElementType(rhsElt); + + if (lhsType->template is() || + rhsType->template is()) { + break; + } + + // FIXME: Check flags + + pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); + ++prefixLength; + } + + // Consume the longest possible suffix where neither type in + // the pair is a pack expansion type. + unsigned suffixLength = 0; + for (unsigned i = 0; i < minLength - prefixLength; ++i) { + unsigned lhsIdx = lhsElts.size() - i - 1; + unsigned rhsIdx = rhsElts.size() - i - 1; + + auto lhsElt = lhsElts[lhsIdx]; + auto rhsElt = rhsElts[rhsIdx]; + + // FIXME: Check flags + + if (getElementLabel(lhsElt) != getElementLabel(rhsElt)) + break; + + auto lhsType = getElementType(lhsElt); + auto rhsType = getElementType(rhsElt); + + if (lhsType->template is() || + rhsType->template is()) { + break; + } + + pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); + ++suffixLength; + } + + assert(prefixLength + suffixLength <= lhsElts.size()); + assert(prefixLength + suffixLength <= rhsElts.size()); + + // Drop the consumed prefix and suffix from each list of types. + lhsElts = lhsElts.drop_front(prefixLength).drop_back(suffixLength); + rhsElts = rhsElts.drop_front(prefixLength).drop_back(suffixLength); + + // If nothing remains, we're done. + if (lhsElts.empty() && rhsElts.empty()) + return false; + + // If the left hand side is a single pack expansion type, bind it + // to what remains of the right hand side. + if (lhsElts.size() == 1) { + auto lhsType = getElementType(lhsElts[0]); + if (auto *lhsExpansion = lhsType->template getAs()) { + unsigned lhsIdx = prefixLength; + unsigned rhsIdx = prefixLength; + + SmallVector rhsTypes; + for (auto rhsElt : rhsElts) { + if (!getElementLabel(rhsElt).empty()) + return true; - bool match(); + // FIXME: Check rhs flags + rhsTypes.push_back(getElementType(rhsElt)); + } + auto rhs = createPackBinding(rhsTypes); + + // FIXME: Check lhs flags + pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx); + return false; + } + } + + // If the right hand side is a single pack expansion type, bind it + // to what remains of the left hand side. + if (rhsElts.size() == 1) { + auto rhsType = getElementType(rhsElts[0]); + if (auto *rhsExpansion = rhsType->template getAs()) { + unsigned lhsIdx = prefixLength; + unsigned rhsIdx = prefixLength; + + SmallVector lhsTypes; + for (auto lhsElt : lhsElts) { + if (!getElementLabel(lhsElt).empty()) + return true; + + // FIXME: Check lhs flags + lhsTypes.push_back(getElementType(lhsElt)); + } + auto lhs = createPackBinding(lhsTypes); + + // FIXME: Check rhs flags + pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx); + return false; + } + } + + // Otherwise, all remaining possibilities are invalid: + // - Neither side has any pack expansions, and they have different lengths. + // - One side has a pack expansion but the other side is too short, eg + // {Int, T..., Float} vs {Int}. + // - The prefix and suffix are mismatched, so we're left with something + // like {T..., Int} vs {Float, U...}. + return true; + } + +private: + Identifier getElementLabel(const Element &) const; + Type getElementType(const Element &) const; + ParameterTypeFlags getElementFlags(const Element &) const; + + PackExpansionType *createPackBinding(ArrayRef types) const { + // If there is only one element and it's a PackExpansionType, + // return it directly. + if (types.size() == 1) { + if (auto *expansionType = types.front()->getAs()) { + return expansionType; + } + } + + // Otherwise, wrap the elements in PackExpansionType(PackType(...)). + auto *packType = PackType::get(ctx, types); + return PackExpansionType::get(packType, packType); + } +}; + +/// Performs a structural match of two lists of tuple elements. +/// +/// The invariant is that each list must only contain at most one pack +/// expansion type. After collecting a common prefix and suffix, the +/// pack expansion on either side asborbs the remaining elements on the +/// other side. +class TuplePackMatcher : public TypeListPackMatcher { +public: + TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple) + : TypeListPackMatcher(lhsTuple->getASTContext(), + lhsTuple->getElements(), + rhsTuple->getElements()) {} }; /// Performs a structural match of two lists of (unlabeled) function @@ -70,20 +233,11 @@ class TuplePackMatcher { /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. -class ParamPackMatcher { - ArrayRef lhsParams; - ArrayRef rhsParams; - - ASTContext &ctx; - +class ParamPackMatcher : public TypeListPackMatcher { public: - SmallVector pairs; - ParamPackMatcher(ArrayRef lhsParams, - ArrayRef rhsParams, - ASTContext &ctx); - - bool match(); + ArrayRef rhsParams, ASTContext &ctx) + : TypeListPackMatcher(ctx, lhsParams, rhsParams) {} }; /// Performs a structural match of two lists of types. @@ -92,20 +246,10 @@ class ParamPackMatcher { /// expansion type. After collecting a common prefix and suffix, the /// pack expansion on either side asborbs the remaining elements on the /// other side. -class PackMatcher { - ArrayRef lhsTypes; - ArrayRef rhsTypes; - - ASTContext &ctx; - +class PackMatcher : public TypeListPackMatcher { public: - SmallVector pairs; - - PackMatcher(ArrayRef lhsTypes, - ArrayRef rhsTypes, - ASTContext &ctx); - - bool match(); + PackMatcher(ArrayRef lhsTypes, ArrayRef rhsTypes, ASTContext &ctx) + : TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {} }; } // end namespace swift diff --git a/lib/AST/PackExpansionMatcher.cpp b/lib/AST/PackExpansionMatcher.cpp index 3c78e36b7ba58..61df8064a232f 100644 --- a/lib/AST/PackExpansionMatcher.cpp +++ b/lib/AST/PackExpansionMatcher.cpp @@ -24,319 +24,54 @@ using namespace swift; -static PackExpansionType *createPackBinding(ASTContext &ctx, - ArrayRef types) { - // If there is only one element and it's a PackExpansionType, - // return it directly. - if (types.size() == 1) { - if (auto *expansionType = types.front()->getAs()) { - return expansionType; - } - } - - // Otherwise, wrap the elements in PackExpansionType(PackType(...)). - auto *packType = PackType::get(ctx, types); - return PackExpansionType::get(packType, packType); +template <> +Identifier TypeListPackMatcher::getElementLabel( + const TupleTypeElt &elt) const { + return elt.getName(); } -static PackExpansionType *gatherTupleElements(ArrayRef &elts, - Identifier name, - ASTContext &ctx) { - SmallVector types; - - if (!elts.empty() && elts.front().getName() == name) { - do { - types.push_back(elts.front().getType()); - elts = elts.slice(1); - } while (!elts.empty() && !elts.front().hasName()); - } - - return createPackBinding(ctx, types); +template <> +Type TypeListPackMatcher::getElementType( + const TupleTypeElt &elt) const { + return elt.getType(); } -TuplePackMatcher::TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple) - : lhsElts(lhsTuple->getElements()), - rhsElts(rhsTuple->getElements()), - ctx(lhsTuple->getASTContext()) {} - -bool TuplePackMatcher::match() { - unsigned lhsIdx = 0; - unsigned rhsIdx = 0; - - // Iterate over the two tuples in parallel, popping elements from - // the start. - while (true) { - // If both tuples have been exhausted, we're done. - if (lhsElts.empty() && rhsElts.empty()) - return false; - - if (lhsElts.empty()) { - assert(!rhsElts.empty()); - return true; - } - - // A pack expansion type on the left hand side absorbs all elements - // from the right hand side up to the next mismatched label. - auto lhsElt = lhsElts.front(); - if (auto *lhsExpansionType = lhsElt.getType()->getAs()) { - lhsElts = lhsElts.slice(1); - - assert(lhsElts.empty() || lhsElts.front().hasName() && - "Tuple element with pack expansion type cannot be followed " - "by an unlabeled element"); - - auto rhs = gatherTupleElements(rhsElts, lhsElt.getName(), ctx); - pairs.emplace_back(lhsExpansionType, rhs, lhsIdx++, rhsIdx); - continue; - } - - if (rhsElts.empty()) { - assert(!lhsElts.empty()); - return true; - } - - // A pack expansion type on the right hand side absorbs all elements - // from the left hand side up to the next mismatched label. - auto rhsElt = rhsElts.front(); - if (auto *rhsExpansionType = rhsElt.getType()->getAs()) { - rhsElts = rhsElts.slice(1); - - assert(rhsElts.empty() || rhsElts.front().hasName() && - "Tuple element with pack expansion type cannot be followed " - "by an unlabeled element"); - - auto lhs = gatherTupleElements(lhsElts, rhsElt.getName(), ctx); - pairs.emplace_back(lhs, rhsExpansionType, lhsIdx, rhsIdx++); - continue; - } - - // Neither side is a pack expansion. We must have an exact match. - if (lhsElt.getName() != rhsElt.getName()) - return true; - - lhsElts = lhsElts.slice(1); - rhsElts = rhsElts.slice(1); - - pairs.emplace_back(lhsElt.getType(), rhsElt.getType(), lhsIdx++, rhsIdx++); - } - - return false; +template <> +ParameterTypeFlags TypeListPackMatcher::getElementFlags( + const TupleTypeElt &elt) const { + return ParameterTypeFlags(); } -ParamPackMatcher::ParamPackMatcher( - ArrayRef lhsParams, - ArrayRef rhsParams, - ASTContext &ctx) - : lhsParams(lhsParams), rhsParams(rhsParams), ctx(ctx) {} - -bool ParamPackMatcher::match() { - unsigned minLength = std::min(lhsParams.size(), rhsParams.size()); - - // Consume the longest possible prefix where neither type in - // the pair is a pack expansion type. - unsigned prefixLength = 0; - for (unsigned i = 0; i < minLength; ++i) { - unsigned lhsIdx = i; - unsigned rhsIdx = i; - - auto lhsParam = lhsParams[lhsIdx]; - auto rhsParam = rhsParams[rhsIdx]; - - // FIXME: Check flags - - auto lhsType = lhsParam.getPlainType(); - auto rhsType = rhsParam.getPlainType(); - - if (lhsType->is() || - rhsType->is()) { - break; - } - - // FIXME: Check flags - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++prefixLength; - } - - // Consume the longest possible suffix where neither type in - // the pair is a pack expansion type. - unsigned suffixLength = 0; - for (unsigned i = 0; i < minLength - prefixLength; ++i) { - unsigned lhsIdx = lhsParams.size() - i - 1; - unsigned rhsIdx = rhsParams.size() - i - 1; - - auto lhsParam = lhsParams[lhsIdx]; - auto rhsParam = rhsParams[rhsIdx]; - - // FIXME: Check flags - - auto lhsType = lhsParam.getPlainType(); - auto rhsType = rhsParam.getPlainType(); - - if (lhsType->is() || - rhsType->is()) { - break; - } - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++suffixLength; - } - - assert(prefixLength + suffixLength <= lhsParams.size()); - assert(prefixLength + suffixLength <= rhsParams.size()); - - // Drop the consumed prefix and suffix from each list of types. - lhsParams = lhsParams.drop_front(prefixLength).drop_back(suffixLength); - rhsParams = rhsParams.drop_front(prefixLength).drop_back(suffixLength); - - // If nothing remains, we're done. - if (lhsParams.empty() && rhsParams.empty()) - return false; - - // If the left hand side is a single pack expansion type, bind it - // to what remains of the right hand side. - if (lhsParams.size() == 1) { - auto lhsType = lhsParams[0].getPlainType(); - if (auto *lhsExpansion = lhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; - - SmallVector rhsTypes; - for (auto rhsParam : rhsParams) { - // FIXME: Check rhs flags - rhsTypes.push_back(rhsParam.getPlainType()); - } - auto rhs = createPackBinding(ctx, rhsTypes); - - // FIXME: Check lhs flags - pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx); - return false; - } - } - - // If the right hand side is a single pack expansion type, bind it - // to what remains of the left hand side. - if (rhsParams.size() == 1) { - auto rhsType = rhsParams[0].getPlainType(); - if (auto *rhsExpansion = rhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; - - SmallVector lhsTypes; - for (auto lhsParam : lhsParams) { - // FIXME: Check lhs flags - lhsTypes.push_back(lhsParam.getPlainType()); - } - auto lhs = createPackBinding(ctx, lhsTypes); - - // FIXME: Check rhs flags - pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx); - return false; - } - } - - // Otherwise, all remaining possibilities are invalid: - // - Neither side has any pack expansions, and they have different lengths. - // - One side has a pack expansion but the other side is too short, eg - // {Int, T..., Float} vs {Int}. - // - The prefix and suffix are mismatched, so we're left with something - // like {T..., Int} vs {Float, U...}. - return true; +template <> +Identifier TypeListPackMatcher::getElementLabel( + const AnyFunctionType::Param &elt) const { + return elt.getLabel(); } -PackMatcher::PackMatcher( - ArrayRef lhsTypes, - ArrayRef rhsTypes, - ASTContext &ctx) - : lhsTypes(lhsTypes), rhsTypes(rhsTypes), ctx(ctx) {} - -bool PackMatcher::match() { - unsigned minLength = std::min(lhsTypes.size(), rhsTypes.size()); - - // Consume the longest possible prefix where neither type in - // the pair is a pack expansion type. - unsigned prefixLength = 0; - for (unsigned i = 0; i < minLength; ++i) { - unsigned lhsIdx = i; - unsigned rhsIdx = i; - - auto lhsType = lhsTypes[lhsIdx]; - auto rhsType = rhsTypes[rhsIdx]; - - if (lhsType->is() || - rhsType->is()) { - break; - } - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++prefixLength; - } - - // Consume the longest possible suffix where neither type in - // the pair is a pack expansion type. - unsigned suffixLength = 0; - for (unsigned i = 0; i < minLength - prefixLength; ++i) { - unsigned lhsIdx = lhsTypes.size() - i - 1; - unsigned rhsIdx = rhsTypes.size() - i - 1; - - auto lhsType = lhsTypes[lhsIdx]; - auto rhsType = rhsTypes[rhsIdx]; - - if (lhsType->is() || - rhsType->is()) { - break; - } - - pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx); - ++suffixLength; - } - - assert(prefixLength + suffixLength <= lhsTypes.size()); - assert(prefixLength + suffixLength <= rhsTypes.size()); - - // Drop the consumed prefix and suffix from each list of types. - lhsTypes = lhsTypes.drop_front(prefixLength).drop_back(suffixLength); - rhsTypes = rhsTypes.drop_front(prefixLength).drop_back(suffixLength); - - // If nothing remains, we're done. - if (lhsTypes.empty() && rhsTypes.empty()) - return false; - - // If the left hand side is a single pack expansion type, bind it - // to what remains of the right hand side. - if (lhsTypes.size() == 1) { - auto lhsType = lhsTypes[0]; - if (auto *lhsExpansion = lhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; - - auto rhs = createPackBinding(ctx, rhsTypes); - - pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx); - return false; - } - } +template <> +Type TypeListPackMatcher::getElementType( + const AnyFunctionType::Param &elt) const { + return elt.getPlainType(); +} - // If the right hand side is a single pack expansion type, bind it - // to what remains of the left hand side. - if (rhsTypes.size() == 1) { - auto rhsType = rhsTypes[0]; - if (auto *rhsExpansion = rhsType->getAs()) { - unsigned lhsIdx = prefixLength; - unsigned rhsIdx = prefixLength; +template <> +ParameterTypeFlags TypeListPackMatcher::getElementFlags( + const AnyFunctionType::Param &elt) const { + return elt.getParameterFlags(); +} - auto lhs = createPackBinding(ctx, lhsTypes); +template <> +Identifier TypeListPackMatcher::getElementLabel(const Type &elt) const { + return Identifier(); +} - pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx); - return false; - } - } +template <> +Type TypeListPackMatcher::getElementType(const Type &elt) const { + return elt; +} - // Otherwise, all remaining possibilities are invalid: - // - Neither side has any pack expansions, and they have different lengths. - // - One side has a pack expansion but the other side is too short, eg - // {Int, T..., Float} vs {Int}. - // - The prefix and suffix are mismatched, so we're left with something - // like {T..., Int} vs {Float, U...}. - return true; +template <> +ParameterTypeFlags +TypeListPackMatcher::getElementFlags(const Type &elt) const { + return ParameterTypeFlags(); } diff --git a/lib/Sema/CSSimplify.cpp b/lib/Sema/CSSimplify.cpp index d9f1d46a238df..142d3cc91c52e 100644 --- a/lib/Sema/CSSimplify.cpp +++ b/lib/Sema/CSSimplify.cpp @@ -2062,13 +2062,19 @@ class TupleMatcher { TupleType *tuple2; public: + enum class MatchKind : uint8_t { + Equality, + Subtype, + Conversion, + }; + SmallVector pairs; bool hasLabelMismatch = false; TupleMatcher(TupleType *tuple1, TupleType *tuple2) - : tuple1(tuple1), tuple2(tuple2) {} + : tuple1(tuple1), tuple2(tuple2) {} - bool matchBind() { + bool match(MatchKind kind, ConstraintLocatorBuilder locator) { // FIXME: TuplePackMatcher should completely replace the non-variadic // case too eventually. if (tuple1->containsPackExpansionType() || @@ -2084,42 +2090,32 @@ class TupleMatcher { if (tuple1->getNumElements() != tuple2->getNumElements()) return true; - for (unsigned i = 0, n = tuple1->getNumElements(); i != n; ++i) { - const auto &elt1 = tuple1->getElement(i); - const auto &elt2 = tuple2->getElement(i); + switch (kind) { + case MatchKind::Equality: + return matchEquality(isInPatternMatchingContext(locator)); - // If the names don't match, we have a conflict. - if (elt1.getName() != elt2.getName()) - return true; + case MatchKind::Subtype: + return matchSubtype(); - pairs.emplace_back(elt1.getType(), elt2.getType(), i, i); + case MatchKind::Conversion: + return matchConversion(); } - - return false; } - bool matchInPatternMatchingContext() { - // FIXME: TuplePackMatcher should completely replace the non-variadic - // case too eventually. - if (tuple1->containsPackExpansionType() || - tuple2->containsPackExpansionType()) { - TuplePackMatcher matcher(tuple1, tuple2); - if (matcher.match()) - return true; - - pairs = matcher.pairs; - return false; - } - - if (tuple1->getNumElements() != tuple2->getNumElements()) - return true; - +private: + bool matchEquality(bool inPatternMatchingContext) { for (unsigned i = 0, n = tuple1->getNumElements(); i != n; ++i) { const auto &elt1 = tuple1->getElement(i); const auto &elt2 = tuple2->getElement(i); - if (elt1.hasName() && elt1.getName() != elt2.getName()) - return true; + if (inPatternMatchingContext) { + if (elt1.hasName() && elt1.getName() != elt2.getName()) + return true; + } else { + // If the names don't match, we have a conflict. + if (elt1.getName() != elt2.getName()) + return true; + } pairs.emplace_back(elt1.getType(), elt2.getType(), i, i); } @@ -2128,21 +2124,6 @@ class TupleMatcher { } bool matchSubtype() { - // FIXME: TuplePackMatcher should completely replace the non-variadic - // case too eventually. - if (tuple1->containsPackExpansionType() || - tuple2->containsPackExpansionType()) { - TuplePackMatcher matcher(tuple1, tuple2); - if (matcher.match()) - return true; - - pairs = matcher.pairs; - return false; - } - - if (tuple1->getNumElements() != tuple2->getNumElements()) - return true; - for (unsigned i = 0, n = tuple1->getNumElements(); i != n; ++i) { const auto &elt1 = tuple1->getElement(i); const auto &elt2 = tuple2->getElement(i); @@ -2166,18 +2147,6 @@ class TupleMatcher { } bool matchConversion() { - // FIXME: TuplePackMatcher should completely replace the non-variadic - // case too eventually. - if (tuple1->containsPackExpansionType() || - tuple2->containsPackExpansionType()) { - TuplePackMatcher matcher(tuple1, tuple2); - if (matcher.match()) - return true; - - pairs = matcher.pairs; - return false; - } - SmallVector sources; if (computeTupleShuffle(tuple1, tuple2, sources)) return true; @@ -2201,23 +2170,16 @@ ConstraintSystem::TypeMatchResult ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, ConstraintKind kind, TypeMatchOptions flags, ConstraintLocatorBuilder locator) { - TupleMatcher matcher(tuple1, tuple2); + using TupleMatchKind = TupleMatcher::MatchKind; ConstraintKind subkind; + TupleMatchKind matchKind; switch (kind) { case ConstraintKind::Bind: case ConstraintKind::Equal: { subkind = kind; - - if (isInPatternMatchingContext(locator)) { - if (matcher.matchInPatternMatchingContext()) - return getTypeMatchFailure(locator); - } else { - if (matcher.matchBind()) - return getTypeMatchFailure(locator); - } - + matchKind = TupleMatchKind::Equality; break; } @@ -2227,17 +2189,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, case ConstraintKind::Subtype: case ConstraintKind::BindToPointerType: { subkind = kind; - - if (matcher.matchSubtype()) - return getTypeMatchFailure(locator); - - if (matcher.hasLabelMismatch) { - // If we had a label mismatch, emit a warning. This is something we - // shouldn't permit, as it's more permissive than what a conversion would - // allow. Ideally we'd turn this into an error in Swift 6 mode. - recordFix(AllowTupleLabelMismatch::create( - *this, tuple1, tuple2, getConstraintLocator(locator))); - } + matchKind = TupleMatchKind::Subtype; break; } @@ -2245,11 +2197,7 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, case ConstraintKind::ArgumentConversion: case ConstraintKind::OperatorArgumentConversion: { subkind = ConstraintKind::Conversion; - - // Compute the element shuffles for conversions. - if (matcher.matchConversion()) - return getTypeMatchFailure(locator); - + matchKind = TupleMatchKind::Conversion; break; } @@ -2289,6 +2237,19 @@ ConstraintSystem::matchTupleTypes(TupleType *tuple1, TupleType *tuple2, llvm_unreachable("Bad constraint kind in matchTupleTypes()"); } + TupleMatcher matcher(tuple1, tuple2); + + if (matcher.match(matchKind, locator)) + return getTypeMatchFailure(locator); + + if (matcher.hasLabelMismatch) { + // If we had a label mismatch, emit a warning. This is something we + // shouldn't permit, as it's more permissive than what a conversion would + // allow. Ideally we'd turn this into an error in Swift 6 mode. + recordFix(AllowTupleLabelMismatch::create(*this, tuple1, tuple2, + getConstraintLocator(locator))); + } + TypeMatchOptions subflags = getDefaultDecompositionOptions(flags); for (auto pair : matcher.pairs) { diff --git a/test/Constraints/pack-expansion-expressions.swift b/test/Constraints/pack-expansion-expressions.swift index af71c93c095db..9b292e1e3ae36 100644 --- a/test/Constraints/pack-expansion-expressions.swift +++ b/test/Constraints/pack-expansion-expressions.swift @@ -172,9 +172,19 @@ do { get { 42 } set {} } + + subscript(simpleTuple args: (repeat each T)) -> Int { + get { return 0 } + set {} + } + + subscript(compoundTuple args: (String, repeat each T)) -> Int { + get { return 0 } + set {} + } } - func test_that_variadic_generics_claim_unlabeled_arguments(_ args: repeat each T, test: inout TestArgMatching) { + func test_that_variadic_generics_claim_unlabeled_arguments(_ args: repeat each T, test: inout TestArgMatching, extra: String) { func testLabeled(data: repeat each U) {} func testUnlabeled(_: repeat each U) {} func testInBetween(_: repeat each U, other: String) {} @@ -193,6 +203,31 @@ do { _ = test[data: repeat each args, 0, ""] test[data: repeat each args, "", 42] = 0 + + do { + let first = "" + let second = "" + let third = 42 + + _ = test[simpleTuple: (repeat each args)] + _ = test[simpleTuple: (repeat each args, extra)] + _ = test[simpleTuple: (first, second)] + _ = test[compoundTuple: (first, repeat each args)] + _ = test[compoundTuple: (first, repeat each args, extra)] + _ = test[compoundTuple: (first, second, third)] + } + + do { + func testRef() -> (repeat each T, String) { fatalError() } + func testResult() -> (repeat each T) { fatalError() } + + func experiment1() -> (repeat each U, String) { + testResult() // Ok + } + + func experiment2(_: () -> (repeat each U)) -> (repeat each U) { fatalError() } + let _: (Int, String) = experiment2(testRef) // Ok + } } }