From d2153ed5a40a5a1f14cd80f3ec2565ba071309a4 Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Mon, 31 Jul 2023 16:01:56 +0200 Subject: [PATCH] Reimplement support for type aliases in SAM types This was dropped in #18201 which restricted SAM types to valid parent types, but it turns out that there is code in the wild that relies on refinements being allowed here. To support this properly, we had to enhance ExpandSAMs to move refinements into type members to pass Ycheck (previous Scala 3 releases would accept the code in tests/run/i18315.scala but fail Ycheck). Fixes #18315. --- compiler/src/dotty/tools/dotc/ast/tpd.scala | 23 ++++---- .../src/dotty/tools/dotc/core/Types.scala | 59 ++++++++++++------- .../tools/dotc/transform/ExpandSAMs.scala | 32 +++++----- tests/run/i18315.scala | 15 +++++ 4 files changed, 84 insertions(+), 45 deletions(-) create mode 100644 tests/run/i18315.scala diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 4ff5c4c8c41d..ad2676624b0f 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -349,24 +349,27 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { /** An anonymous class * - * new parents { forwarders } + * new parents { termForwarders; typeAliases } * - * where `forwarders` contains forwarders for all functions in `fns`. - * @param parents a non-empty list of class types - * @param fns a non-empty of functions for which forwarders should be defined in the class. - * The class has the same owner as the first function in `fns`. - * Its position is the union of all functions in `fns`. + * @param parents a non-empty list of class types + * @param termForwarders a non-empty list of forwarding definitions specified by their name and the definition they forward to. + * @param typeMembers a possibly-empty list of type members specified by their name and their right hand side. + * + * The class has the same owner as the first function in `termForwarders`. + * Its position is the union of all symbols in `termForwarders`. */ - def AnonClass(parents: List[Type], fns: List[TermSymbol], methNames: List[TermName])(using Context): Block = { - AnonClass(fns.head.owner, parents, fns.map(_.span).reduceLeft(_ union _)) { cls => - def forwarder(fn: TermSymbol, name: TermName) = { + def AnonClass(parents: List[Type], termForwarders: List[(TermName, TermSymbol)], + typeMembers: List[(TypeName, TypeBounds)] = Nil)(using Context): Block = { + AnonClass(termForwarders.head._2.owner, parents, termForwarders.map(_._2.span).reduceLeft(_ union _)) { cls => + def forwarder(name: TermName, fn: TermSymbol) = { val fwdMeth = fn.copy(cls, name, Synthetic | Method | Final).entered.asTerm for overridden <- fwdMeth.allOverriddenSymbols do if overridden.is(Extension) then fwdMeth.setFlag(Extension) if !overridden.is(Deferred) then fwdMeth.setFlag(Override) DefDef(fwdMeth, ref(fn).appliedToArgss(_)) } - fns.lazyZip(methNames).map(forwarder) + termForwarders.map((name, sym) => forwarder(name, sym)) ++ + typeMembers.map((name, info) => TypeDef(newSymbol(cls, name, Synthetic, info).entered)) } } diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index b6581177b4be..d68ab1aedf49 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -5536,13 +5536,16 @@ object Types { * and PolyType not allowed!) according to `possibleSamMethods`. * - can be instantiated without arguments or with just () as argument. * + * Additionally, a SAM type may contain type aliases refinements if they refine + * an existing type member. + * * The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the * type of the single abstract method and `samParent` is a subtype of the matched * SAM type which has been stripped of wildcards to turn it into a valid parent * type. */ object SAMType { - /** If possible, return a type which is both a subtype of `origTp` and a type + /** If possible, return a type which is both a subtype of `origTp` and a (possibly refined) type * application of `samClass` where none of the type arguments are * wildcards (thus making it a valid parent type), otherwise return * NoType. @@ -5572,27 +5575,41 @@ object Types { * we arbitrarily pick the upper-bound. */ def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type = - val tp = origTp.baseType(samClass) + val tp0 = origTp.baseType(samClass) + + /** Copy type aliases refinements to `toTp` from `fromTp` */ + def withRefinements(toType: Type, fromTp: Type): Type = fromTp.dealias match + case RefinedType(fromParent, name, info: TypeAlias) if tp0.member(name).exists => + val parent1 = withRefinements(toType, fromParent) + RefinedType(toType, name, info) + case _ => toType + val tp = withRefinements(tp0, origTp) + if !(tp <:< origTp) then NoType - else tp match - case tp @ AppliedType(tycon, args) if tp.hasWildcardArg => - val accu = new TypeAccumulator[VarianceMap[Symbol]]: - def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match - case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) => - vmap.recordLocalVariance(tp.symbol, variance) - case _ => - foldOver(vmap, t) - val vmap = accu(VarianceMap.empty, samMeth.info) - val tparams = tycon.typeParamSymbols - val args1 = args.zipWithConserve(tparams): - case (arg @ TypeBounds(lo, hi), tparam) => - val v = vmap.computedVariance(tparam) - if v.uncheckedNN < 0 then lo - else hi - case (arg, _) => arg - tp.derivedAppliedType(tycon, args1) - case _ => - tp + else + def approxWildcardArgs(tp: Type): Type = tp match + case tp @ AppliedType(tycon, args) if tp.hasWildcardArg => + val accu = new TypeAccumulator[VarianceMap[Symbol]]: + def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match + case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) => + vmap.recordLocalVariance(tp.symbol, variance) + case _ => + foldOver(vmap, t) + val vmap = accu(VarianceMap.empty, samMeth.info) + val tparams = tycon.typeParamSymbols + val args1 = args.zipWithConserve(tparams): + case (arg @ TypeBounds(lo, hi), tparam) => + val v = vmap.computedVariance(tparam) + if v.uncheckedNN < 0 then lo + else hi + case (arg, _) => arg + tp.derivedAppliedType(tycon, args1) + case tp @ RefinedType(parent, name, info) => + tp.derivedRefinedType(approxWildcardArgs(parent), name, info) + case _ => + tp + approxWildcardArgs(tp) + end samParent def samClass(tp: Type)(using Context): Symbol = tp match case tp: ClassInfo => diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index a933b247a85f..6dae564041ee 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -6,6 +6,7 @@ import core._ import Scopes.newScope import Contexts._, Symbols._, Types._, Flags._, Decorators._, StdNames._, Constants._ import MegaPhase._ +import Names.TypeName import SymUtils._ import NullOpsDecorator._ import ast.untpd @@ -51,16 +52,28 @@ class ExpandSAMs extends MiniPhase: case tpe if defn.isContextFunctionType(tpe) => tree case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) => - val tpe1 = checkRefinements(tpe, fn) - toPartialFunction(tree, tpe1) + toPartialFunction(tree, tpe) case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) => - checkRefinements(tpe, fn) tree case tpe => - val tpe1 = checkRefinements(tpe.stripNull, fn) + // A SAM type is allowed to have type aliases refinements (see + // SAMType#samParent) which must be converted into type members if + // the closure is desugared into a class. + val refinements = collection.mutable.ListBuffer[(TypeName, TypeAlias)]() + def collectAndStripRefinements(tp: Type): Type = tp match + case RefinedType(parent, name, info: TypeAlias) => + val res = collectAndStripRefinements(parent) + refinements += ((name.asTypeName, info)) + res + case _ => tp + val tpe1 = collectAndStripRefinements(tpe) val Seq(samDenot) = tpe1.possibleSamMethods cpy.Block(tree)(stats, - AnonClass(tpe1 :: Nil, fn.symbol.asTerm :: Nil, samDenot.symbol.asTerm.name :: Nil)) + AnonClass(List(tpe1), + List(samDenot.symbol.asTerm.name -> fn.symbol.asTerm), + refinements.toList + ) + ) } case _ => tree @@ -171,13 +184,4 @@ class ExpandSAMs extends MiniPhase: List(isDefinedAtDef, applyOrElseDef) } } - - private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match { - case RefinedType(parent, name, _) => - if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement - report.error(em"Lambda does not define $name", tree.srcPos) - checkRefinements(parent, tree) - case tpe => - tpe - } end ExpandSAMs diff --git a/tests/run/i18315.scala b/tests/run/i18315.scala new file mode 100644 index 000000000000..85824920efbd --- /dev/null +++ b/tests/run/i18315.scala @@ -0,0 +1,15 @@ +trait Sam1: + type T + def apply(x: T): T + +trait Sam2: + var x: Int = 1 // To force anonymous class generation + type T + def apply(x: T): T + +object Test: + def main(args: Array[String]): Unit = + val s1: Sam1 { type T = String } = x => x.trim + s1.apply("foo") + val s2: Sam2 { type T = Int } = x => x + 1 + s2.apply(1)