@@ -334,12 +334,19 @@ object desugar {
334334
335335 // a reference to the class type bound by `cdef`, with type parameters coming from the constructor
336336 val classTypeRef = appliedRef(classTycon)
337- // a refereence to `enumClass`, with type parameters coming from the constructor
337+ // a reference to `enumClass`, with type parameters coming from the constructor
338338 lazy val enumClassTypeRef = appliedRef(enumClassRef)
339339
340340 // new C[Ts](paramss)
341341 lazy val creatorExpr = New (classTypeRef, constrVparamss nestedMap refOfDef)
342342
343+ // The return type of the `apply` and `copy` methods
344+ val applyResultTpt =
345+ if (isEnumCase)
346+ if (parents.isEmpty) enumClassTypeRef
347+ else parents.head
348+ else TypeTree ()
349+
343350 // Methods to add to a case class C[..](p1: T1, ..., pN: Tn)(moreParams)
344351 // def isDefined = true
345352 // def productArity = N
@@ -380,7 +387,7 @@ object desugar {
380387 cpy.ValDef (vparam)(rhs = copyDefault(vparam)))
381388 val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
382389 cpy.ValDef (vparam)(rhs = EmptyTree ))
383- DefDef (nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree () , creatorExpr)
390+ DefDef (nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, applyResultTpt , creatorExpr)
384391 .withMods(synthetic) :: Nil
385392 }
386393
@@ -430,15 +437,15 @@ object desugar {
430437 constrVparamss.length > 1 ||
431438 mods.is(Abstract ) ||
432439 constr.mods.is(Private )) anyRef
440+ else
433441 // todo: also use anyRef if constructor has a dependent method type (or rule that out)!
434- else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function (vparams map (_.tpt), restpe))
442+ (constrVparamss :\ (if (isEnumCase) applyResultTpt else classTypeRef)) (
443+ (vparams, restpe) => Function (vparams map (_.tpt), restpe))
435444 val applyMeths =
436445 if (mods is Abstract ) Nil
437- else {
438- val restpe = if (isEnumCase) enumClassTypeRef else TypeTree ()
439- DefDef (nme.apply, derivedTparams, derivedVparamss, restpe, creatorExpr)
446+ else
447+ DefDef (nme.apply, derivedTparams, derivedVparamss, applyResultTpt, creatorExpr)
440448 .withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized )) :: Nil
441- }
442449 val unapplyMeth = {
443450 val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
444451 val unapplyRHS = if (arity == 0 ) Literal (Constant (true )) else Ident (unapplyParam.name)
@@ -505,7 +512,10 @@ object desugar {
505512 case _ =>
506513 }
507514
508- flatTree(cdef1 :: companions ::: implicitWrappers)
515+ val result = val flatTree(cdef1 :: companions ::: implicitWrappers)
516+ // if (isEnum) println(i"enum $cdef\n --->\n$result")
517+ // if (isEnumCase) println(i"enum case $cdef\n --->\n$result")
518+ result
509519 }
510520
511521 val AccessOrSynthetic = AccessFlags | Synthetic
0 commit comments