@@ -3259,9 +3259,10 @@ Identifier DependentMemberType::getName() const {
32593259 return NameOrAssocType.get <AssociatedTypeDecl *>()->getName ();
32603260}
32613261
3262- static bool transformSILResult (SILResultInfo &result, bool &changed,
3263- llvm::function_ref<Type(Type)> fn) {
3264- Type transType = result.getType ().transform (fn);
3262+ static bool transformSILResult (
3263+ SILResultInfo &result, bool &changed,
3264+ llvm::function_ref<Optional<Type>(TypeBase *)> fn) {
3265+ Type transType = result.getType ().transformRec (fn);
32653266 if (!transType) return true ;
32663267
32673268 CanType canTransType = transType->getCanonicalType ();
@@ -3272,9 +3273,10 @@ static bool transformSILResult(SILResultInfo &result, bool &changed,
32723273 return false ;
32733274}
32743275
3275- static bool transformSILParameter (SILParameterInfo ¶m, bool &changed,
3276- llvm::function_ref<Type(Type)> fn) {
3277- Type transType = param.getType ().transform (fn);
3276+ static bool transformSILParameter (
3277+ SILParameterInfo ¶m, bool &changed,
3278+ llvm::function_ref<Optional<Type>(TypeBase *)> fn) {
3279+ Type transType = param.getType ().transformRec (fn);
32783280 if (!transType) return true ;
32793281
32803282 CanType canTransType = transType->getCanonicalType ();
@@ -3286,13 +3288,28 @@ static bool transformSILParameter(SILParameterInfo ¶m, bool &changed,
32863288}
32873289
32883290Type Type::transform (llvm::function_ref<Type(Type)> fn) const {
3291+ return transformRec ([&fn](TypeBase *type) -> Optional<Type> {
3292+ Type transformed = fn (Type (type));
3293+ if (!transformed)
3294+ return Type ();
3295+
3296+ // If the function didn't change the type at all, let transformRec()
3297+ // recurse.
3298+ if (transformed.getPointer () == type)
3299+ return None;
3300+
3301+ return transformed;
3302+ });
3303+ }
3304+
3305+ Type Type::transformRec (
3306+ llvm::function_ref<Optional<Type>(TypeBase *)> fn) const {
32893307 if (!isa<ParenType>(getPointer ())) {
32903308 // Transform this type node.
3291- Type transformed = fn (*this );
3309+ if (Optional<Type> transformed = fn (getPointer ()))
3310+ return *transformed;
32923311
3293- // If the client changed the type, we're done.
3294- if (!transformed || transformed.getPointer () != getPointer ())
3295- return transformed;
3312+ // Recurse.
32963313 }
32973314
32983315 // Recursive into children of this type.
@@ -3314,7 +3331,7 @@ case TypeKind::Id:
33143331 case TypeKind::Protocol: {
33153332 auto nominalTy = cast<NominalType>(base);
33163333 if (auto parentTy = nominalTy->getParent ()) {
3317- parentTy = parentTy.transform (fn);
3334+ parentTy = parentTy.transformRec (fn);
33183335 if (!parentTy)
33193336 return Type ();
33203337
@@ -3330,7 +3347,7 @@ case TypeKind::Id:
33303347
33313348 case TypeKind::SILBlockStorage: {
33323349 auto storageTy = cast<SILBlockStorageType>(base);
3333- Type transCap = storageTy->getCaptureType ().transform (fn);
3350+ Type transCap = storageTy->getCaptureType ().transformRec (fn);
33343351 if (!transCap)
33353352 return Type ();
33363353 CanType canTransCap = transCap->getCanonicalType ();
@@ -3345,7 +3362,7 @@ case TypeKind::Id:
33453362 // generic SILBox.
33463363 auto boxTy = cast<SILBoxType>(base);
33473364 for (auto &arg : boxTy->getGenericArgs ())
3348- assert (arg.getReplacement ()->isEqual (arg.getReplacement ().transform (fn))
3365+ assert (arg.getReplacement ()->isEqual (arg.getReplacement ().transformRec (fn))
33493366 && " SILBoxType can't be transformed" );
33503367#endif
33513368 return base;
@@ -3390,7 +3407,7 @@ case TypeKind::Id:
33903407 case TypeKind::WeakStorage: {
33913408 auto storageTy = cast<ReferenceStorageType>(base);
33923409 Type refTy = storageTy->getReferentType ();
3393- Type substRefTy = refTy.transform (fn);
3410+ Type substRefTy = refTy.transformRec (fn);
33943411 if (!substRefTy)
33953412 return Type ();
33963413
@@ -3405,7 +3422,7 @@ case TypeKind::Id:
34053422 auto unbound = cast<UnboundGenericType>(base);
34063423 Type substParentTy;
34073424 if (auto parentTy = unbound->getParent ()) {
3408- substParentTy = parentTy.transform (fn);
3425+ substParentTy = parentTy.transformRec (fn);
34093426 if (!substParentTy)
34103427 return Type ();
34113428
@@ -3427,7 +3444,7 @@ case TypeKind::Id:
34273444 bool anyChanged = false ;
34283445 Type substParentTy;
34293446 if (auto parentTy = bound->getParent ()) {
3430- substParentTy = parentTy.transform (fn);
3447+ substParentTy = parentTy.transformRec (fn);
34313448 if (!substParentTy)
34323449 return Type ();
34333450
@@ -3436,7 +3453,7 @@ case TypeKind::Id:
34363453 }
34373454
34383455 for (auto arg : bound->getGenericArgs ()) {
3439- Type substArg = arg.transform (fn);
3456+ Type substArg = arg.transformRec (fn);
34403457 if (!substArg)
34413458 return Type ();
34423459 substArgs.push_back (substArg);
@@ -3452,7 +3469,7 @@ case TypeKind::Id:
34523469
34533470 case TypeKind::ExistentialMetatype: {
34543471 auto meta = cast<ExistentialMetatypeType>(base);
3455- auto instanceTy = meta->getInstanceType ().transform (fn);
3472+ auto instanceTy = meta->getInstanceType ().transformRec (fn);
34563473 if (!instanceTy)
34573474 return Type ();
34583475
@@ -3467,7 +3484,7 @@ case TypeKind::Id:
34673484
34683485 case TypeKind::Metatype: {
34693486 auto meta = cast<MetatypeType>(base);
3470- auto instanceTy = meta->getInstanceType ().transform (fn);
3487+ auto instanceTy = meta->getInstanceType ().transformRec (fn);
34713488 if (!instanceTy)
34723489 return Type ();
34733490
@@ -3481,7 +3498,7 @@ case TypeKind::Id:
34813498
34823499 case TypeKind::DynamicSelf: {
34833500 auto dynamicSelf = cast<DynamicSelfType>(base);
3484- auto selfTy = dynamicSelf->getSelfType ().transform (fn);
3501+ auto selfTy = dynamicSelf->getSelfType ().transformRec (fn);
34853502 if (!selfTy)
34863503 return Type ();
34873504
@@ -3494,7 +3511,7 @@ case TypeKind::Id:
34943511 case TypeKind::NameAlias: {
34953512 auto alias = cast<NameAliasType>(base);
34963513 auto underlyingTy = Type (alias->getSinglyDesugaredType ());
3497- auto transformedTy = underlyingTy.transform (fn);
3514+ auto transformedTy = underlyingTy.transformRec (fn);
34983515 if (!transformedTy)
34993516 return Type ();
35003517
@@ -3506,7 +3523,7 @@ case TypeKind::Id:
35063523
35073524 case TypeKind::Paren: {
35083525 auto paren = cast<ParenType>(base);
3509- Type underlying = paren->getUnderlyingType ().transform (fn);
3526+ Type underlying = paren->getUnderlyingType ().transformRec (fn);
35103527 if (!underlying)
35113528 return Type ();
35123529
@@ -3522,7 +3539,7 @@ case TypeKind::Id:
35223539 SmallVector<TupleTypeElt, 4 > elements;
35233540 unsigned Index = 0 ;
35243541 for (const auto &elt : tuple->getElements ()) {
3525- Type eltTy = elt.getType ().transform (fn);
3542+ Type eltTy = elt.getType ().transformRec (fn);
35263543 if (!eltTy)
35273544 return Type ();
35283545
@@ -3555,7 +3572,7 @@ case TypeKind::Id:
35553572
35563573 case TypeKind::DependentMember: {
35573574 auto dependent = cast<DependentMemberType>(base);
3558- auto dependentBase = dependent->getBase ().transform (fn);
3575+ auto dependentBase = dependent->getBase ().transformRec (fn);
35593576 if (!dependentBase)
35603577 return Type ();
35613578
@@ -3570,10 +3587,10 @@ case TypeKind::Id:
35703587
35713588 case TypeKind::Function: {
35723589 auto function = cast<AnyFunctionType>(base);
3573- auto inputTy = function->getInput ().transform (fn);
3590+ auto inputTy = function->getInput ().transformRec (fn);
35743591 if (!inputTy)
35753592 return Type ();
3576- auto resultTy = function->getResult ().transform (fn);
3593+ auto resultTy = function->getResult ().transformRec (fn);
35773594 if (!resultTy)
35783595 return Type ();
35793596
@@ -3592,7 +3609,7 @@ case TypeKind::Id:
35923609 // Transform generic parameters.
35933610 SmallVector<GenericTypeParamType *, 4 > genericParams;
35943611 for (auto param : function->getGenericParams ()) {
3595- Type paramTy = Type (param).transform (fn);
3612+ Type paramTy = Type (param).transformRec (fn);
35963613 if (!paramTy)
35973614 return Type ();
35983615
@@ -3609,7 +3626,7 @@ case TypeKind::Id:
36093626 // Transform requirements.
36103627 SmallVector<Requirement, 4 > requirements;
36113628 for (const auto &req : function->getRequirements ()) {
3612- auto firstType = req.getFirstType ().transform (fn);
3629+ auto firstType = req.getFirstType ().transformRec (fn);
36133630 if (!firstType)
36143631 return Type ();
36153632
@@ -3618,7 +3635,7 @@ case TypeKind::Id:
36183635
36193636 Type secondType = req.getSecondType ();
36203637 if (secondType) {
3621- secondType = secondType.transform (fn);
3638+ secondType = secondType.transformRec (fn);
36223639 if (!secondType)
36233640 return Type ();
36243641
@@ -3637,12 +3654,12 @@ case TypeKind::Id:
36373654 }
36383655
36393656 // Transform input type.
3640- auto inputTy = function->getInput ().transform (fn);
3657+ auto inputTy = function->getInput ().transformRec (fn);
36413658 if (!inputTy)
36423659 return Type ();
36433660
36443661 // Transform result type.
3645- auto resultTy = function->getResult ().transform (fn);
3662+ auto resultTy = function->getResult ().transformRec (fn);
36463663 if (!resultTy)
36473664 return Type ();
36483665
@@ -3695,7 +3712,7 @@ case TypeKind::Id:
36953712
36963713 case TypeKind::ArraySlice: {
36973714 auto slice = cast<ArraySliceType>(base);
3698- auto baseTy = slice->getBaseType ().transform (fn);
3715+ auto baseTy = slice->getBaseType ().transformRec (fn);
36993716 if (!baseTy)
37003717 return Type ();
37013718
@@ -3707,7 +3724,7 @@ case TypeKind::Id:
37073724
37083725 case TypeKind::Optional: {
37093726 auto optional = cast<OptionalType>(base);
3710- auto baseTy = optional->getBaseType ().transform (fn);
3727+ auto baseTy = optional->getBaseType ().transformRec (fn);
37113728 if (!baseTy)
37123729 return Type ();
37133730
@@ -3719,7 +3736,7 @@ case TypeKind::Id:
37193736
37203737 case TypeKind::ImplicitlyUnwrappedOptional: {
37213738 auto optional = cast<ImplicitlyUnwrappedOptionalType>(base);
3722- auto baseTy = optional->getBaseType ().transform (fn);
3739+ auto baseTy = optional->getBaseType ().transformRec (fn);
37233740 if (!baseTy)
37243741 return Type ();
37253742
@@ -3731,11 +3748,11 @@ case TypeKind::Id:
37313748
37323749 case TypeKind::Dictionary: {
37333750 auto dict = cast<DictionaryType>(base);
3734- auto keyTy = dict->getKeyType ().transform (fn);
3751+ auto keyTy = dict->getKeyType ().transformRec (fn);
37353752 if (!keyTy)
37363753 return Type ();
37373754
3738- auto valueTy = dict->getValueType ().transform (fn);
3755+ auto valueTy = dict->getValueType ().transformRec (fn);
37393756 if (!valueTy)
37403757 return Type ();
37413758
@@ -3748,7 +3765,7 @@ case TypeKind::Id:
37483765
37493766 case TypeKind::LValue: {
37503767 auto lvalue = cast<LValueType>(base);
3751- auto objectTy = lvalue->getObjectType ().transform (fn);
3768+ auto objectTy = lvalue->getObjectType ().transformRec (fn);
37523769 if (!objectTy || objectTy->hasError ())
37533770 return objectTy;
37543771
@@ -3758,7 +3775,7 @@ case TypeKind::Id:
37583775
37593776 case TypeKind::InOut: {
37603777 auto inout = cast<InOutType>(base);
3761- auto objectTy = inout->getObjectType ().transform (fn);
3778+ auto objectTy = inout->getObjectType ().transformRec (fn);
37623779 if (!objectTy || objectTy->hasError ())
37633780 return objectTy;
37643781
@@ -3772,7 +3789,7 @@ case TypeKind::Id:
37723789 bool anyChanged = false ;
37733790 unsigned index = 0 ;
37743791 for (auto proto : pc->getProtocols ()) {
3775- auto substProto = proto.transform (fn);
3792+ auto substProto = proto.transformRec (fn);
37763793 if (!substProto)
37773794 return Type ();
37783795
0 commit comments