Skip to content

Commit 7ee1d79

Browse files
committed
Merge pull request #336 from dotty-staging/fix-321
Fix transformation of inner tail recursive methods
2 parents b09c2e8 + 1957093 commit 7ee1d79

File tree

3 files changed

+64
-26
lines changed

3 files changed

+64
-26
lines changed

src/dotty/tools/dotc/transform/FullParameterization.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,13 @@ trait FullParameterization {
8585
*
8686
* If a self type is present, $this has this self type as its type.
8787
*/
88-
def fullyParameterizedType(info: Type, clazz: ClassSymbol)(implicit ctx: Context): Type = {
88+
def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true)(implicit ctx: Context): Type = {
8989
val (mtparamCount, origResult) = info match {
9090
case info @ PolyType(mtnames) => (mtnames.length, info.resultType)
9191
case info: ExprType => (0, info.resultType)
9292
case _ => (0, info)
9393
}
94-
val ctparams = clazz.typeParams
94+
val ctparams = if(abstractOverClass) clazz.typeParams else Nil
9595
val ctnames = ctparams.map(_.name.unexpandedName())
9696

9797
/** The method result type */
@@ -104,7 +104,7 @@ trait FullParameterization {
104104
/** Replace class type parameters by the added type parameters of the polytype `pt` */
105105
def mapClassParams(tp: Type, pt: PolyType): Type = {
106106
val classParamsRange = (mtparamCount until mtparamCount + ctparams.length).toList
107-
tp.substDealias(clazz.typeParams, classParamsRange map (PolyParam(pt, _)))
107+
tp.substDealias(ctparams, classParamsRange map (PolyParam(pt, _)))
108108
}
109109

110110
/** The bounds for the added type paraneters of the polytype `pt` */
@@ -141,19 +141,23 @@ trait FullParameterization {
141141
/** The type parameters (skolems) of the method definition `originalDef`,
142142
* followed by the class parameters of its enclosing class.
143143
*/
144-
private def allInstanceTypeParams(originalDef: DefDef)(implicit ctx: Context): List[Symbol] =
145-
originalDef.tparams.map(_.symbol) ::: originalDef.symbol.enclosingClass.typeParams
144+
private def allInstanceTypeParams(originalDef: DefDef, abstractOverClass: Boolean)(implicit ctx: Context): List[Symbol] =
145+
if (abstractOverClass)
146+
originalDef.tparams.map(_.symbol) ::: originalDef.symbol.enclosingClass.typeParams
147+
else originalDef.tparams.map(_.symbol)
146148

147149
/** Given an instance method definition `originalDef`, return a
148150
* fully parameterized method definition derived from `originalDef`, which
149151
* has `derived` as symbol and `fullyParameterizedType(originalDef.symbol.info)`
150152
* as info.
153+
* `abstractOverClass` defines weather the DefDef should abstract over type parameters
154+
* of class that contained original defDef
151155
*/
152-
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree =
156+
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree =
153157
polyDefDef(derived, trefs => vrefss => {
154158
val origMeth = originalDef.symbol
155159
val origClass = origMeth.enclosingClass.asClass
156-
val origTParams = allInstanceTypeParams(originalDef)
160+
val origTParams = allInstanceTypeParams(originalDef, abstractOverClass)
157161
val origVParams = originalDef.vparamss.flatten map (_.symbol)
158162
val thisRef :: argRefs = vrefss.flatten
159163

@@ -214,13 +218,13 @@ trait FullParameterization {
214218
})
215219

216220
/** A forwarder expression which calls `derived`, passing along
217-
* - the type parameters and enclosing class parameters of `originalDef`,
221+
* - if `abstractOverClass` the type parameters and enclosing class parameters of originalDef`,
218222
* - the `this` of the enclosing class,
219223
* - the value parameters of the original method `originalDef`.
220224
*/
221-
def forwarder(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree =
225+
def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree =
222226
ref(derived.termRef)
223-
.appliedToTypes(allInstanceTypeParams(originalDef).map(_.typeRef))
227+
.appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef))
224228
.appliedTo(This(originalDef.symbol.enclosingClass.asClass))
225229
.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol)))
226230
.withPos(originalDef.rhs.pos)

src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,25 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
7474
final val labelPrefix = "tailLabel"
7575
final val labelFlags = Flags.Synthetic | Flags.Label
7676

77-
private def mkLabel(method: Symbol)(implicit c: Context): TermSymbol = {
77+
private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
7878
val name = c.freshName(labelPrefix)
7979

80-
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass))
80+
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass))
8181
}
8282

