Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -506,21 +506,21 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
def Inlined_copy(original: Tree)(call: Option[Term | TypeTree], bindings: List[Definition], expansion: Term)(implicit ctx: Context): Inlined =
tpd.cpy.Inlined(original)(call.getOrElse(tpd.EmptyTree), bindings.asInstanceOf[List[tpd.MemberDef]], expansion)

type Lambda = tpd.Closure
type Closure = tpd.Closure

def matchLambda(x: Term)(implicit ctx: Context): Option[Lambda] = x match {
def matchClosure(x: Term)(implicit ctx: Context): Option[Closure] = x match {
case x: tpd.Closure => Some(x)
case _ => None
}

def Lambda_meth(self: Lambda)(implicit ctx: Context): Term = self.meth
def Lambda_tptOpt(self: Lambda)(implicit ctx: Context): Option[TypeTree] = optional(self.tpt)
def Closure_meth(self: Closure)(implicit ctx: Context): Term = self.meth
def Closure_tpeOpt(self: Closure)(implicit ctx: Context): Option[Type] = optional(self.tpt).map(_.tpe)

def Lambda_apply(meth: Term, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
withDefaultPos(ctx => tpd.Closure(Nil, meth, tpt.getOrElse(tpd.EmptyTree))(ctx))
def Closure_apply(meth: Term, tpe: Option[Type])(implicit ctx: Context): Closure =
withDefaultPos(ctx => tpd.Closure(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))(ctx))

def Lambda_copy(original: Tree)(meth: Tree, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
tpd.cpy.Closure(original)(Nil, meth, tpt.getOrElse(tpd.EmptyTree))
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(implicit ctx: Context): Closure =
tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))

