Skip to content

Commit 8a69700

Browse files
committed
[AST] Add Type::transformRec() to provide better recursion control
Type::transformRec() is intended to supercede Type::transform(). It improves things along two axes: * It provides better control over recursion, so that the callback function can indicate "I handled this" (by returning a Type, which may be null to propagate out a failure condition) vs. "please recurse" (by returning None). * It passes the types along the way as TypeBase * rather than Type, to encourages callback functions to dyn_cast/isa rather than getAs/is. The latter is unnecessary because the transform operations already handle type sugar.
1 parent de92082 commit 8a69700

File tree

2 files changed

+80
-39
lines changed

2 files changed

+80
-39
lines changed

include/swift/AST/Type.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,30 @@ class Type {
216216
/// \returns the result of transforming the type.
217217
Type transform(llvm::function_ref<Type(Type)> fn) const;
218218

219+
/// Transform the given type by applying the user-provided function to
220+
/// each type.
221+
///
222+
/// This routine applies the given function to transform one type into
223+
/// another. If the function leaves the type unchanged, recurse into the
224+
/// child type nodes and transform those. If any child type node changes,
225+
/// the parent type node will be rebuilt.
226+
///
227+
/// If at any time the function returns a null type, the null will be
228+
/// propagated out.
229+
///
230+
/// If the the function returns \c None, the transform operation will
231+
///
232+
/// \param fn A function object with the signature
233+
/// \c Optional<Type>(TypeBase *), which accepts a type pointer and returns a
234+
/// transformed type, a null type (which will propagate the null type to the
235+
/// outermost \c transform() call), or None (to indicate that the transform
236+
/// operation should recursively transform the subtypes). The function object
237+
/// should use \c dyn_cast rather \c getAs, because the transform itself
238+
/// handles desugaring.
239+
///
240+
/// \returns the result of transforming the type.
241+
Type transformRec(llvm::function_ref<Optional<Type>(TypeBase *)> fn) const;
242+
219243
/// Look through the given type and its children and apply fn to them.
220244
void visit(llvm::function_ref<void (Type)> fn) const {
221245
findIf([&fn](Type t) -> bool {

lib/AST/Type.cpp

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3259,9 +3259,10 @@ Identifier DependentMemberType::getName() const {
32593259
return NameOrAssocType.get<AssociatedTypeDecl *>()->getName();
32603260
}
32613261

3262-
static bool transformSILResult(SILResultInfo &result, bool &changed,
3263-
llvm::function_ref<Type(Type)> fn) {
3264-
Type transType = result.getType().transform(fn);
3262+
static bool transformSILResult(
3263+
SILResultInfo &result, bool &changed,
3264+
llvm::function_ref<Optional<Type>(TypeBase *)> fn) {
3265+
Type transType = result.getType().transformRec(fn);
32653266
if (!transType) return true;
32663267

32673268
CanType canTransType = transType->getCanonicalType();
@@ -3272,9 +3273,10 @@ static bool transformSILResult(SILResultInfo &result, bool &changed,
32723273
return false;
32733274
}
32743275

3275-
static bool transformSILParameter(SILParameterInfo &param, bool &changed,
3276-
llvm::function_ref<Type(Type)> fn) {
3277-
Type transType = param.getType().transform(fn);
3276+
static bool transformSILParameter(
3277+
SILParameterInfo &param, bool &changed,
3278+
llvm::function_ref<Optional<Type>(TypeBase *)> fn) {
3279+
Type transType = param.getType().transformRec(fn);
32783280
if (!transType) return true;
32793281

32803282
CanType canTransType = transType->getCanonicalType();
@@ -3286,13 +3288,28 @@ static bool transformSILParameter(SILParameterInfo &param, bool &changed,
32863288
}
32873289

32883290
Type Type::transform(llvm::function_ref<Type(Type)> fn) const {
3291+
return transformRec([&fn](TypeBase *type) -> Optional<Type> {
3292+
Type transformed = fn(Type(type));
3293+
if (!transformed)
3294+
return Type();
3295+
3296+
// If the function didn't change the type at all, let transformRec()
3297+
// recurse.
3298+
if (transformed.getPointer() == type)
3299+
return None;
3300+
3301+
return transformed;
3302+
});
3303+
}
3304+
3305+
Type Type::transformRec(
3306+
llvm::function_ref<Optional<Type>(TypeBase *)> fn) const {
32893307
if (!isa<ParenType>(getPointer())) {
32903308
// Transform this type node.
3291-
Type transformed = fn(*this);
3309+
if (Optional<Type> transformed = fn(getPointer()))
3310+
return *transformed;
32923311

3293-
// If the client changed the type, we're done.
3294-
if (!transformed || transformed.getPointer() != getPointer())
3295-
return transformed;
3312+
// Recurse.
32963313
}
32973314

32983315
// Recursive into children of this type.
@@ -3314,7 +3331,7 @@ case TypeKind::Id:
33143331
case TypeKind::Protocol: {
33153332
auto nominalTy = cast<NominalType>(base);
33163333
if (auto parentTy = nominalTy->getParent()) {
3317-
parentTy = parentTy.transform(fn);
3334+
parentTy = parentTy.transformRec(fn);
33183335
if (!parentTy)
33193336
return Type();
33203337

@@ -3330,7 +3347,7 @@ case TypeKind::Id:
33303347

33313348
case TypeKind::SILBlockStorage: {
33323349
auto storageTy = cast<SILBlockStorageType>(base);
3333-
Type transCap = storageTy->getCaptureType().transform(fn);
3350+
Type transCap = storageTy->getCaptureType().transformRec(fn);
33343351
if (!transCap)
33353352
return Type();
33363353
CanType canTransCap = transCap->getCanonicalType();
@@ -3345,7 +3362,7 @@ case TypeKind::Id:
33453362
// generic SILBox.
33463363
auto boxTy = cast<SILBoxType>(base);
33473364
for (auto &arg : boxTy->getGenericArgs())
3348-
assert(arg.getReplacement()->isEqual(arg.getReplacement().transform(fn))
3365+
assert(arg.getReplacement()->isEqual(arg.getReplacement().transformRec(fn))
33493366
&& "SILBoxType can't be transformed");
33503367
#endif
33513368
return base;
@@ -3390,7 +3407,7 @@ case TypeKind::Id:
33903407
case TypeKind::WeakStorage: {
33913408
auto storageTy = cast<ReferenceStorageType>(base);
33923409
Type refTy = storageTy->getReferentType();
3393-
Type substRefTy = refTy.transform(fn);
3410+
Type substRefTy = refTy.transformRec(fn);
33943411
if (!substRefTy)
33953412
return Type();
33963413

@@ -3405,7 +3422,7 @@ case TypeKind::Id:
34053422
auto unbound = cast<UnboundGenericType>(base);
34063423
Type substParentTy;
34073424
if (auto parentTy = unbound->getParent()) {
3408-
substParentTy = parentTy.transform(fn);
3425+
substParentTy = parentTy.transformRec(fn);
34093426
if (!substParentTy)
34103427
return Type();
34113428

@@ -3427,7 +3444,7 @@ case TypeKind::Id:
34273444
bool anyChanged = false;
34283445
Type substParentTy;
34293446
if (auto parentTy = bound->getParent()) {
3430-
substParentTy = parentTy.transform(fn);
3447+
substParentTy = parentTy.transformRec(fn);
34313448
if (!substParentTy)
34323449
return Type();
34333450

@@ -3436,7 +3453,7 @@ case TypeKind::Id:
34363453
}
34373454

34383455
for (auto arg : bound->getGenericArgs()) {
3439-
Type substArg = arg.transform(fn);
3456+
Type substArg = arg.transformRec(fn);
34403457
if (!substArg)
34413458
return Type();
34423459
substArgs.push_back(substArg);
@@ -3452,7 +3469,7 @@ case TypeKind::Id:
34523469

34533470
case TypeKind::ExistentialMetatype: {
34543471
auto meta = cast<ExistentialMetatypeType>(base);
3455-
auto instanceTy = meta->getInstanceType().transform(fn);
3472+
auto instanceTy = meta->getInstanceType().transformRec(fn);
34563473
if (!instanceTy)
34573474
return Type();
34583475

@@ -3467,7 +3484,7 @@ case TypeKind::Id:
34673484

34683485
case TypeKind::Metatype: {
34693486
auto meta = cast<MetatypeType>(base);
3470-
auto instanceTy = meta->getInstanceType().transform(fn);
3487+
auto instanceTy = meta->getInstanceType().transformRec(fn);
34713488
if (!instanceTy)
34723489
return Type();
34733490

@@ -3481,7 +3498,7 @@ case TypeKind::Id:
34813498

34823499
case TypeKind::DynamicSelf: {
34833500
auto dynamicSelf = cast<DynamicSelfType>(base);
3484-
auto selfTy = dynamicSelf->getSelfType().transform(fn);
3501+
auto selfTy = dynamicSelf->getSelfType().transformRec(fn);
34853502
if (!selfTy)
34863503
return Type();
34873504

@@ -3494,7 +3511,7 @@ case TypeKind::Id:
34943511
case TypeKind::NameAlias: {
34953512
auto alias = cast<NameAliasType>(base);
34963513
auto underlyingTy = Type(alias->getSinglyDesugaredType());
3497-
auto transformedTy = underlyingTy.transform(fn);
3514+
auto transformedTy = underlyingTy.transformRec(fn);
34983515
if (!transformedTy)
34993516
return Type();
35003517

@@ -3506,7 +3523,7 @@ case TypeKind::Id:
35063523

35073524
case TypeKind::Paren: {
35083525
auto paren = cast<ParenType>(base);
3509-
Type underlying = paren->getUnderlyingType().transform(fn);
3526+
Type underlying = paren->getUnderlyingType().transformRec(fn);
35103527
if (!underlying)
35113528
return Type();
35123529

@@ -3522,7 +3539,7 @@ case TypeKind::Id:
35223539
SmallVector<TupleTypeElt, 4> elements;
35233540
unsigned Index = 0;
35243541
for (const auto &elt : tuple->getElements()) {
3525-
Type eltTy = elt.getType().transform(fn);
3542+
Type eltTy = elt.getType().transformRec(fn);
35263543
if (!eltTy)
35273544
return Type();
35283545

@@ -3555,7 +3572,7 @@ case TypeKind::Id:
35553572

35563573
case TypeKind::DependentMember: {
35573574
auto dependent = cast<DependentMemberType>(base);
3558-
auto dependentBase = dependent->getBase().transform(fn);
3575+
auto dependentBase = dependent->getBase().transformRec(fn);
35593576
if (!dependentBase)
35603577
return Type();
35613578

@@ -3570,10 +3587,10 @@ case TypeKind::Id:
35703587

35713588
case TypeKind::Function: {
35723589
auto function = cast<AnyFunctionType>(base);
3573-
auto inputTy = function->getInput().transform(fn);
3590+
auto inputTy = function->getInput().transformRec(fn);
35743591
if (!inputTy)
35753592
return Type();
3576-
auto resultTy = function->getResult().transform(fn);
3593+
auto resultTy = function->getResult().transformRec(fn);
35773594
if (!resultTy)
35783595
return Type();
35793596

@@ -3592,7 +3609,7 @@ case TypeKind::Id:
35923609
// Transform generic parameters.
35933610
SmallVector<GenericTypeParamType *, 4> genericParams;
35943611
for (auto param : function->getGenericParams()) {
3595-
Type paramTy = Type(param).transform(fn);
3612+
Type paramTy = Type(param).transformRec(fn);
35963613
if (!paramTy)
35973614
return Type();
35983615

@@ -3609,7 +3626,7 @@ case TypeKind::Id:
36093626
// Transform requirements.
36103627
SmallVector<Requirement, 4> requirements;
36113628
for (const auto &req : function->getRequirements()) {
3612-
auto firstType = req.getFirstType().transform(fn);
3629+
auto firstType = req.getFirstType().transformRec(fn);
36133630
if (!firstType)
36143631
return Type();
36153632

@@ -3618,7 +3635,7 @@ case TypeKind::Id:
36183635

36193636
Type secondType = req.getSecondType();
36203637
if (secondType) {
3621-
secondType = secondType.transform(fn);
3638+
secondType = secondType.transformRec(fn);
36223639
if (!secondType)
36233640
return Type();
36243641

@@ -3637,12 +3654,12 @@ case TypeKind::Id:
36373654
}
36383655

36393656
// Transform input type.
3640-
auto inputTy = function->getInput().transform(fn);
3657+
auto inputTy = function->getInput().transformRec(fn);
36413658
if (!inputTy)
36423659
return Type();
36433660

36443661
// Transform result type.
3645-
auto resultTy = function->getResult().transform(fn);
3662+
auto resultTy = function->getResult().transformRec(fn);
36463663
if (!resultTy)
36473664
return Type();
36483665

@@ -3695,7 +3712,7 @@ case TypeKind::Id:
36953712

36963713
case TypeKind::ArraySlice: {
36973714
auto slice = cast<ArraySliceType>(base);
3698-
auto baseTy = slice->getBaseType().transform(fn);
3715+
auto baseTy = slice->getBaseType().transformRec(fn);
36993716
if (!baseTy)
37003717
return Type();
37013718

@@ -3707,7 +3724,7 @@ case TypeKind::Id:
37073724

37083725
case TypeKind::Optional: {
37093726
auto optional = cast<OptionalType>(base);
3710-
auto baseTy = optional->getBaseType().transform(fn);
3727+
auto baseTy = optional->getBaseType().transformRec(fn);
37113728
if (!baseTy)
37123729
return Type();
37133730

@@ -3719,7 +3736,7 @@ case TypeKind::Id:
37193736

37203737
case TypeKind::ImplicitlyUnwrappedOptional: {
37213738
auto optional = cast<ImplicitlyUnwrappedOptionalType>(base);
3722-
auto baseTy = optional->getBaseType().transform(fn);
3739+
auto baseTy = optional->getBaseType().transformRec(fn);
37233740
if (!baseTy)
37243741
return Type();
37253742

@@ -3731,11 +3748,11 @@ case TypeKind::Id:
37313748

37323749
case TypeKind::Dictionary: {
37333750
auto dict = cast<DictionaryType>(base);
3734-
auto keyTy = dict->getKeyType().transform(fn);
3751+
auto keyTy = dict->getKeyType().transformRec(fn);
37353752
if (!keyTy)
37363753
return Type();
37373754

3738-
auto valueTy = dict->getValueType().transform(fn);
3755+
auto valueTy = dict->getValueType().transformRec(fn);
37393756
if (!valueTy)
37403757
return Type();
37413758

@@ -3748,7 +3765,7 @@ case TypeKind::Id:
37483765

37493766
case TypeKind::LValue: {
37503767
auto lvalue = cast<LValueType>(base);
3751-
auto objectTy = lvalue->getObjectType().transform(fn);
3768+
auto objectTy = lvalue->getObjectType().transformRec(fn);
37523769
if (!objectTy || objectTy->hasError())
37533770
return objectTy;
37543771

@@ -3758,7 +3775,7 @@ case TypeKind::Id:
37583775

37593776
case TypeKind::InOut: {
37603777
auto inout = cast<InOutType>(base);
3761-
auto objectTy = inout->getObjectType().transform(fn);
3778+
auto objectTy = inout->getObjectType().transformRec(fn);
37623779
if (!objectTy || objectTy->hasError())
37633780
return objectTy;
37643781

@@ -3772,7 +3789,7 @@ case TypeKind::Id:
37723789
bool anyChanged = false;
37733790
unsigned index = 0;
37743791
for (auto proto : pc->getProtocols()) {
3775-
auto substProto = proto.transform(fn);
3792+
auto substProto = proto.transformRec(fn);
37763793
if (!substProto)
37773794
return Type();
37783795

0 commit comments

Comments
 (0)