@@ -470,35 +470,25 @@ bool AbstractionPattern::doesTupleContainPackExpansionType() const {
470470}
471471
472472void AbstractionPattern::forEachTupleElement (CanTupleType substType,
473- llvm::function_ref<void (unsigned origEltIndex,
474- unsigned substEltIndex,
475- AbstractionPattern origEltType,
476- CanType substEltType)>
477- handleScalar,
478- llvm::function_ref<void(unsigned origEltIndex,
479- unsigned substEltIndex,
480- AbstractionPattern origExpansionType,
481- CanTupleEltTypeArrayRef substEltTypes)>
482- handleExpansion) const {
483- assert (isTuple () && " can only call on a tuple expansion" );
484- assert (matchesTuple (substType));
485-
486- size_t substEltIndex = 0 ;
487- auto substEltTypes = substType.getElementTypes ();
488- for (size_t origEltIndex : range (getNumTupleElements ())) {
489- auto origEltType = getTupleElementType (origEltIndex);
490- if (!origEltType.isPackExpansion ()) {
491- handleScalar (origEltIndex, substEltIndex,
492- origEltType, substEltTypes[substEltIndex]);
493- substEltIndex++;
494- } else {
495- auto numComponents = origEltType.getNumPackExpandedComponents ();
496- handleExpansion (origEltIndex, substEltIndex, origEltType,
497- substEltTypes.slice (substEltIndex, numComponents));
498- substEltIndex += numComponents;
499- }
473+ llvm::function_ref<void (TupleElementGenerator &)> handleElement) const {
474+ TupleElementGenerator elt (*this , substType);
475+ for (; !elt.isFinished (); elt.advance ()) {
476+ handleElement (elt);
500477 }
501- assert (substEltIndex == substEltTypes.size ());
478+ elt.finish ();
479+ }
480+
481+ TupleElementGenerator::TupleElementGenerator (
482+ AbstractionPattern origTupleType,
483+ CanTupleType substTupleType)
484+ : origTupleType(origTupleType), substTupleType(substTupleType) {
485+ assert (origTupleType.isTuple ());
486+ assert (origTupleType.matchesTuple (substTupleType));
487+
488+ origTupleTypeIsOpaque = origTupleType.isOpaqueTuple ();
489+ numOrigElts = origTupleType.getNumTupleElements ();
490+
491+ if (!isFinished ()) loadElement ();
502492}
503493
504494void AbstractionPattern::forEachExpandedTupleElement (CanTupleType substType,
@@ -2196,28 +2186,19 @@ class SubstFunctionTypePatternVisitor
21962186 CanType visitTupleType (CanTupleType tuple, AbstractionPattern pattern) {
21972187 assert (pattern.isTuple ());
21982188
2199- // It's pretty weird for us to end up in this case with an
2200- // open-coded tuple pattern, but it happens with opaque derivative
2201- // functions in autodiff.
2202- CanTupleType origTupleTypeForLabels = pattern.getAs <TupleType>();
2203- if (!origTupleTypeForLabels) origTupleTypeForLabels = tuple;
2204-
22052189 SmallVector<TupleTypeElt, 4 > tupleElts;
2206- pattern.forEachTupleElement (tuple,
2207- [&](unsigned origEltIndex, unsigned substEltIndex,
2208- AbstractionPattern origEltType, CanType substEltType) {
2209- auto eltTy = visit (substEltType, origEltType);
2210- auto &origElt = origTupleTypeForLabels->getElement (origEltIndex);
2211- tupleElts.push_back (origElt.getWithType (eltTy));
2212- }, [&](unsigned origEltIndex, unsigned substEltIndex,
2213- AbstractionPattern origExpansionType,
2214- CanTupleEltTypeArrayRef substEltTypes) {
2215- CanType candidateSubstType;
2216- if (!substEltTypes.empty ())
2217- candidateSubstType = substEltTypes[0 ];
2218- auto eltTy = handlePackExpansion (origExpansionType, candidateSubstType);
2219- auto &origElt = origTupleTypeForLabels->getElement (origEltIndex);
2220- tupleElts.push_back (origElt.getWithType (eltTy));
2190+ pattern.forEachTupleElement (tuple, [&](TupleElementGenerator &elt) {
2191+ auto substEltTypes = elt.getSubstTypes ();
2192+ CanType eltTy;
2193+ if (!elt.isOrigPackExpansion ()) {
2194+ eltTy = visit (substEltTypes[0 ], elt.getOrigType ());
2195+ } else {
2196+ CanType candidateSubstType;
2197+ if (!substEltTypes.empty ())
2198+ candidateSubstType = substEltTypes[0 ];
2199+ eltTy = handlePackExpansion (elt.getOrigType (), candidateSubstType);
2200+ }
2201+ tupleElts.push_back (elt.getOrigElement ().getWithType (eltTy));
22212202 });
22222203
22232204 return CanType (TupleType::get (tupleElts, TC.Context ));
@@ -2236,7 +2217,7 @@ class SubstFunctionTypePatternVisitor
22362217
22372218 pattern.forEachFunctionParam (func.getParams (), /* ignore self*/ false ,
22382219 [&](FunctionParamGenerator ¶m) {
2239- if (!param.isPackExpansion ()) {
2220+ if (!param.isOrigPackExpansion ()) {
22402221 auto newParamTy = visit (param.getSubstParams ()[0 ].getParameterType (),
22412222 param.getOrigType ());
22422223 addParam (param.getOrigFlags (), newParamTy);
0 commit comments