Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 182 additions & 38 deletions include/swift/AST/PackExpansionMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,26 +41,189 @@ struct MatchedPair {
: lhs(lhs), rhs(rhs), lhsIdx(lhsIdx), rhsIdx(rhsIdx) {}
};

/// Performs a structural match of two lists of tuple elements. The invariant
/// is that a pack expansion type must not be followed by an unlabeled
/// element, that is, it is either the last element or the next element has
/// a label.
/// Performs a structural match of two lists of types.
///
/// In this manner, an element with a pack expansion type "absorbs" all
/// unlabeled elements up to the next label. An element with any other type
/// matches exactly one element on the other side.
class TuplePackMatcher {
ArrayRef<TupleTypeElt> lhsElts;
ArrayRef<TupleTypeElt> rhsElts;

/// The invariant is that each list must only contain at most one pack
/// expansion type. After collecting a common prefix and suffix, the
/// pack expansion on either side asborbs the remaining elements on the
/// other side.
template <typename Element>
class TypeListPackMatcher {
ASTContext &ctx;

ArrayRef<Element> lhsElements;
ArrayRef<Element> rhsElements;

protected:
TypeListPackMatcher(ASTContext &ctx, ArrayRef<Element> lhs,
ArrayRef<Element> rhs)
: ctx(ctx), lhsElements(lhs), rhsElements(rhs) {}

public:
SmallVector<MatchedPair, 4> pairs;

TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple);
[[nodiscard]] bool match() {
ArrayRef<Element> lhsElts(lhsElements);
ArrayRef<Element> rhsElts(rhsElements);

unsigned minLength = std::min(lhsElts.size(), rhsElts.size());

// Consume the longest possible prefix where neither type in
// the pair is a pack expansion type.
unsigned prefixLength = 0;
for (unsigned i = 0; i < minLength; ++i) {
unsigned lhsIdx = i;
unsigned rhsIdx = i;

auto lhsElt = lhsElts[lhsIdx];
auto rhsElt = rhsElts[rhsIdx];

if (getElementLabel(lhsElt) != getElementLabel(rhsElt))
break;

// FIXME: Check flags

auto lhsType = getElementType(lhsElt);
auto rhsType = getElementType(rhsElt);

if (lhsType->template is<PackExpansionType>() ||
rhsType->template is<PackExpansionType>()) {
break;
}

// FIXME: Check flags

pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
++prefixLength;
}

// Consume the longest possible suffix where neither type in
// the pair is a pack expansion type.
unsigned suffixLength = 0;
for (unsigned i = 0; i < minLength - prefixLength; ++i) {
unsigned lhsIdx = lhsElts.size() - i - 1;
unsigned rhsIdx = rhsElts.size() - i - 1;

auto lhsElt = lhsElts[lhsIdx];
auto rhsElt = rhsElts[rhsIdx];

// FIXME: Check flags

if (getElementLabel(lhsElt) != getElementLabel(rhsElt))
break;

auto lhsType = getElementType(lhsElt);
auto rhsType = getElementType(rhsElt);

if (lhsType->template is<PackExpansionType>() ||
rhsType->template is<PackExpansionType>()) {
break;
}

pairs.emplace_back(lhsType, rhsType, lhsIdx, rhsIdx);
++suffixLength;
}

assert(prefixLength + suffixLength <= lhsElts.size());
assert(prefixLength + suffixLength <= rhsElts.size());

// Drop the consumed prefix and suffix from each list of types.
lhsElts = lhsElts.drop_front(prefixLength).drop_back(suffixLength);
rhsElts = rhsElts.drop_front(prefixLength).drop_back(suffixLength);

// If nothing remains, we're done.
if (lhsElts.empty() && rhsElts.empty())
return false;

// If the left hand side is a single pack expansion type, bind it
// to what remains of the right hand side.
if (lhsElts.size() == 1) {
auto lhsType = getElementType(lhsElts[0]);
if (auto *lhsExpansion = lhsType->template getAs<PackExpansionType>()) {
unsigned lhsIdx = prefixLength;
unsigned rhsIdx = prefixLength;

SmallVector<Type, 2> rhsTypes;
for (auto rhsElt : rhsElts) {
if (!getElementLabel(rhsElt).empty())
return true;

bool match();
// FIXME: Check rhs flags
rhsTypes.push_back(getElementType(rhsElt));
}
auto rhs = createPackBinding(rhsTypes);

// FIXME: Check lhs flags
pairs.emplace_back(lhsExpansion, rhs, lhsIdx, rhsIdx);
return false;
}
}

// If the right hand side is a single pack expansion type, bind it
// to what remains of the left hand side.
if (rhsElts.size() == 1) {
auto rhsType = getElementType(rhsElts[0]);
if (auto *rhsExpansion = rhsType->template getAs<PackExpansionType>()) {
unsigned lhsIdx = prefixLength;
unsigned rhsIdx = prefixLength;

SmallVector<Type, 2> lhsTypes;
for (auto lhsElt : lhsElts) {
if (!getElementLabel(lhsElt).empty())
return true;

// FIXME: Check lhs flags
lhsTypes.push_back(getElementType(lhsElt));
}
auto lhs = createPackBinding(lhsTypes);

// FIXME: Check rhs flags
pairs.emplace_back(lhs, rhsExpansion, lhsIdx, rhsIdx);
return false;
}
}

// Otherwise, all remaining possibilities are invalid:
// - Neither side has any pack expansions, and they have different lengths.
// - One side has a pack expansion but the other side is too short, eg
// {Int, T..., Float} vs {Int}.
// - The prefix and suffix are mismatched, so we're left with something
// like {T..., Int} vs {Float, U...}.
return true;
}

private:
Identifier getElementLabel(const Element &) const;
Type getElementType(const Element &) const;
ParameterTypeFlags getElementFlags(const Element &) const;

PackExpansionType *createPackBinding(ArrayRef<Type> types) const {
// If there is only one element and it's a PackExpansionType,
// return it directly.
if (types.size() == 1) {
if (auto *expansionType = types.front()->getAs<PackExpansionType>()) {
return expansionType;
}
}

// Otherwise, wrap the elements in PackExpansionType(PackType(...)).
auto *packType = PackType::get(ctx, types);
return PackExpansionType::get(packType, packType);
}
};

