@@ -14,6 +14,7 @@ import reporting.diagnostic.messages._
1414
1515object desugar {
1616 import untpd ._
17+ import DesugarEnums ._
1718
1819 /** Tags a .withFilter call generated by desugaring a for expression.
1920 * Such calls can alternatively be rewritten to use filter.
@@ -263,7 +264,9 @@ object desugar {
263264 val className = checkNotReservedName(cdef).asTypeName
264265 val impl @ Template (constr0, parents, self, _) = cdef.rhs
265266 val mods = cdef.mods
266- val companionMods = mods.withFlags((mods.flags & AccessFlags ).toCommonFlags)
267+ val companionMods = mods
268+ .withFlags((mods.flags & AccessFlags ).toCommonFlags)
269+ .withMods(mods.mods.filter(! _.isInstanceOf [Mod .EnumCase ]))
267270
268271 val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true ) match {
269272 case meth : DefDef => (meth, Nil )
@@ -288,17 +291,31 @@ object desugar {
288291 }
289292
290293 val isCaseClass = mods.is(Case ) && ! mods.is(Module )
294+ val isEnum = mods.hasMod[Mod .Enum ]
295+ val isEnumCase = isLegalEnumCase(cdef)
291296 val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
292297 // This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
293298
294- val constrTparams = constr1.tparams map toDefParam
299+ lazy val reconstitutedTypeParams = reconstitutedEnumTypeParams(cdef.pos.startPos)
300+
301+ val originalTparams =
302+ if (isEnumCase && parents.isEmpty) {
303+ if (constr1.tparams.nonEmpty) {
304+ if (reconstitutedTypeParams.nonEmpty)
305+ ctx.error(em " case with type parameters needs extends clause " , constr1.tparams.head.pos)
306+ constr1.tparams
307+ }
308+ else reconstitutedTypeParams
309+ }
310+ else constr1.tparams
311+ val originalVparamss = constr1.vparamss
312+ val constrTparams = originalTparams.map(toDefParam)
295313 val constrVparamss =
296- if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty
297- if (isCaseClass)
298- ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
314+ if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
315+ if (isCaseClass) ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
299316 ListOfNil
300317 }
301- else constr1.vparamss .nestedMap(toDefParam)
318+ else originalVparamss .nestedMap(toDefParam)
302319 val constr = cpy.DefDef (constr1)(tparams = constrTparams, vparamss = constrVparamss)
303320
304321 // Add constructor type parameters and evidence implicit parameters
@@ -312,21 +329,24 @@ object desugar {
312329 stat
313330 }
314331
315- val derivedTparams = constrTparams map derivedTypeParam
332+ val derivedTparams =
333+ if (isEnumCase) constrTparams else constrTparams map derivedTypeParam
316334 val derivedVparamss = constrVparamss nestedMap derivedTermParam
317335 val arity = constrVparamss.head.length
318336
319- var classTycon : Tree = EmptyTree
337+ val classTycon : Tree = new TypeRefTree // watching is set at end of method
320338
321- // a reference to the class type, with all parameters given.
322- val classTypeRef /* : Tree*/ = {
323- // -language:keepUnions difference: classTypeRef needs type annotation, otherwise
324- // infers Ident | AppliedTypeTree, which
325- // renders the :\ in companions below untypable.
326- classTycon = (new TypeRefTree ) withPos cdef.pos.startPos // watching is set at end of method
327- val tparams = impl.constr.tparams
328- if (tparams.isEmpty) classTycon else AppliedTypeTree (classTycon, tparams map refOfDef)
329- }
339+ def appliedRef (tycon : Tree ) =
340+ (if (constrTparams.isEmpty) tycon
341+ else AppliedTypeTree (tycon, constrTparams map refOfDef))
342+ .withPos(cdef.pos.startPos)
343+
344+ // a reference to the class type bound by `cdef`, with type parameters coming from the constructor
345+ val classTypeRef = appliedRef(classTycon)
346+ // a reference to `enumClass`, with type parameters coming from the constructor
347+ lazy val enumClassTypeRef =
348+ if (reconstitutedTypeParams.isEmpty) enumClassRef
349+ else appliedRef(enumClassRef)
330350
331351 // new C[Ts](paramss)
332352 lazy val creatorExpr = New (classTypeRef, constrVparamss nestedMap refOfDef)
@@ -374,7 +394,9 @@ object desugar {
374394 DefDef (nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree (), creatorExpr)
375395 .withMods(synthetic) :: Nil
376396 }
377- copyMeths ::: productElemMeths.toList
397+
398+ val enumTagMeths = if (isEnumCase) enumTagMeth(CaseKind .Class )._1 :: Nil else Nil
399+ copyMeths ::: enumTagMeths ::: productElemMeths.toList
378400 }
379401 else Nil
380402
@@ -387,8 +409,12 @@ object desugar {
387409
388410 // Case classes and case objects get a ProductN parent
389411 var parents1 = parents
412+ if (isEnumCase && parents.isEmpty)
413+ parents1 = enumClassTypeRef :: Nil
390414 if (mods.is(Case ) && arity <= Definitions .MaxTupleArity )
391- parents1 = parents1 :+ productConstr(arity)
415+ parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that?
416+ if (isEnum)
417+ parents1 = parents1 :+ ref(defn.EnumType )
392418
393419 // The thicket which is the desugared version of the companion object
394420 // synthetic object C extends parentTpt { defs }
@@ -410,17 +436,26 @@ object desugar {
410436 // For all other classes, the parent is AnyRef.
411437 val companions =
412438 if (isCaseClass) {
439+ // The return type of the `apply` method
440+ val applyResultTpt =
441+ if (isEnumCase)
442+ if (parents.isEmpty) enumClassTypeRef
443+ else parents.reduceLeft(AndTypeTree )
444+ else TypeTree ()
445+
413446 val parent =
414447 if (constrTparams.nonEmpty ||
415448 constrVparamss.length > 1 ||
416449 mods.is(Abstract ) ||
417450 constr.mods.is(Private )) anyRef
451+ else
418452 // todo: also use anyRef if constructor has a dependent method type (or rule that out)!
419- else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function (vparams map (_.tpt), restpe))
453+ (constrVparamss :\ (if (isEnumCase) applyResultTpt else classTypeRef)) (
454+ (vparams, restpe) => Function (vparams map (_.tpt), restpe))
420455 val applyMeths =
421456 if (mods is Abstract ) Nil
422457 else
423- DefDef (nme.apply, derivedTparams, derivedVparamss, TypeTree () , creatorExpr)
458+ DefDef (nme.apply, derivedTparams, derivedVparamss, applyResultTpt , creatorExpr)
424459 .withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized )) :: Nil
425460 val unapplyMeth = {
426461 val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
@@ -464,15 +499,15 @@ object desugar {
464499 else cpy.ValDef (self)(tpt = selfType).withMods(self.mods | SelfName )
465500 }
466501
467- val cdef1 = {
468- val originalTparams = constr1.tparams .toIterator
469- val originalVparams = constr1.vparamss .toIterator.flatten
470- val tparamAccessors = derivedTparams.map(_.withMods(originalTparams .next.mods))
502+ val cdef1 = addEnumFlags {
503+ val originalTparamsIt = originalTparams .toIterator
504+ val originalVparamsIt = originalVparamss .toIterator.flatten
505+ val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt .next.mods))
471506 val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
472507 val vparamAccessors = derivedVparamss match {
473508 case first :: rest =>
474- first.map(_.withMods(originalVparams .next.mods | caseAccessor)) ++
475- rest.flatten.map(_.withMods(originalVparams .next.mods))
509+ first.map(_.withMods(originalVparamsIt .next.mods | caseAccessor)) ++
510+ rest.flatten.map(_.withMods(originalVparamsIt .next.mods))
476511 case _ =>
477512 Nil
478513 }
@@ -503,23 +538,26 @@ object desugar {
503538 */
504539 def moduleDef (mdef : ModuleDef )(implicit ctx : Context ): Tree = {
505540 val moduleName = checkNotReservedName(mdef).asTermName
506- val tmpl = mdef.impl
541+ val impl = mdef.impl
507542 val mods = mdef.mods
543+ lazy val isEnumCase = isLegalEnumCase(mdef)
508544 if (mods is Package )
509- PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , tmpl).withMods(mods &~ Package ) :: Nil )
545+ PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , impl).withMods(mods &~ Package ) :: Nil )
546+ else if (isEnumCase)
547+ expandEnumModule(moduleName, impl, mods, mdef.pos)
510548 else {
511549 val clsName = moduleName.moduleClassName
512550 val clsRef = Ident (clsName)
513551 val modul = ValDef (moduleName, clsRef, New (clsRef, Nil ))
514552 .withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags )
515553 .withPos(mdef.pos)
516- val ValDef (selfName, selfTpt, _) = tmpl .self
517- val selfMods = tmpl .self.mods
518- if (! selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType (mdef), tmpl .self.pos)
519- val clsSelf = ValDef (selfName, SingletonTypeTree (Ident (moduleName)), tmpl .self.rhs)
554+ val ValDef (selfName, selfTpt, _) = impl .self
555+ val selfMods = impl .self.mods
556+ if (! selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType (mdef), impl .self.pos)
557+ val clsSelf = ValDef (selfName, SingletonTypeTree (Ident (moduleName)), impl .self.rhs)
520558 .withMods(selfMods)
521- .withPos(tmpl .self.pos orElse tmpl .pos.startPos)
522- val clsTmpl = cpy.Template (tmpl )(self = clsSelf, body = tmpl .body)
559+ .withPos(impl .self.pos orElse impl .pos.startPos)
560+ val clsTmpl = cpy.Template (impl )(self = clsSelf, body = impl .body)
523561 val cls = TypeDef (clsName, clsTmpl)
524562 .withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags )
525563 Thicket (modul, classDef(cls).withPos(mdef.pos))
@@ -542,11 +580,23 @@ object desugar {
542580 /** val p1, ..., pN: T = E
543581 * ==>
544582 * makePatDef[[val p1: T1 = E ]]; ...; makePatDef[[val pN: TN = E ]]
583+ *
584+ * case e1, ..., eN
585+ * ==>
586+ * expandSimpleEnumCase([case e1]); ...; expandSimpleEnumCase([case eN])
545587 */
546- def patDef (pdef : PatDef )(implicit ctx : Context ): Tree = {
588+ def patDef (pdef : PatDef )(implicit ctx : Context ): Tree = flatTree {
547589 val PatDef (mods, pats, tpt, rhs) = pdef
548- val pats1 = if (tpt.isEmpty) pats else pats map (Typed (_, tpt))
549- flatTree(pats1 map (makePatDef(pdef, mods, _, rhs)))
590+ if (mods.hasMod[Mod .EnumCase ] && enumCaseIsLegal(pdef))
591+ pats map {
592+ case id : Ident =>
593+ expandSimpleEnumCase(id.name.asTermName, mods,
594+ Position (pdef.pos.start, id.pos.end, id.pos.start))
595+ }
596+ else {
597+ val pats1 = if (tpt.isEmpty) pats else pats map (Typed (_, tpt))
598+ pats1 map (makePatDef(pdef, mods, _, rhs))
599+ }
550600 }
551601
552602 /** If `pat` is a variable pattern,
@@ -923,7 +973,7 @@ object desugar {
923973 case (gen : GenFrom ) :: (rest @ (GenFrom (_, _) :: _)) =>
924974 val cont = makeFor(mapName, flatMapName, rest, body)
925975 Apply (rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
926- case (enum @ GenFrom (pat, rhs)) :: (rest @ GenAlias (_, _) :: _) =>
976+ case (GenFrom (pat, rhs)) :: (rest @ GenAlias (_, _) :: _) =>
927977 val (valeqs, rest1) = rest.span(_.isInstanceOf [GenAlias ])
928978 val pats = valeqs map { case GenAlias (pat, _) => pat }
929979 val rhss = valeqs map { case GenAlias (_, rhs) => rhs }
@@ -1024,7 +1074,6 @@ object desugar {
10241074 List (CaseDef (Ident (nme.DEFAULT_EXCEPTION_NAME ), EmptyTree , Apply (handler, Ident (nme.DEFAULT_EXCEPTION_NAME )))),
10251075 finalizer)
10261076 }
1027-
10281077 }
10291078 }.withPos(tree.pos)
10301079
0 commit comments