diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 1baeaef6ff82..91627864ac64 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -824,12 +824,37 @@ object desugar { * object name extends parents { self => body } * * to: + * * val name: name$ = New(name$) * final class name$ extends parents { self: name.type => body } + * + * Special case for extension methods with collective parameters. Expand: + * + * given object name[tparams](x: T) extends parents { self => bpdy } + * + * to: + * + * given object name extends parents { self => body' } + * + * where every definition in `body` is expanded to an extension method + * taking type parameters `tparams` and a leading paramter `(x: T)`. + * See: makeExtensionDef */ def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = { val impl = mdef.impl val mods = mdef.mods + impl.constr match { + case DefDef(_, tparams, (vparams @ (vparam :: Nil)) :: _, _, _) => + assert(mods.is(Given)) + return moduleDef( + cpy.ModuleDef(mdef)( + mdef.name, + cpy.Template(impl)( + constr = emptyConstructor, + body = impl.body.map(makeExtensionDef(_, tparams, vparams))))) + case _ => + } + val moduleName = normalizeName(mdef, impl).asTermName def isEnumCase = mods.isEnumCase @@ -869,6 +894,36 @@ object desugar { } } + /** Given tpe parameters `Ts` (possibly empty) and a leading value parameter `(x: T)`, + * map a method definition + * + * def foo [Us] paramss ... + * + * to + * + * def foo[Ts ++ Us](x: T) parammss ... + * + * If the given member `mdef` is not of this form, flag it as an error. + */ + + def makeExtensionDef(mdef: Tree, tparams: List[TypeDef], leadingParams: List[ValDef]) given (ctx: Context): Tree = { + val allowed = "allowed here, since collective parameters are given" + mdef match { + case mdef: DefDef => + if (mdef.mods.is(Extension)) { + ctx.error(em"No extension method $allowed", mdef.sourcePos) + mdef + } + else cpy.DefDef(mdef)(tparams = tparams ++ mdef.tparams, vparamss = leadingParams :: mdef.vparamss) + .withFlags(Extension) + case mdef: Import => + mdef + case mdef => + ctx.error(em"Only methods $allowed", mdef.sourcePos) + mdef + } + } + /** Transforms * * type $T >: Low <: Hi diff --git a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala index e79a8cc55df7..7dd5ccf429e0 100644 --- a/compiler/src/dotty/tools/dotc/parsing/Parsers.scala +++ b/compiler/src/dotty/tools/dotc/parsing/Parsers.scala @@ -2869,11 +2869,20 @@ object Parsers { /** GivenDef ::= [id] [DefTypeParamClause] GivenBody * GivenBody ::= [‘as ConstrApp {‘,’ ConstrApp }] {GivenParamClause} [TemplateBody] * | ‘as’ Type {GivenParamClause} ‘=’ Expr + * | ‘(’ DefParam ‘)’ TemplateBody */ def instanceDef(newStyle: Boolean, start: Offset, mods: Modifiers, instanceMod: Mod) = atSpan(start, nameStart) { var mods1 = addMod(mods, instanceMod) val name = if (isIdent && (!newStyle || in.name != nme.as)) ident() else EmptyTermName val tparams = typeParamClauseOpt(ParamOwner.Def) + var leadingParamss = + if (in.token == LPAREN) + try paramClause(prefix = true) :: Nil + finally { + newLineOptWhenFollowedBy(LBRACE) + if (in.token != LBRACE) syntaxErrorOrIncomplete("`{' expected") + } + else Nil val parents = if (!newStyle && in.token == FOR || isIdent(nme.as)) { // for the moment, accept both `given for` and `given as` in.nextToken() @@ -2889,11 +2898,15 @@ object Parsers { } else { newLineOptWhenFollowedBy(LBRACE) - val tparams1 = tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal)) - val vparamss1 = vparamss.map(_.map(vparam => - vparam.withMods(vparam.mods &~ Param | ParamAccessor | PrivateLocal))) + val (tparams1, vparamss1) = + if (leadingParamss.nonEmpty) + (tparams, leadingParamss) + else + (tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal)), + vparamss.map(_.map(vparam => + vparam.withMods(vparam.mods &~ Param | ParamAccessor | PrivateLocal)))) val templ = templateBodyOpt(makeConstructor(tparams1, vparamss1), parents, Nil) - if (tparams.isEmpty && vparamss.isEmpty) ModuleDef(name, templ) + if (tparams.isEmpty && vparamss1.isEmpty || leadingParamss.nonEmpty) ModuleDef(name, templ) else TypeDef(name.toTypeName, templ) } finalizeDef(instDef, mods1, start) diff --git a/docs/docs/internals/syntax.md b/docs/docs/internals/syntax.md index 0078638044a2..bf0b5bf9892c 100644 --- a/docs/docs/internals/syntax.md +++ b/docs/docs/internals/syntax.md @@ -387,6 +387,7 @@ EnumDef ::= id ClassConstr InheritClauses EnumBody GivenDef ::= [id] [DefTypeParamClause] GivenBody GivenBody ::= [‘as ConstrApp {‘,’ ConstrApp }] {GivenParamClause} [TemplateBody] | ‘as’ Type {GivenParamClause} ‘=’ Expr + | ‘(’ DefParam ‘)’ TemplateBody Template ::= InheritClauses [TemplateBody] Template(constr, parents, self, stats) InheritClauses ::= [‘extends’ ConstrApps] [‘derives’ QualId {‘,’ QualId}] ConstrApps ::= ConstrApp {‘with’ ConstrApp} diff --git a/docs/docs/reference/contextual/extension-methods.md b/docs/docs/reference/contextual/extension-methods.md index 3c64f2416ea9..da7d39ad0c17 100644 --- a/docs/docs/reference/contextual/extension-methods.md +++ b/docs/docs/reference/contextual/extension-methods.md @@ -80,7 +80,7 @@ So `circle.circumference` translates to `CircleOps.circumference(circle)`, provi ### Given Instances for Extension Methods -Given instances that define extension methods can also be defined without a `for` clause. E.g., +Given instances that define extension methods can also be defined without an `as` clause. E.g., ```scala given StringOps { @@ -94,8 +94,33 @@ given { def (xs: List[T]) second[T] = xs.tail.head } ``` -If such given instances are anonymous (as in the second clause), their name is synthesized from the name -of the first defined extension method. +If such given instances are anonymous (as in the second clause), their name is synthesized from the name of the first defined extension method. + +### Given Instances with Collective Parameters + +If a given instance has several extension methods one can pull out the left parameter section +as well as any type parameters of these extension methods into the given instance itself. +For instance, here is a given instance with two extension methods. +```scala +given ListOps { + def (xs: List[T]) second[T]: T = xs.tail.head + def (xs: List[T]) third[T]: T = xs.tail.tail.head +} +``` +The repetition in the parameters can be avoided by moving the parameters into the given instance itself. The following version is a shorthand for the code above. +```scala +given ListOps[T](xs: List[T]) { + def second: T = xs.tail.head + def third: T = xs.tail.tail.head +} +``` +This syntax just adds convenience at the definition site. Applications of such extension methods are exactly the same as if their parameters were repeated in each extension method. +Examples: +```scala +val xs = List(1, 2, 3) +xs.second[Int] +ListOps.third[T](xs) +``` ### Operators @@ -143,4 +168,6 @@ to the [current syntax](../../internals/syntax.md). ``` DefSig ::= ... | ‘(’ DefParam ‘)’ [nl] id [DefTypeParamClause] DefParamClauses +GivenBody ::= ... + | ‘(’ DefParam ‘)’ TemplateBody ``` diff --git a/tests/neg/extension-methods.scala b/tests/neg/extension-methods.scala index a16fa8fee751..095d9148b466 100644 --- a/tests/neg/extension-methods.scala +++ b/tests/neg/extension-methods.scala @@ -10,5 +10,9 @@ object Test { "".l2 // error 1.l1 // error - + given [T](xs: List[T]) { + def (x: Int) f1: T = ??? // error: No extension method allowed here, since collective parameters are given + def f2[T]: T = ??? // error: T is already defined as type T + def f3(xs: List[T]) = ??? // error: xs is already defined as value xs + } } \ No newline at end of file diff --git a/tests/run/extmethods2.scala b/tests/run/extmethods2.scala index d2977a6a25e9..28b1937cabfe 100644 --- a/tests/run/extmethods2.scala +++ b/tests/run/extmethods2.scala @@ -14,4 +14,32 @@ object Test extends App { } test given TC() -} \ No newline at end of file + + object A { + given ListOps[T](xs: List[T]) { + def second: T = xs.tail.head + def third: T = xs.tail.tail.head + def concat(ys: List[T]) = xs ++ ys + def zipp[U](ys: List[U]): List[(T, U)] = xs.zip(ys) + } + given (xs: List[Int]) { + def prod = (1 /: xs)(_ * _) + } + + + } + + object B { + import given A._ + val xs = List(1, 2, 3) + assert(xs.second[Int] == 2) + assert(xs.third == 3) + assert(A.ListOps.second[Int](xs) == 2) + assert(A.ListOps.third(xs) == 3) + assert(xs.prod == 6) + assert(xs.concat(xs).length == 6) + assert(xs.zipp(xs).map(_ + _).prod == 36) + assert(xs.zipp[Int, Int](xs).map(_ + _).prod == 36) + } +} +