@@ -41,26 +41,189 @@ struct MatchedPair {
4141 : lhs(lhs), rhs(rhs), lhsIdx(lhsIdx), rhsIdx(rhsIdx) {}
4242};
4343
44- // / Performs a structural match of two lists of tuple elements. The invariant
45- // / is that a pack expansion type must not be followed by an unlabeled
46- // / element, that is, it is either the last element or the next element has
47- // / a label.
44+ // / Performs a structural match of two lists of types.
4845// /
49- // / In this manner, an element with a pack expansion type "absorbs" all
50- // / unlabeled elements up to the next label. An element with any other type
51- // / matches exactly one element on the other side.
52- class TuplePackMatcher {
53- ArrayRef<TupleTypeElt> lhsElts;
54- ArrayRef<TupleTypeElt> rhsElts;
55-
46+ // / The invariant is that each list must only contain at most one pack
47+ // / expansion type. After collecting a common prefix and suffix, the
48+ // / pack expansion on either side asborbs the remaining elements on the
49+ // / other side.
50+ template <typename Element>
51+ class TypeListPackMatcher {
5652 ASTContext &ctx;
5753
54+ ArrayRef<Element> lhsElements;
55+ ArrayRef<Element> rhsElements;
56+
57+ protected:
58+ TypeListPackMatcher (ASTContext &ctx, ArrayRef<Element> lhs,
59+ ArrayRef<Element> rhs)
60+ : ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}
61+
5862public:
5963 SmallVector<MatchedPair, 4 > pairs;
6064
61- TuplePackMatcher (TupleType *lhsTuple, TupleType *rhsTuple);
65+ [[nodiscard]] bool match () {
66+ ArrayRef<Element> lhsElts (lhsElements);
67+ ArrayRef<Element> rhsElts (rhsElements);
68+
69+ unsigned minLength = std::min (lhsElts.size (), rhsElts.size ());
70+
71+ // Consume the longest possible prefix where neither type in
72+ // the pair is a pack expansion type.
73+ unsigned prefixLength = 0 ;
74+ for (unsigned i = 0 ; i < minLength; ++i) {
75+ unsigned lhsIdx = i;
76+ unsigned rhsIdx = i;
77+
78+ auto lhsElt = lhsElts[lhsIdx];
79+ auto rhsElt = rhsElts[rhsIdx];
80+
81+ if (getElementLabel (lhsElt) != getElementLabel (rhsElt))
82+ break ;
83+
84+ // FIXME: Check flags
85+
86+ auto lhsType = getElementType (lhsElt);
87+ auto rhsType = getElementType (rhsElt);
88+
89+ if (lhsType->template is <PackExpansionType>() ||
90+ rhsType->template is <PackExpansionType>()) {
91+ break ;
92+ }
93+
94+ // FIXME: Check flags
95+
96+ pairs.emplace_back (lhsType, rhsType, lhsIdx, rhsIdx);
97+ ++prefixLength;
98+ }
99+
100+ // Consume the longest possible suffix where neither type in
101+ // the pair is a pack expansion type.
102+ unsigned suffixLength = 0 ;
103+ for (unsigned i = 0 ; i < minLength - prefixLength; ++i) {
104+ unsigned lhsIdx = lhsElts.size () - i - 1 ;
105+ unsigned rhsIdx = rhsElts.size () - i - 1 ;
106+
107+ auto lhsElt = lhsElts[lhsIdx];
108+ auto rhsElt = rhsElts[rhsIdx];
109+
110+ // FIXME: Check flags
111+
112+ if (getElementLabel (lhsElt) != getElementLabel (rhsElt))
113+ break ;
114+
115+ auto lhsType = getElementType (lhsElt);
116+ auto rhsType = getElementType (rhsElt);
117+
118+ if (lhsType->template is <PackExpansionType>() ||
119+ rhsType->template is <PackExpansionType>()) {
120+ break ;
121+ }
122+
123+ pairs.emplace_back (lhsType, rhsType, lhsIdx, rhsIdx);
124+ ++suffixLength;
125+ }
126+
127+ assert (prefixLength + suffixLength <= lhsElts.size ());
128+ assert (prefixLength + suffixLength <= rhsElts.size ());
129+
130+ // Drop the consumed prefix and suffix from each list of types.
131+ lhsElts = lhsElts.drop_front (prefixLength).drop_back (suffixLength);
132+ rhsElts = rhsElts.drop_front (prefixLength).drop_back (suffixLength);
133+
134+ // If nothing remains, we're done.
135+ if (lhsElts.empty () && rhsElts.empty ())
136+ return false ;
137+
138+ // If the left hand side is a single pack expansion type, bind it
139+ // to what remains of the right hand side.
140+ if (lhsElts.size () == 1 ) {
141+ auto lhsType = getElementType (lhsElts[0 ]);
142+ if (auto *lhsExpansion = lhsType->template getAs <PackExpansionType>()) {
143+ unsigned lhsIdx = prefixLength;
144+ unsigned rhsIdx = prefixLength;
145+
146+ SmallVector<Type, 2 > rhsTypes;
147+ for (auto rhsElt : rhsElts) {
148+ if (!getElementLabel (rhsElt).empty ())
149+ return true ;
62150
63- bool match ();
151+ // FIXME: Check rhs flags
152+ rhsTypes.push_back (getElementType (rhsElt));
153+ }
154+ auto rhs = createPackBinding (rhsTypes);
155+
156+ // FIXME: Check lhs flags
157+ pairs.emplace_back (lhsExpansion, rhs, lhsIdx, rhsIdx);
158+ return false ;
159+ }
160+ }
161+
162+ // If the right hand side is a single pack expansion type, bind it
163+ // to what remains of the left hand side.
164+ if (rhsElts.size () == 1 ) {
165+ auto rhsType = getElementType (rhsElts[0 ]);
166+ if (auto *rhsExpansion = rhsType->template getAs <PackExpansionType>()) {
167+ unsigned lhsIdx = prefixLength;
168+ unsigned rhsIdx = prefixLength;
169+
170+ SmallVector<Type, 2 > lhsTypes;
171+ for (auto lhsElt : lhsElts) {
172+ if (!getElementLabel (lhsElt).empty ())
173+ return true ;
174+
175+ // FIXME: Check lhs flags
176+ lhsTypes.push_back (getElementType (lhsElt));
177+ }
178+ auto lhs = createPackBinding (lhsTypes);
179+
180+ // FIXME: Check rhs flags
181+ pairs.emplace_back (lhs, rhsExpansion, lhsIdx, rhsIdx);
182+ return false ;
183+ }
184+ }
185+
186+ // Otherwise, all remaining possibilities are invalid:
187+ // - Neither side has any pack expansions, and they have different lengths.
188+ // - One side has a pack expansion but the other side is too short, eg
189+ // {Int, T..., Float} vs {Int}.
190+ // - The prefix and suffix are mismatched, so we're left with something
191+ // like {T..., Int} vs {Float, U...}.
192+ return true ;
193+ }
194+
195+ private:
196+ Identifier getElementLabel (const Element &) const ;
197+ Type getElementType (const Element &) const ;
198+ ParameterTypeFlags getElementFlags (const Element &) const ;
199+
200+ PackExpansionType *createPackBinding (ArrayRef<Type> types) const {
201+ // If there is only one element and it's a PackExpansionType,
202+ // return it directly.
203+ if (types.size () == 1 ) {
204+ if (auto *expansionType = types.front ()->getAs <PackExpansionType>()) {
205+ return expansionType;
206+ }
207+ }
208+
209+ // Otherwise, wrap the elements in PackExpansionType(PackType(...)).
210+ auto *packType = PackType::get (ctx, types);
211+ return PackExpansionType::get (packType, packType);
212+ }
213+ };
214+
215+ // / Performs a structural match of two lists of tuple elements.
216+ // /
217+ // / The invariant is that each list must only contain at most one pack
218+ // / expansion type. After collecting a common prefix and suffix, the
219+ // / pack expansion on either side asborbs the remaining elements on the
220+ // / other side.
221+ class TuplePackMatcher : public TypeListPackMatcher <TupleTypeElt> {
222+ public:
223+ TuplePackMatcher (TupleType *lhsTuple, TupleType *rhsTuple)
224+ : TypeListPackMatcher(lhsTuple->getASTContext (),
225+ lhsTuple->getElements(),
226+ rhsTuple->getElements()) {}
64227};
65228
66229// / Performs a structural match of two lists of (unlabeled) function
@@ -70,20 +233,11 @@ class TuplePackMatcher {
70233// / expansion type. After collecting a common prefix and suffix, the
71234// / pack expansion on either side asborbs the remaining elements on the
72235// / other side.
73- class ParamPackMatcher {
74- ArrayRef<AnyFunctionType::Param> lhsParams;
75- ArrayRef<AnyFunctionType::Param> rhsParams;
76-
77- ASTContext &ctx;
78-
236+ class ParamPackMatcher : public TypeListPackMatcher <AnyFunctionType::Param> {
79237public:
80- SmallVector<MatchedPair, 4 > pairs;
81-
82238 ParamPackMatcher (ArrayRef<AnyFunctionType::Param> lhsParams,
83- ArrayRef<AnyFunctionType::Param> rhsParams,
84- ASTContext &ctx);
85-
86- bool match ();
239+ ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
240+ : TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
87241};
88242
89243// / Performs a structural match of two lists of types.
@@ -92,20 +246,10 @@ class ParamPackMatcher {
92246// / expansion type. After collecting a common prefix and suffix, the
93247// / pack expansion on either side asborbs the remaining elements on the
94248// / other side.
95- class PackMatcher {
96- ArrayRef<Type> lhsTypes;
97- ArrayRef<Type> rhsTypes;
98-
99- ASTContext &ctx;
100-
249+ class PackMatcher : public TypeListPackMatcher <Type> {
101250public:
102- SmallVector<MatchedPair, 4 > pairs;
103-
104- PackMatcher (ArrayRef<Type> lhsTypes,
105- ArrayRef<Type> rhsTypes,
106- ASTContext &ctx);
107-
108- bool match ();
251+ PackMatcher (ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
252+ : TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
109253};
110254
111255} // end namespace swift
0 commit comments