@@ -8,12 +8,13 @@ import MegaPhase._
88import SymUtils ._
99import ast .untpd
1010import ast .Trees ._
11+ import dotty .tools .dotc .reporting .diagnostic .messages .TypeMismatch
1112import dotty .tools .dotc .util .Positions .Position
1213
1314/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
1415 * These fall into five categories
1516 *
16- * 1. Partial function closures, we need to generate a isDefinedAt method for these.
17+ * 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
1718 * 2. Closures implementing non-trait classes.
1819 * 3. Closures implementing classes that inherit from a class other than Object
1920 * (a lambda cannot not be a run-time subtype of such a class)
@@ -35,8 +36,8 @@ class ExpandSAMs extends MiniPhase {
3536 tpt.tpe match {
3637 case NoType => tree // it's a plain function
3738 case tpe @ SAMType (_) if tpe.isRef(defn.PartialFunctionClass ) =>
38- checkRefinements(tpe, fn.pos)
39- toPartialFunction(tree)
39+ val tpe1 = checkRefinements(tpe, fn.pos)
40+ toPartialFunction(tree, tpe1 )
4041 case tpe @ SAMType (_) if isPlatformSam(tpe.classSymbol.asClass) =>
4142 checkRefinements(tpe, fn.pos)
4243 tree
@@ -50,50 +51,83 @@ class ExpandSAMs extends MiniPhase {
5051 tree
5152 }
5253
53- private def toPartialFunction (tree : Block )(implicit ctx : Context ): Tree = {
54- val Block (
55- (applyDef @ DefDef (nme.ANON_FUN , Nil , List (List (param)), _, _)) :: Nil ,
56- Closure (_, _, tpt)) = tree
57- val applyRhs : Tree = applyDef.rhs
58- val applyFn = applyDef.symbol.asTerm
59-
60- val MethodTpe (paramNames, paramTypes, _) = applyFn.info
61- val isDefinedAtFn = applyFn.copy(
62- name = nme.isDefinedAt,
63- flags = Synthetic | Method ,
64- info = MethodType (paramNames, paramTypes, defn.BooleanType )).asTerm
65- val tru = Literal (Constant (true ))
66- def isDefinedAtRhs (paramRefss : List [List [Tree ]]) = applyRhs match {
67- case Match (selector, cases) =>
68- assert(selector.symbol == param.symbol)
69- val paramRef = paramRefss.head.head
70- // Again, the alternative
71- // val List(List(paramRef)) = paramRefs
72- // fails with a similar self instantiation error
73- def translateCase (cdef : CaseDef ): CaseDef =
74- cpy.CaseDef (cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
75- val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD , Synthetic , selector.tpe.widen)
76- val defaultCase =
77- CaseDef (
78- Bind (defaultSym, Underscore (selector.tpe.widen)),
79- EmptyTree ,
80- Literal (Constant (false )))
81- val annotated = Annotated (paramRef, New (ref(defn.UncheckedAnnotType )))
82- cpy.Match (applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
54+ private def toPartialFunction (tree : Block , tpe : Type )(implicit ctx : Context ): Tree = {
55+ // /** An extractor for match, either contained in a block or standalone. */
56+ object PartialFunctionRHS {
57+ def unapply (tree : Tree ): Option [Match ] = tree match {
58+ case Block (Nil , expr) => unapply(expr)
59+ case m : Match => Some (m)
60+ case _ => None
61+ }
62+ }
63+
64+ val closureDef(anon @ DefDef (_, _, List (List (param)), _, _)) = tree
65+ anon.rhs match {
66+ case PartialFunctionRHS (pf) =>
67+ val anonSym = anon.symbol
68+
69+ def overrideSym (sym : Symbol ) = sym.copy(
70+ owner = anonSym.owner,
71+ flags = Synthetic | Method | Final ,
72+ info = tpe.memberInfo(sym),
73+ coord = tree.pos).asTerm
74+ val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
75+ val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
76+
77+ def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree ) = {
78+ val selector = tree.selector
79+ val selectorTpe = selector.tpe.widen
80+ val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD , Synthetic , selectorTpe)
81+ val defaultCase =
82+ CaseDef (
83+ Bind (defaultSym, Underscore (selectorTpe)),
84+ EmptyTree ,
85+ defaultValue)
86+ val unchecked = Annotated (selector, New (ref(defn.UncheckedAnnotType )))
87+ cpy.Match (tree)(unchecked, cases :+ defaultCase)
88+ .subst(param.symbol :: Nil , pfParam :: Nil )
89+ // Needed because a partial function can be written as:
90+ // param => param match { case "foo" if foo(param) => param }
91+ // And we need to update all references to 'param'
92+ }
93+
94+ def isDefinedAtRhs (paramRefss : List [List [Tree ]]) = {
95+ val tru = Literal (Constant (true ))
96+ def translateCase (cdef : CaseDef ) =
97+ cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
98+ val paramRef = paramRefss.head.head
99+ val defaultValue = Literal (Constant (false ))
100+ translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
101+ }
102+
103+ def applyOrElseRhs (paramRefss : List [List [Tree ]]) = {
104+ val List (paramRef, defaultRef) = paramRefss.head
105+ def translateCase (cdef : CaseDef ) =
106+ cdef.changeOwner(anonSym, applyOrElseFn)
107+ val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
108+ translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
109+ }
110+
111+ val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)))
112+ val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)))
113+
114+ val parent = defn.AbstractPartialFunctionType .appliedTo(tpe.argInfos)
115+ val anonCls = AnonClass (parent :: Nil , List (isDefinedAtFn, applyOrElseFn), List (nme.isDefinedAt, nme.applyOrElse))
116+ cpy.Block (tree)(List (isDefinedAtDef, applyOrElseDef), anonCls)
117+
83118 case _ =>
84- tru
119+ val found = tpe.baseType(defn.FunctionClass (1 ))
120+ ctx.error(TypeMismatch (found, tpe), tree.pos)
121+ tree
85122 }
86- val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)))
87- val anonCls = AnonClass (tpt.tpe :: Nil , List (applyFn, isDefinedAtFn), List (nme.apply, nme.isDefinedAt))
88- cpy.Block (tree)(List (applyDef, isDefinedAtDef), anonCls)
89123 }
90124
91125 private def checkRefinements (tpe : Type , pos : Position )(implicit ctx : Context ): Type = tpe.dealias match {
92126 case RefinedType (parent, name, _) =>
93127 if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
94128 ctx.error(" Lambda does not define " + name, pos)
95129 checkRefinements(parent, pos)
96- case _ =>
130+ case tpe =>
97131 tpe
98132 }
99133
0 commit comments