Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions src/dotty/tools/dotc/transform/FullParameterization.scala
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,13 @@ trait FullParameterization {
*
* If a self type is present, $this has this self type as its type.
*/
def fullyParameterizedType(info: Type, clazz: ClassSymbol)(implicit ctx: Context): Type = {
def fullyParameterizedType(info: Type, clazz: ClassSymbol, abstractOverClass: Boolean = true)(implicit ctx: Context): Type = {
val (mtparamCount, origResult) = info match {
case info @ PolyType(mtnames) => (mtnames.length, info.resultType)
case info: ExprType => (0, info.resultType)
case _ => (0, info)
}
val ctparams = clazz.typeParams
val ctparams = if(abstractOverClass) clazz.typeParams else Nil
val ctnames = ctparams.map(_.name.unexpandedName())

/** The method result type */
Expand All @@ -104,7 +104,7 @@ trait FullParameterization {
/** Replace class type parameters by the added type parameters of the polytype `pt` */
def mapClassParams(tp: Type, pt: PolyType): Type = {
val classParamsRange = (mtparamCount until mtparamCount + ctparams.length).toList
tp.substDealias(clazz.typeParams, classParamsRange map (PolyParam(pt, _)))
tp.substDealias(ctparams, classParamsRange map (PolyParam(pt, _)))
}

/** The bounds for the added type paraneters of the polytype `pt` */
Expand Down Expand Up @@ -141,19 +141,23 @@ trait FullParameterization {
/** The type parameters (skolems) of the method definition `originalDef`,
* followed by the class parameters of its enclosing class.
*/
private def allInstanceTypeParams(originalDef: DefDef)(implicit ctx: Context): List[Symbol] =
originalDef.tparams.map(_.symbol) ::: originalDef.symbol.enclosingClass.typeParams
private def allInstanceTypeParams(originalDef: DefDef, abstractOverClass: Boolean)(implicit ctx: Context): List[Symbol] =
if (abstractOverClass)
originalDef.tparams.map(_.symbol) ::: originalDef.symbol.enclosingClass.typeParams
else originalDef.tparams.map(_.symbol)

/** Given an instance method definition `originalDef`, return a
* fully parameterized method definition derived from `originalDef`, which
* has `derived` as symbol and `fullyParameterizedType(originalDef.symbol.info)`
* as info.
* `abstractOverClass` defines weather the DefDef should abstract over type parameters
* of class that contained original defDef
*/
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree =
def fullyParameterizedDef(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree =
polyDefDef(derived, trefs => vrefss => {
val origMeth = originalDef.symbol
val origClass = origMeth.enclosingClass.asClass
val origTParams = allInstanceTypeParams(originalDef)
val origTParams = allInstanceTypeParams(originalDef, abstractOverClass)
val origVParams = originalDef.vparamss.flatten map (_.symbol)
val thisRef :: argRefs = vrefss.flatten

Expand Down Expand Up @@ -214,13 +218,13 @@ trait FullParameterization {
})

/** A forwarder expression which calls `derived`, passing along
* - the type parameters and enclosing class parameters of `originalDef`,
* - if `abstractOverClass` the type parameters and enclosing class parameters of originalDef`,
* - the `this` of the enclosing class,
* - the value parameters of the original method `originalDef`.
*/
def forwarder(derived: TermSymbol, originalDef: DefDef)(implicit ctx: Context): Tree =
def forwarder(derived: TermSymbol, originalDef: DefDef, abstractOverClass: Boolean = true)(implicit ctx: Context): Tree =
ref(derived.termRef)
.appliedToTypes(allInstanceTypeParams(originalDef).map(_.typeRef))
.appliedToTypes(allInstanceTypeParams(originalDef, abstractOverClass).map(_.typeRef))
.appliedTo(This(originalDef.symbol.enclosingClass.asClass))
.appliedToArgss(originalDef.vparamss.nestedMap(vparam => ref(vparam.symbol)))
.withPos(originalDef.rhs.pos)
Expand Down
40 changes: 24 additions & 16 deletions src/dotty/tools/dotc/transform/TailRec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,25 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
final val labelPrefix = "tailLabel"
final val labelFlags = Flags.Synthetic | Flags.Label

private def mkLabel(method: Symbol)(implicit c: Context): TermSymbol = {
private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
val name = c.freshName(labelPrefix)

c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass))
c.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass))
}

override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
val sym = tree.symbol
tree match {
case dd@DefDef(name, tparams, vparamss0, tpt, rhs0)
if (dd.symbol.isEffectivelyFinal) && !((dd.symbol is Flags.Accessor) || (rhs0 eq EmptyTree) || (dd.symbol is Flags.Label)) =>
val mandatory = dd.symbol.hasAnnotation(defn.TailrecAnnotationClass)
if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (rhs0 eq EmptyTree) || (sym is Flags.Label)) =>
val mandatory = sym.hasAnnotation(defn.TailrecAnnotationClass)
atGroupEnd { implicit ctx: Context =>

cpy.DefDef(dd)(rhs = {

val origMeth = tree.symbol
val label = mkLabel(dd.symbol)
val defIsTopLevel = sym.owner.isClass
val origMeth = sym
val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
val owner = ctx.owner.enclosingClass.asClass
val thisTpe = owner.thisType.widen

Expand All @@ -101,16 +103,16 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
// and second one will actually apply,
// now this speculatively transforms tree and throws away result in many cases
val rhsSemiTransformed = {
val transformer = new TailRecElimination(dd.symbol, owner, thisTpe, mandatory, label)
val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
val rhs = atGroupEnd(transformer.transform(rhs0)(_))
rewrote = transformer.rewrote
rhs
}

if (rewrote) {
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
val res = fullyParameterizedDef(label, dummyDefDef)
val call = forwarder(label, dd)
val res = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel)
Block(List(res), call)
} else {
if (mandatory)
Expand All @@ -130,7 +132,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete

}

class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {

import dotty.tools.dotc.ast.tpd._

Expand Down Expand Up @@ -179,9 +181,11 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
val targs = typeArguments.map(noTailTransform)
val argumentss = arguments.map(noTailTransforms)

val receiverIsSame = enclosingClass.typeRef.widen =:= recv.tpe.widen
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recv.tpe.widen
val receiverIsThis = recv.tpe.widen =:= thisType
val recvWiden = recv.tpe.widenDealias

val receiverIsSame = enclosingClass.typeRef.widenDealias =:= recvWiden
val receiverIsSuper = (method.name eq sym) && enclosingClass.typeRef.widen <:< recvWiden
val receiverIsThis = recv.tpe =:= thisType

val isRecursiveCall = (method eq sym)

Expand All @@ -204,9 +208,13 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
c.debuglog("Rewriting tail recursive call: " + tree.pos)
rewrote = true
val reciever = noTailTransform(recv)
val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
val trz = classTypeArgs.map(x => ref(x.typeSymbol))
val callTargs: List[tpd.Tree] = targs ::: trz

val callTargs: List[tpd.Tree] =
if(abstractOverClass) {
val classTypeArgs = recv.tpe.baseTypeWithArgs(enclosingClass).argInfos
targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
} else targs

val method = Apply(if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef),
List(reciever))

Expand Down
26 changes: 26 additions & 0 deletions tests/pos/tailcall/i321.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import scala.annotation.tailrec
/**
* Illustrates that abstracting over type arguments without triggering Ycheck failure is tricky
*
* go1.loop refers to type parameter of i321, and captures value f
* if go1.loop will abstract over T it will need to cast f or will trigger a Ycheck failure.
* One could decide to not abstract over type parameters in tail calls, but this leads us to go2 example
*
* In go2 we should abstract over i321.T, as we need to change it in recursive call.
*
* For now decision is such - we will abstract for top-level methods, but will not for inner ones.
*/

class i321[T >: Null <: AnyRef] {

def go1(f: T => Int): Int = {
@tailrec def loop(pending: T): Int = {
val head1 = f(pending)
loop(pending)
}
loop(null)
}

final def go2[U >: Null <: AnyRef](t: i321[U]): Int = t.go2(this)

}