@@ -64,15 +64,85 @@ object Splicer {
6464 */
6565 def checkValidMacroBody (tree : Tree )(implicit ctx : Context ): Unit = tree match {
6666 case Quoted (_) => // ok
67- case _ => (new CheckValidMacroBody ).apply(tree)
67+ case _ =>
68+ def checkValidStat (tree : Tree ): Unit = tree match {
69+ case tree : ValDef if tree.symbol.is(Synthetic ) =>
70+ // Check val from `foo(j = x, i = y)` which it is expanded to
71+ // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
72+ checkIfValidArgument(tree.rhs)
73+ case _ =>
74+ ctx.error(" Macro should not have statements" , tree.sourcePos)
75+ }
76+ def checkIfValidArgument (tree : Tree ): Unit = tree match {
77+ case Block (Nil , expr) => checkIfValidArgument(expr)
78+ case Typed (expr, _) => checkIfValidArgument(expr)
79+
80+ case Apply (TypeApply (fn, _), quoted :: Nil ) if fn.symbol == defn.InternalQuoted_exprQuote =>
81+ // OK
82+
83+ case TypeApply (fn, quoted :: Nil ) if fn.symbol == defn.InternalQuoted_typeQuote =>
84+ // OK
85+
86+ case Literal (Constant (value)) =>
87+ // OK
88+
89+ case _ if tree.symbol == defn.QuoteContext_macroContext =>
90+ // OK
91+
92+ case Call (fn, args)
93+ if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package )) ||
94+ fn.symbol.is(Module ) || fn.symbol.isStatic ||
95+ (fn.qualifier.symbol.is(Module ) && fn.qualifier.symbol.isStatic) =>
96+ args.foreach(_.foreach(checkIfValidArgument))
97+
98+ case NamedArg (_, arg) =>
99+ checkIfValidArgument(arg)
100+
101+ case SeqLiteral (elems, _) =>
102+ elems.foreach(checkIfValidArgument)
103+
104+ case tree : Ident if tree.symbol.is(Inline ) || tree.symbol.is(Synthetic ) =>
105+ // OK
106+
107+ case _ =>
108+ ctx.error(
109+ """ Malformed macro parameter
110+ |
111+ |Parameters may be:
112+ | * Quoted parameters or fields
113+ | * References to inline parameters
114+ | * Literal values of primitive types
115+ |""" .stripMargin, tree.sourcePos)
116+ }
117+ def checkIfValidStaticCall (tree : Tree ): Unit = tree match {
118+ case Block (stats, expr) =>
119+ stats.foreach(checkValidStat)
120+ checkIfValidStaticCall(expr)
121+
122+ case Typed (expr, _) =>
123+ checkIfValidStaticCall(expr)
124+
125+ case Call (fn, args)
126+ if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package )) ||
127+ fn.symbol.is(Module ) || fn.symbol.isStatic ||
128+ (fn.qualifier.symbol.is(Module ) && fn.qualifier.symbol.isStatic) =>
129+ args.flatten.foreach(checkIfValidArgument)
130+
131+ case _ =>
132+ ctx.error(
133+ """ Malformed macro.
134+ |
135+ |Expected the splice ${...} to contain a single call to a static method.
136+ |""" .stripMargin, tree.sourcePos)
137+ }
138+
139+ checkIfValidStaticCall(tree)
68140 }
69141
70142 /** Tree interpreter that evaluates the tree */
71- private class Interpreter (pos : SourcePosition , classLoader : ClassLoader )(implicit ctx : Context ) extends AbstractInterpreter {
143+ private class Interpreter (pos : SourcePosition , classLoader : ClassLoader )(implicit ctx : Context ) {
72144
73- def checking : Boolean = false
74-
75- type Result = Object
145+ type Env = Map [Name , Object ]
76146
77147 /** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
78148 * Return Some of the result or None if some error happen during the interpretation.
@@ -93,22 +163,92 @@ object Splicer {
93163 }
94164 }
95165
96- protected def interpretQuote (tree : Tree )(implicit env : Env ): Object =
166+ def interpretTree (tree : Tree )(implicit env : Env ): Object = tree match {
167+ case Apply (TypeApply (fn, _), quoted :: Nil ) if fn.symbol == defn.InternalQuoted_exprQuote =>
168+ val quoted1 = quoted match {
169+ case quoted : Ident if quoted.symbol.isAllOf(InlineByNameProxy ) =>
170+ // inline proxy for by-name parameter
171+ quoted.symbol.defTree.asInstanceOf [DefDef ].rhs
172+ case Inlined (EmptyTree , _, quoted) => quoted
173+ case _ => quoted
174+ }
175+ interpretQuote(quoted1)
176+
177+ case TypeApply (fn, quoted :: Nil ) if fn.symbol == defn.InternalQuoted_typeQuote =>
178+ interpretTypeQuote(quoted)
179+
180+ case Literal (Constant (value)) =>
181+ interpretLiteral(value)
182+
183+ case _ if tree.symbol == defn.QuoteContext_macroContext =>
184+ interpretQuoteContext()
185+
186+ // TODO disallow interpreted method calls as arguments
187+ case Call (fn, args) =>
188+ if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package )) {
189+ interpretNew(fn.symbol, args.flatten.map(interpretTree))
190+ } else if (fn.symbol.is(Module )) {
191+ interpretModuleAccess(fn.symbol)
192+ } else if (fn.symbol.isStatic) {
193+ val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
194+ staticMethodCall(args.flatten.map(interpretTree))
195+ } else if (fn.qualifier.symbol.is(Module ) && fn.qualifier.symbol.isStatic) {
196+ val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
197+ staticMethodCall(args.flatten.map(interpretTree))
198+ } else if (env.contains(fn.name)) {
199+ env(fn.name)
200+ } else if (tree.symbol.is(InlineProxy )) {
201+ interpretTree(tree.symbol.defTree.asInstanceOf [ValOrDefDef ].rhs)
202+ } else {
203+ unexpectedTree(tree)
204+ }
205+
206+ // Interpret `foo(j = x, i = y)` which it is expanded to
207+ // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
208+ case Block (stats, expr) => interpretBlock(stats, expr)
209+ case NamedArg (_, arg) => interpretTree(arg)
210+
211+ case Inlined (_, bindings, expansion) => interpretBlock(bindings, expansion)
212+
213+ case Typed (expr, _) =>
214+ interpretTree(expr)
215+
216+ case SeqLiteral (elems, _) =>
217+ interpretVarargs(elems.map(e => interpretTree(e)))
218+
219+ case _ =>
220+ unexpectedTree(tree)
221+ }
222+
223+ private def interpretBlock (stats : List [Tree ], expr : Tree )(implicit env : Env ) = {
224+ var unexpected : Option [Object ] = None
225+ val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
226+ case stat : ValDef =>
227+ accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
228+ case stat =>
229+ if (unexpected.isEmpty)
230+ unexpected = Some (unexpectedTree(stat))
231+ accEnv
232+ })
233+ unexpected.getOrElse(interpretTree(expr)(newEnv))
234+ }
235+
236+ private def interpretQuote (tree : Tree )(implicit env : Env ): Object =
97237 new scala.internal.quoted.TastyTreeExpr (Inlined (EmptyTree , Nil , tree).withSpan(tree.span))
98238
99- protected def interpretTypeQuote (tree : Tree )(implicit env : Env ): Object =
239+ private def interpretTypeQuote (tree : Tree )(implicit env : Env ): Object =
100240 new scala.internal.quoted.TreeType (tree)
101241
102- protected def interpretLiteral (value : Any )(implicit env : Env ): Object =
242+ private def interpretLiteral (value : Any )(implicit env : Env ): Object =
103243 value.asInstanceOf [Object ]
104244
105- protected def interpretVarargs (args : List [Object ])(implicit env : Env ): Object =
245+ private def interpretVarargs (args : List [Object ])(implicit env : Env ): Object =
106246 args.toSeq
107247
108- protected def interpretQuoteContext ()(implicit env : Env ): Object =
248+ private def interpretQuoteContext ()(implicit env : Env ): Object =
109249 new scala.quoted.QuoteContext (ReflectionImpl (ctx, pos))
110250
111- protected def interpretStaticMethodCall (moduleClass : Symbol , fn : Symbol , args : => List [ Object ] )(implicit env : Env ): Object = {
251+ private def interpretedStaticMethodCall (moduleClass : Symbol , fn : Symbol )(implicit env : Env ): List [ Object ] => Object = {
112252 val (inst, clazz) =
113253 if (moduleClass.name.startsWith(str.REPL_SESSION_LINE )) {
114254 (null , loadReplLineClass(moduleClass))
@@ -125,19 +265,20 @@ object Splicer {
125265
126266 val name = getDirectName(fn.info.finalResultType, fn.name.asTermName)
127267 val method = getMethod(clazz, name, paramsSig(fn))
128- stopIfRuntimeException(method.invoke(inst, args : _* ))
268+
269+ (args : List [Object ]) => stopIfRuntimeException(method.invoke(inst, args : _* ))
129270 }
130271
131- protected def interpretModuleAccess (fn : Symbol )(implicit env : Env ): Object =
272+ private def interpretModuleAccess (fn : Symbol )(implicit env : Env ): Object =
132273 loadModule(fn.moduleClass)
133274
134- protected def interpretNew (fn : Symbol , args : => List [Result ])(implicit env : Env ): Object = {
275+ private def interpretNew (fn : Symbol , args : => List [Object ])(implicit env : Env ): Object = {
135276 val clazz = loadClass(fn.owner.fullName.toString)
136277 val constr = clazz.getConstructor(paramsSig(fn): _* )
137278 constr.newInstance(args : _* ).asInstanceOf [Object ]
138279 }
139280
140- protected def unexpectedTree (tree : Tree )(implicit env : Env ): Object =
281+ private def unexpectedTree (tree : Tree )(implicit env : Env ): Object =
141282 throw new StopInterpretation (" Unexpected tree could not be interpreted: " + tree, tree.sourcePos)
142283
143284 private def loadModule (sym : Symbol ): Object = {
@@ -265,158 +406,25 @@ object Splicer {
265406
266407 }
267408
268- /** Tree interpreter that tests if tree can be interpreted */
269- private class CheckValidMacroBody (implicit ctx : Context ) extends AbstractInterpreter {
270- def checking : Boolean = true
271-
272- type Result = Unit
273-
274- def apply (tree : Tree ): Unit = interpretTree(tree)(Map .empty)
275-
276- protected def interpretQuote (tree : tpd.Tree )(implicit env : Env ): Unit = ()
277- protected def interpretTypeQuote (tree : tpd.Tree )(implicit env : Env ): Unit = ()
278- protected def interpretLiteral (value : Any )(implicit env : Env ): Unit = ()
279- protected def interpretVarargs (args : List [Unit ])(implicit env : Env ): Unit = ()
280- protected def interpretQuoteContext ()(implicit env : Env ): Unit = ()
281- protected def interpretStaticMethodCall (module : Symbol , fn : Symbol , args : => List [Unit ])(implicit env : Env ): Unit = args.foreach(identity)
282- protected def interpretModuleAccess (fn : Symbol )(implicit env : Env ): Unit = ()
283- protected def interpretNew (fn : Symbol , args : => List [Unit ])(implicit env : Env ): Unit = args.foreach(identity)
284-
285- def unexpectedTree (tree : tpd.Tree )(implicit env : Env ): Unit = {
286- // Assuming that top-level splices can only be in inline methods
287- // and splices are expanded at inline site, references to inline values
288- // will be known literal constant trees.
289- if (! tree.symbol.is(Inline ))
290- ctx.error(
291- """ Malformed macro.
292- |
293- |Expected the splice ${...} to contain a single call to a static method.
294- |
295- |Where parameters may be:
296- | * Quoted paramers or fields
297- | * References to inline parameters
298- | * Literal values of primitive types
299- """ .stripMargin, tree.sourcePos)
300- }
301- }
302-
303- /** Abstract Tree interpreter that can interpret calls to static methods with quoted or inline arguments */
304- private abstract class AbstractInterpreter (implicit ctx : Context ) {
305-
306- def checking : Boolean
307-
308- type Env = Map [Name , Result ]
309- type Result
310-
311- protected def interpretQuote (tree : Tree )(implicit env : Env ): Result
312- protected def interpretTypeQuote (tree : Tree )(implicit env : Env ): Result
313- protected def interpretLiteral (value : Any )(implicit env : Env ): Result
314- protected def interpretVarargs (args : List [Result ])(implicit env : Env ): Result
315- protected def interpretQuoteContext ()(implicit env : Env ): Result
316- protected def interpretStaticMethodCall (module : Symbol , fn : Symbol , args : => List [Result ])(implicit env : Env ): Result
317- protected def interpretModuleAccess (fn : Symbol )(implicit env : Env ): Result
318- protected def interpretNew (fn : Symbol , args : => List [Result ])(implicit env : Env ): Result
319- protected def unexpectedTree (tree : Tree )(implicit env : Env ): Result
320-
321- private final def removeErasedArguments (args : List [List [Tree ]], fnTpe : Type ): List [List [Tree ]] =
322- fnTpe match {
323- case tp : TermRef => removeErasedArguments(args, tp.underlying)
324- case tp : PolyType => removeErasedArguments(args, tp.resType)
325- case tp : ExprType => removeErasedArguments(args, tp.resType)
326- case tp : MethodType =>
327- val tail = removeErasedArguments(args.tail, tp.resType)
328- if (tp.isErasedMethod) tail else args.head :: tail
329- case tp : AppliedType if defn.isImplicitFunctionType(tp) =>
330- val tail = removeErasedArguments(args.tail, tp.args.last)
331- if (defn.isErasedFunctionType(tp)) tail else args.head :: tail
332- case tp => assert(args.isEmpty, tp); Nil
333- }
334-
335- protected final def interpretTree (tree : Tree )(implicit env : Env ): Result = tree match {
336- case Apply (TypeApply (fn, _), quoted :: Nil ) if fn.symbol == defn.InternalQuoted_exprQuote =>
337- val quoted1 = quoted match {
338- case quoted : Ident if quoted.symbol.isAllOf(InlineByNameProxy ) =>
339- // inline proxy for by-name parameter
340- quoted.symbol.defTree.asInstanceOf [DefDef ].rhs
341- case Inlined (EmptyTree , _, quoted) => quoted
342- case _ => quoted
343- }
344- interpretQuote(quoted1)
345-
346- case TypeApply (fn, quoted :: Nil ) if fn.symbol == defn.InternalQuoted_typeQuote =>
347- interpretTypeQuote(quoted)
348-
349- case Literal (Constant (value)) =>
350- interpretLiteral(value)
351-
352- case _ if tree.symbol == defn.QuoteContext_macroContext =>
353- interpretQuoteContext()
354-
355- case Call (fn, args) =>
356- if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package )) {
357- interpretNew(fn.symbol, args.flatten.map(interpretTree))
358- } else if (fn.symbol.is(Module )) {
359- interpretModuleAccess(fn.symbol)
360- } else if (fn.symbol.isStatic) {
361- val module = fn.symbol.owner
362- def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
363- interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
364- } else if (fn.qualifier.symbol.is(Module ) && fn.qualifier.symbol.isStatic) {
365- val module = fn.qualifier.symbol.moduleClass
366- def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
367- interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
368- } else if (env.contains(fn.name)) {
369- env(fn.name)
370- } else if (tree.symbol.is(InlineProxy )) {
371- interpretTree(tree.symbol.defTree.asInstanceOf [ValOrDefDef ].rhs)
372- } else {
373- unexpectedTree(tree)
374- }
375-
376- // Interpret `foo(j = x, i = y)` which it is expanded to
377- // `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
378- case Block (stats, expr) => interpretBlock(stats, expr)
379- case NamedArg (_, arg) => interpretTree(arg)
380-
381- case Inlined (_, bindings, expansion) => interpretBlock(bindings, expansion)
382-
383- case Typed (expr, _) =>
384- interpretTree(expr)
385-
386- case SeqLiteral (elems, _) =>
387- interpretVarargs(elems.map(e => interpretTree(e)))
388-
389- case _ =>
390- unexpectedTree(tree)
391- }
392-
393- private def interpretBlock (stats : List [Tree ], expr : Tree )(implicit env : Env ) = {
394- var unexpected : Option [Result ] = None
395- val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
396- case stat : ValDef if stat.symbol.is(Synthetic ) || ! checking =>
397- accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
398- case stat =>
399- if (unexpected.isEmpty)
400- unexpected = Some (unexpectedTree(stat))
401- accEnv
402- })
403- unexpected.getOrElse(interpretTree(expr)(newEnv))
404- }
405-
406- object Call {
407- def unapply (arg : Tree ): Option [(RefTree , List [List [Tree ]])] =
408- Call0 .unapply(arg).map((fn, args) => (fn, args.reverse))
409-
410- object Call0 {
411- def unapply (arg : Tree ): Option [(RefTree , List [List [Tree ]])] = arg match {
412- case Select (Call0 (fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
413- Some ((fn, args))
414- case fn : RefTree => Some ((fn, Nil ))
415- case Apply (Call0 (fn, args1), args2) => Some ((fn, args2 :: args1))
416- case TypeApply (Call0 (fn, args), _) => Some ((fn, args))
417- case _ => None
418- }
409+ object Call {
410+ /** Matches an expression that is either a field access or an application
411+ * It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
412+ */
413+ def unapply (arg : Tree )(implicit ctx : Context ): Option [(RefTree , List [List [Tree ]])] =
414+ Call0 .unapply(arg).map((fn, args) => (fn, args.reverse))
415+
416+ private object Call0 {
417+ def unapply (arg : Tree )(implicit ctx : Context ): Option [(RefTree , List [List [Tree ]])] = arg match {
418+ case Select (Call0 (fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
419+ Some ((fn, args))
420+ case fn : RefTree => Some ((fn, Nil ))
421+ case Apply (f @ Call0 (fn, args1), args2) =>
422+ if (f.tpe.widenDealias.isErasedMethod) Some ((fn, args1))
423+ else Some ((fn, args2 :: args1))
424+ case TypeApply (Call0 (fn, args), _) => Some ((fn, args))
425+ case _ => None
419426 }
420427 }
421428 }
429+
422430}
0 commit comments