Skip to content

Commit 7e927f4

Browse files
committed
Even more careful handling of tailcalls.
See i321 doc for description of problem and decision taken.
1 parent 13f25af commit 7e927f4

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -74,23 +74,25 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
7474
final val labelPrefix = "tailLabel"
7575
final val labelFlags = Flags.Synthetic | Flags.Label
7676

77-
private def mkLabel(method: Symbol)(implicit c: Context): TermSymbol = {
77+
private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
7878
val name = c.freshName(labelPrefix)
7979

80-
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass = false))
80+
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass))
8181
}
8282

8383
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
84+
val sym = tree.symbol
8485
tree match {
8586
case dd@DefDef(name, tparams, vparamss0, tpt, rhs0)
86-
if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree) || (dd.symbol is Flags.Label)) =>
87-
val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass)
87+
if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (rhs0 eq EmptyTree) || (sym is Flags.Label)) =>
88+
val mandatory = sym.hasAnnotation(defn.TailrecAnnotationClass)
8889
atGroupEnd { implicit ctx: Context =>
8990

9091
cpy.DefDef(dd)(rhs = {
9192

92-
val origMeth = tree.symbol
93-
val label = mkLabel(dd.symbol)
93+
val defIsTopLevel = sym.owner.isClass
94+
val origMeth = sym
95+
val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
9496
val owner = ctx.owner.enclosingClass.asClass
9597
val thisTpe = owner.thisType.widen
9698

@@ -101,16 +103,16 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
101103
// and second one will actually apply,
102104
// now this speculatively transforms tree and throws away result in many cases
103105
val rhsSemiTransformed = {
104-
val transformer = new TailRecElimination(dd.symbol, owner, thisTpe, mandatory, label)
106+
val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
105107
val rhs = atGroupEnd(transformer.transform(rhs0)(_))
106108
rewrote = transformer.rewrote
107109
rhs
108110
}
109111

110112
if (rewrote) {
111113
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
112-
val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = false)
113-
val call = forwarder(label, dd, abstractOverClass = false)
114+
val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
115+
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel)
114116
Block(List(res), call)
115117
} else {
116118
if (mandatory)
@@ -130,7 +132,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
130132

131133
}
132134

133-
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
135+
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
134136

135137
import dotty.tools.dotc.ast.tpd._
136138

@@ -179,9 +181,11 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
179181
val targs = typeArguments.map(noTailTransform)
180182
val argumentss = arguments.map(noTailTransforms)
181183

182-
val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen
183-
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen
184-
val receiverIsThis = recv.tpe.widen =:= thisType
184+
val recvWiden = recv.tpe.widenDealias
185+
186+
val receiverIsSame = enclosingClass.typeRef.widenDealias =:= recvWiden
187+
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recvWiden
188+
val receiverIsThis = recv.tpe =:= thisType
185189

186190
val isRecursiveCall = (method eq sym)
187191

@@ -205,12 +209,12 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
205209
rewrote = true
206210
val reciever = noTailTransform(recv)
207211

208-
/*
209-
handling changed type arguments in sound way is hard, see test `i321`
210-
val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
211-
val trz = classTypeArgs.map(x => ref(x.typeSymbol))
212-
*/
213-
val callTargs: List[tpd.Tree] = targs
212+
val callTargs: List[tpd.Tree] =
213+
if(abstractOverClass) {
214+
val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
215+
targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
216+
} else targs
217+
214218
val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef),
215219
List(reciever))
216220

tests/pos/tailcall/i321.scala

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,26 @@
1+
import scala.annotation.tailrec
2+
/**
3+
* Illustrates that abstracting over type arguments without triggering Ycheck failure is tricky
4+
*
5+
* go1.loop refers to type parameter of i321, and captures value f
6+
* if goo1.loop will abstract over T it will need to cast f or will trigger a Ycheck failure.
7+
* One could decide to not abstract over type parameters in tail calls, but this leads us to go2 example
8+
*
9+
* In go2 we should abstract over i321.T, as we need to change it in recursive call.
10+
*
11+
* For now decision is such - we will abstract for top-level methods, but will not for inner ones.
12+
*/
13+
114
class i321[T >: Null <: AnyRef] {
215

3-
def mapconserve(f: T => Int): Int = {
4-
def loop(pending: T): Int = {
5-
val head1 = f(pending)
6-
loop(pending)
7-
}
16+
def go1(f: T => Int): Int = {
17+
@tailrec def loop(pending: T): Int = {
18+
val head1 = f(pending)
19+
loop(pending)
20+
}
821
loop(null)
922
}
23+
24+
final def go2[U >: Null <: AnyRef](t: i321[U]): Int = t.go2(this)
25+
1026
}

0 commit comments

Comments
 (0)