@@ -1102,21 +1102,31 @@ trait Checking {
11021102 report.error(ClassCannotExtendEnum (cls, firstParent), cdef.sourcePos)
11031103 }
11041104
1105- /** Check that the firstParent derives from the declaring enum class.
1105+ /** Check that the firstParent derives from the declaring enum class, if not, adds it as a parent after emitting an
1106+ * error.
11061107 */
1107- def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Boolean = {
1108+ def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Unit =
1109+
1110+ extension (sym : Symbol ) def typeRefApplied (using Context ): Type =
1111+ typeRef.appliedTo(typeParams.map(_.info.loBound))
1112+
1113+ def ensureParentDerivesFrom (enumCase : Symbol )(using Context ) =
1114+ val enumCls = enumCase.owner.linkedClass
1115+ if ! firstParent.derivesFrom(enumCls) then
1116+ report.error(i " enum case does not extend its enum $enumCls" , enumCase.sourcePos)
1117+ cls.info match
1118+ case info : ClassInfo =>
1119+ cls.info = info.derivedClassInfo(classParents = enumCls.typeRefApplied :: info.classParents)
1120+ case _ =>
1121+
11081122 val enumCase =
11091123 if cls.isAllOf(EnumCase ) then cls
11101124 else if cls.isAnonymousClass && cls.owner.isAllOf(EnumCase ) then cls.owner
11111125 else NoSymbol
1112- def parentDerivesFrom (enumCls : Symbol )(using Context ) =
1113- if ! firstParent.derivesFrom(enumCls) then
1114- report.error(i " enum case does not extend its enum $enumCls" , enumCase.sourcePos)
1115- false
1116- else
1117- true
1118- ! enumCase.exists || parentDerivesFrom(enumCase.owner.linkedClass)
1119- }
1126+ if enumCase.exists then
1127+ ensureParentDerivesFrom(enumCase)
1128+
1129+ end checkEnumParent
11201130
11211131 /** Check that all references coming from enum cases in an enum companion object
11221132 * are legal.
@@ -1211,7 +1221,7 @@ trait Checking {
12111221
12121222trait ReChecking extends Checking {
12131223 import tpd ._
1214- override def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Boolean = true
1224+ override def checkEnumParent (cls : Symbol , firstParent : Symbol )(using Context ): Unit = ()
12151225 override def checkEnum (cdef : untpd.TypeDef , cls : Symbol , firstParent : Symbol )(using Context ): Unit = ()
12161226 override def checkRefsLegal (tree : tpd.Tree , badOwner : Symbol , allowed : (Name , Symbol ) => Boolean , where : String )(using Context ): Unit = ()
12171227 override def checkFullyAppliedType (tree : Tree )(using Context ): Unit = ()
0 commit comments