@@ -2051,44 +2051,51 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
20512051 }}
20522052 val argVals = argVals0.reverse
20532053 val argRefs = argRefs0.reverse
2054- def rec (fn : Tree , topAscription : Option [TypeTree ]): Tree = fn match {
2055- case Typed (expr, tpt) =>
2056- // we need to retain any type ascriptions we see and:
2057- // a) if we succeed, ascribe the result type of the ascription to the inlined body
2058- // b) if we fail, re-ascribe the same type to whatever it was we couldn't inline
2059- // note: if you see many nested ascriptions, keep the top one as that's what the enclosing expression expects
2060- rec(expr, topAscription.orElse(Some (tpt)))
2061- case Inlined (call, bindings, expansion) =>
2062- // this case must go before closureDef to avoid dropping the inline node
2063- cpy.Inlined (fn)(call, bindings, rec(expansion, topAscription))
2064- case closureDef(ddef) =>
2065- val paramSyms = ddef.vparamss.head.map(param => param.symbol)
2066- val paramToVals = paramSyms.zip(argRefs).toMap
2067- val result = new TreeTypeMap (
2068- oldOwners = ddef.symbol :: Nil ,
2069- newOwners = ctx.owner :: Nil ,
2070- treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
2071- ).transform(ddef.rhs)
2072- topAscription match {
2073- case Some (tpt) =>
2074- // we assume the ascribed type has an apply that has a MethodType with a single param list (there should be no polys)
2075- val methodType = tpt.tpe.member(nme.apply).info.asInstanceOf [MethodType ]
2054+ val reducedBody = lambdaExtractor(fn, argRefs.map(_.tpe)) match {
2055+ case Some (body) => body(argRefs)
2056+ case None => fn.select(nme.apply).appliedToArgs(argRefs)
2057+ }
2058+ seq(argVals, reducedBody).withSpan(fn.span)
2059+ }
2060+
2061+ def lambdaExtractor (fn : Term , paramTypes : List [Type ])(using ctx : Context ): Option [List [Term ] => Term ] = {
2062+ def rec (fn : Term , transformBody : Term => Term ): Option [List [Term ] => Term ] = {
2063+ fn match {
2064+ case Inlined (call, bindings, expansion) =>
2065+ // this case must go before closureDef to avoid dropping the inline node
2066+ rec(expansion, cpy.Inlined (fn)(call, bindings, _))
2067+ case Typed (expr, tpt) =>
2068+ val tpe = tpt.tpe.dropDependentRefinement
2069+ // we checked that this is a plain Function closure, so there will be an apply method with a MethodType
2070+ // and the expected signature based on param types
2071+ val expectedSig = Signature .NotAMethod .prependTermParams(paramTypes, false )
2072+ val method = tpt.tpe.member(nme.apply).atSignature(expectedSig)
2073+ if method.symbol.is(Deferred ) then
2074+ val methodType = method.info.asInstanceOf [MethodType ]
20762075 // result might contain paramrefs, so we substitute them with arg termrefs
2077- val resultTypeWithSubst = methodType.resultType.substParams(methodType, argRefs.map(_.tpe))
2078- Typed (result, TypeTree (resultTypeWithSubst).withSpan(fn.span)).withSpan(fn.span)
2079- case None =>
2080- result
2081- }
2082- case tpd.Block (stats, expr) =>
2083- seq(stats, rec(expr, topAscription)).withSpan(fn.span)
2084- case _ =>
2085- val maybeAscribed = topAscription match {
2086- case Some (tpt) => Typed (fn, tpt).withSpan(fn.span)
2087- case None => fn
2088- }
2089- maybeAscribed.select(nme.apply).appliedToArgs(argRefs).withSpan(fn.span)
2076+ val resultTypeWithSubst = methodType.resultType.substParams(methodType, paramTypes)
2077+ rec(expr, Typed (_, TypeTree (resultTypeWithSubst).withSpan(tpt.span)))
2078+ else
2079+ None
2080+ case cl @ closureDef(ddef) =>
2081+ def replace (body : Term , argRefs : List [Term ]): Term = {
2082+ val paramSyms = ddef.vparamss.head.map(param => param.symbol)
2083+ val paramToVals = paramSyms.zip(argRefs).toMap
2084+ new TreeTypeMap (
2085+ oldOwners = ddef.symbol :: Nil ,
2086+ newOwners = ctx.owner :: Nil ,
2087+ treeMap = tree => paramToVals.get(tree.symbol).map(_.withSpan(tree.span)).getOrElse(tree)
2088+ ).transform(body)
2089+ }
2090+ Some (argRefs => replace(transformBody(ddef.rhs), argRefs))
2091+ case Block (stats, expr) =>
2092+ // this case must go after closureDef to avoid matching the closure
2093+ rec(expr, cpy.Block (fn)(stats, _))
2094+ case _ =>
2095+ None
2096+ }
20902097 }
2091- seq(argVals, rec(fn, None ) )
2098+ rec(fn, identity )
20922099 }
20932100
20942101 // ///////////
0 commit comments