@@ -85,10 +85,49 @@ object Inliner {
8585 * @return An `Inlined` node that refers to the original call and the inlined bindings
8686 * and body that replace it.
8787 */
88- def inlineCall (tree : Tree , pt : Type )(implicit ctx : Context ): Tree = tree match {
89- case Block (stats, expr) =>
90- cpy.Block (tree)(stats, inlineCall(expr, pt))
91- case _ if (enclosingInlineds.length < ctx.settings.XmaxInlines .value) =>
88+ def inlineCall (tree : Tree , pt : Type )(implicit ctx : Context ): Tree = {
89+
90+ /** Set the position of all trees logically contained in the expansion of
91+ * inlined call `call` to the position of `call`. This transform is necessary
92+ * when lifting bindings from the expansion to the outside of the call.
93+ */
94+ def liftFromInlined (call : Tree ) = new TreeMap {
95+ override def transform (t : Tree )(implicit ctx : Context ) = {
96+ t match {
97+ case Inlined (t, Nil , expr) if t.isEmpty => expr
98+ case _ => super .transform(t.withPos(call.pos))
99+ }
100+ }
101+ }
102+
103+ val bindings = new mutable.ListBuffer [Tree ]
104+
105+ /** Lift bindings around inline call or in its function part to
106+ * the `bindings` buffer. This is done as an optimization to keep
107+ * inline call expansions smaller.
108+ */
109+ def liftBindings (tree : Tree , liftPos : Tree => Tree ): Tree = tree match {
110+ case Block (stats, expr) =>
111+ bindings ++= stats.map(liftPos)
112+ liftBindings(expr, liftPos)
113+ case Inlined (call, stats, expr) =>
114+ bindings ++= stats.map(liftPos)
115+ val lifter = liftFromInlined(call)
116+ cpy.Inlined (tree)(call, Nil , liftBindings(expr, liftFromInlined(call).transform(_)))
117+ case Apply (fn, args) =>
118+ cpy.Apply (tree)(liftBindings(fn, liftPos), args)
119+ case TypeApply (fn, args) =>
120+ cpy.TypeApply (tree)(liftBindings(fn, liftPos), args)
121+ case Select (qual, name) =>
122+ cpy.Select (tree)(liftBindings(qual, liftPos), name)
123+ case _ =>
124+ tree
125+ }
126+
127+ val tree1 = liftBindings(tree, identity)
128+ if (bindings.nonEmpty)
129+ cpy.Block (tree)(bindings.toList, inlineCall(tree1, pt))
130+ else if (enclosingInlineds.length < ctx.settings.XmaxInlines .value) {
92131 val body = bodyToInline(tree.symbol) // can typecheck the tree and thereby produce errors
93132 if (ctx.reporter.hasErrors) tree
94133 else {
@@ -97,7 +136,8 @@ object Inliner {
97136 else ctx.fresh.setProperty(InlineBindings , newMutableSymbolMap[Tree ])
98137 new Inliner (tree, body)(inlinerCtx).inlined(pt)
99138 }
100- case _ =>
139+ }
140+ else
101141 errorTree(
102142 tree,
103143 i """ |Maximal number of successive inlines ( ${ctx.settings.XmaxInlines .value}) exceeded,
@@ -469,10 +509,13 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
469509 val argInPlace =
470510 if (trailing.isEmpty) arg
471511 else letBindUnless(TreeInfo .Pure , arg)(seq(trailing, _))
472- seq(prefix, seq(leading, argInPlace))
512+ val fullArg = seq(prefix, seq(leading, argInPlace))
513+ new TreeTypeMap ().transform(fullArg) // make sure local bindings in argument have fresh symbols
473514 .reporting(res => i " projecting $tree -> $res" , inlining)
474515 }
475516 else tree
517+ case Block (stats, expr) if stats.forall(isPureBinding) =>
518+ cpy.Block (tree)(stats, reduceProjection(expr))
476519 case _ => tree
477520 }
478521 }
@@ -793,6 +836,14 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
793836 }
794837 countRefs.traverse(tree)
795838 for (binding <- bindings) countRefs.traverse(binding)
839+
840+ def retain (boundSym : Symbol ) = {
841+ refCount.get(boundSym) match {
842+ case Some (x) => x > 1 || x == 1 && ! boundSym.is(Method )
843+ case none => true
844+ }
845+ } && ! boundSym.is(TransparentImplicitMethod )
846+
796847 val inlineBindings = new TreeMap {
797848 override def transform (t : Tree )(implicit ctx : Context ) = t match {
798849 case t : RefTree =>
@@ -812,11 +863,8 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
812863 super .transform(t)
813864 }
814865 }
815- def retain (binding : MemberDef ) = refCount.get(binding.symbol) match {
816- case Some (x) => x > 1 || x == 1 && ! binding.symbol.is(Method )
817- case none => true
818- }
819- val retained = bindings.filterConserve(retain)
866+
867+ val retained = bindings.filterConserve(binding => retain(binding.symbol))
820868 if (retained `eq` bindings) {
821869 (bindings, tree)
822870 }
0 commit comments