8383
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
84+
val sym = tree.symbol
8485
tree match {
8586
case dd@DefDef(name, tparams, vparamss0, tpt, rhs0)
86-
if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree) || (dd.symbol is Flags.Label)) =>
87-
val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass)
87+
if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (rhs0 eq EmptyTree) || (sym is Flags.Label)) =>
88+
val mandatory = sym.hasAnnotation(defn.TailrecAnnotationClass)
8889
atGroupEnd { implicit ctx: Context =>
8990

9091
cpy.DefDef(dd)(rhs = {
9192

92-
val origMeth = tree.symbol
93-
val label = mkLabel(dd.symbol)
93+
val defIsTopLevel = sym.owner.isClass
94+
val origMeth = sym
95+
val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
9496
val owner = ctx.owner.enclosingClass.asClass
9597
val thisTpe = owner.thisType.widen
9698

@@ -101,16 +103,16 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
101103
// and second one will actually apply,
102104
// now this speculatively transforms tree and throws away result in many cases
103105
val rhsSemiTransformed = {
104-
val transformer = new TailRecElimination(dd.symbol, owner, thisTpe, mandatory, label)
106+
val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
105107
val rhs = atGroupEnd(transformer.transform(rhs0)(_))
106108
rewrote = transformer.rewrote
107109
rhs
108110
}
109111

110112
if (rewrote) {
111113
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
112-
val res = fullyParameterizedDef(label, dummyDefDef)
113-
val call = forwarder(label, dd)
114+
val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
115+
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel)
114116
Block(List(res), call)
115117
} else {
116118
if (mandatory)
@@ -130,7 +132,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
130132

131133
}
132134

133-
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
135+
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
134136

135137
import dotty.tools.dotc.ast.tpd._
136138

@@ -179,9 +181,11 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
179181
val targs = typeArguments.map(noTailTransform)
180182
val argumentss = arguments.map(noTailTransforms)
181183

182-
val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen
183-
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen
184-
val receiverIsThis = recv.tpe.widen =:= thisType
184+
val recvWiden = recv.tpe.widenDealias
185+
186+
val receiverIsSame = enclosingClass.typeRef.widenDealias =:= recvWiden
187+
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recvWiden
188+
val receiverIsThis = recv.tpe =:= thisType
185189

186190
val isRecursiveCall = (method eq sym)
187191

@@ -204,9 +208,13 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
204208
c.debuglog("Rewriting tail recursive call: " + tree.pos)
205209
rewrote = true
206210
val reciever = noTailTransform(recv)
207-
val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
208-
val trz = classTypeArgs.map(x => ref(x.typeSymbol))
209-
val callTargs: List[tpd.Tree] = targs ::: trz
211+
212+
val callTargs: List[tpd.Tree] =
213+
if(abstractOverClass) {
214+
val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
215+
targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
216+
} else targs
217+
210218
val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef),
211219
List(reciever))
212220

tests/pos/tailcall/i321.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import scala.annotation.tailrec
2+
/**
3+
* Illustrates that abstracting over type arguments without triggering Ycheck failure is tricky
4+
*
5+
* go1.loop refers to type parameter of i321, and captures value f
6+
* if go1.loop will abstract over T it will need to cast f or will trigger a Ycheck failure.
7+
* One could decide to not abstract over type parameters in tail calls, but this leads us to go2 example
8+
*
9+
* In go2 we should abstract over i321.T, as we need to change it in recursive call.
10+
*
11+
* For now decision is such - we will abstract for top-level methods, but will not for inner ones.
12+
*/
13+
14+
class i321[T >: Null <: AnyRef] {
15+
16+
def go1(f: T => Int): Int = {
17+
@tailrec def loop(pending: T): Int = {
18+
val head1 = f(pending)
19+
loop(pending)
20+
}
21+
loop(null)
22+
}
23+
24+
final def go2[U >: Null <: AnyRef](t: i321[U]): Int = t.go2(this)
25+
26+
}

0 commit comments

Comments
 (0)