/// Performs a structural match of two lists of tuple elements.
///
/// The invariant is that each list must only contain at most one pack
/// expansion type. After collecting a common prefix and suffix, the
/// pack expansion on either side asborbs the remaining elements on the
/// other side.
class TuplePackMatcher : public TypeListPackMatcher<TupleTypeElt> {
public:
TuplePackMatcher(TupleType *lhsTuple, TupleType *rhsTuple)
: TypeListPackMatcher(lhsTuple->getASTContext(),
lhsTuple->getElements(),
rhsTuple->getElements()) {}
};

/// Performs a structural match of two lists of (unlabeled) function
Expand All @@ -70,20 +233,11 @@ class TuplePackMatcher {
/// expansion type. After collecting a common prefix and suffix, the
/// pack expansion on either side asborbs the remaining elements on the
/// other side.
class ParamPackMatcher {
ArrayRef<AnyFunctionType::Param> lhsParams;
ArrayRef<AnyFunctionType::Param> rhsParams;

ASTContext &ctx;

class ParamPackMatcher : public TypeListPackMatcher<AnyFunctionType::Param> {
public:
SmallVector<MatchedPair, 4> pairs;

ParamPackMatcher(ArrayRef<AnyFunctionType::Param> lhsParams,
ArrayRef<AnyFunctionType::Param> rhsParams,
ASTContext &ctx);

bool match();
ArrayRef<AnyFunctionType::Param> rhsParams, ASTContext &ctx)
: TypeListPackMatcher(ctx, lhsParams, rhsParams) {}
};

/// Performs a structural match of two lists of types.
Expand All @@ -92,20 +246,10 @@ class ParamPackMatcher {
/// expansion type. After collecting a common prefix and suffix, the
/// pack expansion on either side asborbs the remaining elements on the
/// other side.
class PackMatcher {
ArrayRef<Type> lhsTypes;
ArrayRef<Type> rhsTypes;

ASTContext &ctx;

class PackMatcher : public TypeListPackMatcher<Type> {
public:
SmallVector<MatchedPair, 4> pairs;

PackMatcher(ArrayRef<Type> lhsTypes,
ArrayRef<Type> rhsTypes,
ASTContext &ctx);

bool match();
PackMatcher(ArrayRef<Type> lhsTypes, ArrayRef<Type> rhsTypes, ASTContext &ctx)
: TypeListPackMatcher(ctx, lhsTypes, rhsTypes) {}
};

} // end namespace swift
Expand Down
Loading