@@ -860,20 +860,21 @@ object desugar {
860860 *
861861 * where every definition in `body` is expanded to an extension method
862862 * taking type parameters `tparams` and a leading paramter `(x: T)`.
863- * See: makeExtensionDef
863+ * See: collectiveExtensionBody
864864 */
865865 def moduleDef (mdef : ModuleDef )(implicit ctx : Context ): Tree = {
866866 val impl = mdef.impl
867867 val mods = mdef.mods
868868 impl.constr match {
869- case DefDef (_, tparams, (vparams @ (vparam :: Nil )) :: givenParamss, _, _) =>
869+ case DefDef (_, tparams, vparamss @ (vparam :: Nil ) :: givenParamss, _, _) =>
870+ // Transform collective extension
870871 assert(mods.is(Given ))
871872 return moduleDef(
872873 cpy.ModuleDef (mdef)(
873874 mdef.name,
874875 cpy.Template (impl)(
875876 constr = emptyConstructor,
876- body = impl.body.map(makeExtensionDef(_ , tparams, vparams, givenParamss) ))))
877+ body = collectiveExtensionBody( impl.body, tparams, vparamss ))))
877878 case _ =>
878879 }
879880
@@ -916,38 +917,67 @@ object desugar {
916917 }
917918 }
918919
919- /** Given tpe parameters `Ts` (possibly empty) and a leading value parameter `(x: T)`,
920- * map a method definition
920+ /** Transform the statements of a collective extension
921+ * @param stats the original statements as they were parsed
922+ * @param tparams the collective type parameters
923+ * @param vparamss the collective value parameters, consisting
924+ * of a single leading value parameter, followed by
925+ * zero or more context parameter clauses
921926 *
922- * def foo [Us] paramss ...
927+ * Note: It is already assured by Parser.checkExtensionMethod that all
928+ * statements conform to requirements.
923929 *
924- * to
930+ * Each method in stats is transformed into an extension method. Furthermore,
931+ * identifier references to other methods are turned into selections on the common
932+ * parameter.
933+ *
934+ * Example:
925935 *
926- * <extension> def foo[Ts ++ Us](x: T) parammss ...
936+ * extension on [Ts](x: T)(using C):
937+ * def f(y: T) = ???
938+ * def g(z: T) = f(z)
927939 *
928- * If the given member `mdef` is not of this form, flag it as an error.
940+ * is turned into
941+ *
942+ * extension:
943+ * <extension> def f[Ts](x: T)(using C)(y: T) = ???
944+ * <extension> def g[Ts](x: T)(using C)(z: T) = x.f(z)
929945 */
930-
931- def makeExtensionDef (mdef : Tree , tparams : List [TypeDef ], leadingParams : List [ValDef ],
932- givenParamss : List [List [ValDef ]])(using ctx : Context ): Tree = {
933- val allowed = " allowed here, since collective parameters are given"
934- mdef match {
935- case mdef : DefDef =>
936- if (mdef.mods.is(Extension )) {
937- ctx.error(em " No extension method $allowed" , mdef.sourcePos)
946+ def collectiveExtensionBody (stats : List [Tree ],
947+ tparams : List [TypeDef ], vparamss : List [List [ValDef ]])(using ctx : Context ): List [Tree ] =
948+ val methodNames : Set [Name ] =
949+ stats.collect { case stat : DefDef => stat.name }.toSet
950+
951+ object linkMethods extends UntypedTreeMap :
952+ private val paramName = vparamss.head.head.name
953+ private var prefixName = paramName
954+
955+ override def transform (tree : Tree )(using Context ): Tree = tree match
956+ case tree : NamedDefTree if tree.name == paramName =>
957+ prefixName = UniqueName .fresh()
958+ super .transform(tree)
959+ case tree : Ident if methodNames.contains(tree.name) =>
960+ cpy.Select (tree)(Ident (prefixName), tree.name)
961+ case _ =>
962+ super .transform(tree)
963+
964+ def apply (rhs : Tree ): Tree =
965+ val rhs1 = transform(rhs)
966+ if prefixName == paramName then rhs1
967+ else Block (ValDef (prefixName, TypeTree (), Ident (paramName)), rhs1)
968+ end linkMethods
969+
970+ for stat <- stats yield
971+ stat match
972+ case mdef : DefDef =>
973+ cpy.DefDef (mdef)(
974+ tparams = tparams ++ mdef.tparams,
975+ vparamss = vparamss ::: mdef.vparamss,
976+ rhs = linkMethods(mdef.rhs)
977+ ).withMods(mdef.mods | Extension )
978+ case mdef =>
938979 mdef
939- }
940- else cpy.DefDef (mdef)(
941- tparams = tparams ++ mdef.tparams,
942- vparamss = leadingParams :: givenParamss ::: mdef.vparamss
943- ).withMods(mdef.mods | Extension )
944- case mdef : Import =>
945- mdef
946- case mdef =>
947- ctx.error(em " Only methods $allowed" , mdef.sourcePos)
948- mdef
949- }
950- }
980+ end collectiveExtensionBody
951981
952982 /** Transforms
953983 *
0 commit comments