type If = tpd.If

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ object Matcher {
tpt1 =#= tpt2 &&
withEnv(rhsEnv)(rhs1 =#= rhs2)

case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
case (Closure(_, tpt1), Closure(_, tpt2)) =>
// TODO match tpt1 with tpt2?
matched

Expand Down
14 changes: 11 additions & 3 deletions library/src/scala/tasty/reflect/Core.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ package scala.tasty.reflect
* | +- Typed
* | +- Assign
* | +- Block
* | +- Lambda
* | +- Closure
* | +- If
* | +- Match
* | +- ImpliedMatch
Expand Down Expand Up @@ -200,8 +200,16 @@ trait Core {
/** Tree representing a block `{ ... }` in the source code */
type Block = kernel.Block

/** Tree representing a lambda `(...) => ...` in the source code */
type Lambda = kernel.Lambda
/** A lambda `(...) => ...` in the source code is represented as
* a local method and a closure:
*
* {
* def m(...) = ...
* closure(m)
* }
*
*/
type Closure = kernel.Closure

/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
type If = kernel.If
Expand Down
24 changes: 16 additions & 8 deletions library/src/scala/tasty/reflect/Kernel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package scala.tasty.reflect
* | +- Typed
* | +- Assign
* | +- Block
* | +- Lambda
* | +- Closure
* | +- If
* | +- Match
* | +- ImpliedMatch
Expand Down Expand Up @@ -436,16 +436,24 @@ trait Kernel {
def Block_apply(stats: List[Statement], expr: Term)(implicit ctx: Context): Block
def Block_copy(original: Tree)(stats: List[Statement], expr: Term)(implicit ctx: Context): Block

/** Tree representing a lambda `(...) => ...` in the source code */
type Lambda <: Term
/** A lambda `(...) => ...` in the source code is represented as
* a local method and a closure:
*
* {
* def m(...) = ...
* closure(m)
* }
*
*/
type Closure <: Term

def matchLambda(tree: Tree)(implicit ctx: Context): Option[Lambda]
def matchClosure(tree: Tree)(implicit ctx: Context): Option[Closure]

def Lambda_meth(self: Lambda)(implicit ctx: Context): Term
def Lambda_tptOpt(self: Lambda)(implicit ctx: Context): Option[TypeTree]
def Closure_meth(self: Closure)(implicit ctx: Context): Term
def Closure_tpeOpt(self: Closure)(implicit ctx: Context): Option[Type]

def Lambda_apply(meth: Term, tpt: Option[TypeTree])(implicit ctx: Context): Lambda
def Lambda_copy(original: Tree)(meth: Tree, tpt: Option[TypeTree])(implicit ctx: Context): Lambda
def Closure_apply(meth: Term, tpe: Option[Type])(implicit ctx: Context): Closure
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(implicit ctx: Context): Closure

/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
type If <: Term
Expand Down
47 changes: 19 additions & 28 deletions library/src/scala/tasty/reflect/Printers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ trait Printers
this += "Block(" ++= stats += ", " += expr += ")"
case If(cond, thenp, elsep) =>
this += "If(" += cond += ", " += thenp += ", " += elsep += ")"
case Lambda(meth, tpt) =>
this += "Lambda(" += meth += ", " += tpt += ")"
case Closure(meth, tpt) =>
this += "Closure(" += meth += ", " += tpt += ")"
case Match(selector, cases) =>
this += "Match(" += selector += ", " ++= cases += ")"
case ImpliedMatch(cases) =>
Expand Down Expand Up @@ -406,6 +406,7 @@ trait Printers

private implicit class TypeOps(buff: Buffer) {
def +=(x: TypeOrBounds): Buffer = { visitType(x); buff }
def +=(x: Option[TypeOrBounds]): Buffer = { visitOption(x, visitType); buff }
def ++=(x: List[TypeOrBounds]): Buffer = { visitList(x, visitType); buff }
}

Expand Down Expand Up @@ -740,17 +741,6 @@ trait Printers
printTree(body)
}

case IsDefDef(ddef @ DefDef(name, targs, argss, _, rhsOpt)) if name.startsWith("$anonfun") =>
// Decompile lambda definition
assert(targs.isEmpty)
val args :: Nil = argss
val Some(rhs) = rhsOpt
inParens {
printArgsDefs(args)
this += " => "
printTree(rhs)
}

case IsDefDef(ddef @ DefDef(name, targs, argss, tpt, rhs)) =>
printDefAnnotations(ddef)

Expand Down Expand Up @@ -901,6 +891,13 @@ trait Printers
this += " = "
printTree(rhs)

case Lambda(params, body) => // must come before `Block`
inParens {
printArgsDefs(params)
this += " => "
printTree(body)
}

case Block(stats0, expr) =>
val stats = stats0.filter {
case IsValDef(tree) => !tree.symbol.flags.is(Flags.Object)
Expand All @@ -911,10 +908,6 @@ trait Printers
case Inlined(_, bindings, expansion) =>
printFlatBlock(bindings, expansion)

case Lambda(meth, tpt) =>
// Printed in by it's DefDef
this

case If(cond, thenp, elsep) =>
this += highlightKeyword("if ")
inParens(printTree(cond))
Expand Down Expand Up @@ -982,6 +975,8 @@ trait Printers
def flatBlock(stats: List[Statement], expr: Term): (List[Statement], Term) = {
val flatStats = List.newBuilder[Statement]
def extractFlatStats(stat: Statement): Unit = stat match {
case Lambda(_, _) => // must come before `Block`
flatStats += stat
case Block(stats1, expr1) =>
val it = stats1.iterator
while (it.hasNext)
Expand All @@ -996,6 +991,8 @@ trait Printers
case stat => flatStats += stat
}
def extractFlatExpr(term: Term): Term = term match {
case Lambda(_, _) => // must come before `Block`
term
case Block(stats1, expr1) =>
val it = stats1.iterator
while (it.hasNext)
Expand All @@ -1017,23 +1014,16 @@ trait Printers

def printFlatBlock(stats: List[Statement], expr: Term)(implicit elideThis: Option[Symbol]): Buffer = {
val (stats1, expr1) = flatBlock(stats, expr)
// Remove Lambda nodes, lambdas are printed by their definition
val stats2 = stats1.filter {
case Lambda(_, _) => false
case IsTypeDef(tree) => !tree.symbol.annots.exists(_.symbol.owner.fullName == "scala.internal.Quoted$.quoteTypeTag")
case _ => true
}
val (stats3, expr3) = expr1 match {
case Lambda(_, _) =>
val init :+ last = stats2
(init, last)
case _ => (stats2, expr1)
}
if (stats3.isEmpty) {
printTree(expr3)
if (stats2.isEmpty) {
printTree(expr1)
} else {
this += "{"
indented {
printStats(stats3, expr3)
printStats(stats2, expr1)
}
this += lineBreak() += "}"
}
Expand All @@ -1043,6 +1033,7 @@ trait Printers
def printSeparator(next: Tree): Unit = {
// Avoid accidental application of opening `{` on next line with a double break
def rec(next: Tree): Unit = next match {
case Lambda(_, _) => this += lineBreak()
case Block(stats, _) if stats.nonEmpty => this += doubleLineBreak()
case Inlined(_, bindings, _) if bindings.nonEmpty => this += doubleLineBreak()
case Select(qual, _) => rec(qual)
Expand Down
48 changes: 35 additions & 13 deletions library/src/scala/tasty/reflect/TreeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -552,26 +552,48 @@ trait TreeOps extends Core {
def expr(implicit ctx: Context): Term = kernel.Block_expr(self)
}

object IsLambda {
/** Matches any Lambda and returns it */
def unapply(tree: Tree)(implicit ctx: Context): Option[Lambda] = kernel.matchLambda(tree)
object IsClosure {
/** Matches any Closure and returns it */
def unapply(tree: Tree)(implicit ctx: Context): Option[Closure] = kernel.matchClosure(tree)
}

object Lambda {
object Closure {

def apply(meth: Term, tpt: Option[Type])(implicit ctx: Context): Closure =
kernel.Closure_apply(meth, tpt)

def apply(meth: Term, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
kernel.Lambda_apply(meth, tpt)
def copy(original: Tree)(meth: Tree, tpt: Option[Type])(implicit ctx: Context): Closure =
kernel.Closure_copy(original)(meth, tpt)

def copy(original: Tree)(meth: Tree, tpt: Option[TypeTree])(implicit ctx: Context): Lambda =
kernel.Lambda_copy(original)(meth, tpt)
def unapply(tree: Tree)(implicit ctx: Context): Option[(Term, Option[Type])] =
kernel.matchClosure(tree).map(x => (x.meth, x.tpeOpt))
}

def unapply(tree: Tree)(implicit ctx: Context): Option[(Term, Option[TypeTree])] =
kernel.matchLambda(tree).map(x => (x.meth, x.tptOpt))
implicit class ClosureAPI(self: Closure) {
def meth(implicit ctx: Context): Term = kernel.Closure_meth(self)
def tpeOpt(implicit ctx: Context): Option[Type] = kernel.Closure_tpeOpt(self)
}

implicit class LambdaAPI(self: Lambda) {
def meth(implicit ctx: Context): Term = kernel.Lambda_meth(self)
def tptOpt(implicit ctx: Context): Option[TypeTree] = kernel.Lambda_tptOpt(self)
/** A lambda `(...) => ...` in the source code is represented as
* a local method and a closure:
*
* {
* def m(...) = ...
* closure(m)
* }
*
* @note Due to the encoding, in pattern matches the case for `Lambda`
* should come before the case for `Block` to avoid mishandling
* of `Lambda`.
*/
object Lambda {
def unapply(tree: Tree)(implicit ctx: Context): Option[(List[ValDef], Term)] = tree match {
case Block((ddef @ DefDef(_, _, params :: Nil, _, Some(body))) :: Nil, Closure(meth, _))
if ddef.symbol == meth.symbol =>
Some(params, body)

case _ => None
}
}

object IsIf {
Expand Down
9 changes: 4 additions & 5 deletions library/src/scala/tasty/reflect/TreeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ trait TreeUtils
foldTree(foldTree(foldTree(x, cond), thenp), elsep)
case While(cond, body) =>
foldTree(foldTree(x, cond), body)
case Lambda(meth, tpt) =>
val a = foldTree(x, meth)
tpt.fold(a)(b => foldTree(a, b))
case Closure(meth, tpt) =>
foldTree(x, meth)
case Match(selector, cases) =>
foldTrees(foldTree(x, selector), cases)
case Return(expr) =>
Expand Down Expand Up @@ -193,8 +192,8 @@ trait TreeUtils
Block.copy(tree)(transformStats(stats), transformTerm(expr))
case If(cond, thenp, elsep) =>
If.copy(tree)(transformTerm(cond), transformTerm(thenp), transformTerm(elsep))
case Lambda(meth, tpt) =>
Lambda.copy(tree)(transformTerm(meth), tpt.map(x => transformTypeTree(x)))
case Closure(meth, tpt) =>
Closure.copy(tree)(transformTerm(meth), tpt)
case Match(selector, cases) =>
Match.copy(tree)(transformTerm(selector), transformCaseDefs(cases))
case Return(expr) =>
Expand Down
8 changes: 1 addition & 7 deletions tests/run-macros/i5941/macro_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,7 @@ object Lens {

object Function {
def unapply(t: Term): Option[(List[ValDef], Term)] = t match {
case Inlined(
None, Nil,
Block(
(ddef @ DefDef(_, Nil, params :: Nil, _, Some(body))) :: Nil,
Lambda(meth, _)
)
) if meth.symbol == ddef.symbol => Some((params, body))
case Inlined(None, Nil, Lambda(params, body)) => Some((params, body))
case _ => None
}
}
Expand Down
18 changes: 18 additions & 0 deletions tests/run-macros/reflect-lambda/assert_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import scala.quoted._
import scala.tasty._

object lib {

inline def assert(condition: => Boolean): Unit = ${ assertImpl('condition, '{""}) }

def assertImpl(cond: Expr[Boolean], clue: Expr[Any])(implicit refl: Reflection): Expr[Unit] = {
import refl._
import util._

cond.unseal.underlyingArgument match {
case t @ Apply(Select(lhs, op), Lambda(param :: Nil, Apply(Select(a, "=="), b :: Nil)) :: Nil)
if a.symbol == param.symbol || b.symbol == param.symbol =>
'{ scala.Predef.assert($cond) }
}
}
}
14 changes: 14 additions & 0 deletions tests/run-macros/reflect-lambda/test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
object Test {
import lib._

case class IntList(args: Int*) {
def exists(f: Int => Boolean): Boolean = args.exists(f)
}

def main(args: Array[String]): Unit = {
assert(IntList(3, 5).exists(_ == 3))
assert(IntList(3, 5).exists(5 == _))
assert(IntList(3, 5).exists(x => x == 3))
assert(IntList(3, 5).exists(x => 5 == x))
}
}
4 changes: 2 additions & 2 deletions tests/run-macros/tasty-extractors-2.check
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Inlined(None, Nil, Block(List(ValDef("x", Inferred(), Some(Literal(Constant(1))))), Assign(Ident("x"), Literal(Constant(2)))))
Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix())))

Inlined(None, Nil, Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", TypeIdent("Int"), None))), Inferred(), Some(Ident("x")))), Lambda(Ident("$anonfun"), None)))
Inlined(None, Nil, Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", TypeIdent("Int"), None))), Inferred(), Some(Ident("x")))), Closure(Ident("$anonfun"), None)))
Type.AppliedType(Type.SymRef(IsClassDefSymbol(<scala.Function1>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix()))), List(Type.SymRef(IsClassDefSymbol(<scala.Int>), Type.SymRef(IsPackageDefSymbol(<scala>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<<root>>), NoPrefix())))), Type.SymRef(IsClassDefSymbol(<scala.Int>), Type.SymRef(IsPackageDefSymbol(<scala>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<<root>>), NoPrefix()))))))

Inlined(None, Nil, Ident("???"))
Expand Down Expand Up @@ -100,6 +100,6 @@ Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageD
Inlined(None, Nil, Block(List(ClassDef("Foo", DefDef("<init>", Nil, List(Nil), Inferred(), None), List(Apply(Select(New(Inferred()), "<init>"), Nil)), Nil, None, List(TypeDef("X", TypeBoundsTree(Inferred(), Inferred())))), DefDef("f", Nil, List(List(ValDef("a", Refined(TypeIdent("Foo"), List(TypeDef("X", TypeIdent("Int")))), None))), TypeSelect(Ident("a"), "X"), Some(Ident("???")))), Literal(Constant(()))))
Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix())))

Inlined(None, Nil, Block(List(ValDef("lambda", Applied(Inferred(), List(TypeIdent("Int"), TypeIdent("Int"))), Some(Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", Inferred(), None))), Inferred(), Some(Ident("x")))), Lambda(Ident("$anonfun"), None))))), Literal(Constant(()))))
Inlined(None, Nil, Block(List(ValDef("lambda", Applied(Inferred(), List(TypeIdent("Int"), TypeIdent("Int"))), Some(Block(List(DefDef("$anonfun", Nil, List(List(ValDef("x", Inferred(), None))), Inferred(), Some(Ident("x")))), Closure(Ident("$anonfun"), None))))), Literal(Constant(()))))
Type.SymRef(IsClassDefSymbol(<scala.Unit>), Type.ThisType(Type.SymRef(IsPackageDefSymbol(<scala>), NoPrefix())))