@@ -28,6 +28,7 @@ import Decorators._
2828import Uniques ._
2929import ErrorReporting .{err , errorType }
3030import config .Printers .typr
31+ import NameKinds .DefaultGetterName
3132
3233import collection .mutable
3334import SymDenotations .NoCompleter
@@ -741,9 +742,100 @@ trait Checking {
741742 tp.foreachPart(check, stopAtStatic = true )
742743 tp
743744 }
745+
746+ /** Check that all non-synthetic references of the form `<ident>` or
747+ * `this.<ident>` in `tree` that refer to a member of `badOwner` are
748+ * `allowed`.
749+ */
750+ def checkRefsLegal (tree : tpd.Tree , badOwner : Symbol , allowed : (Name , Symbol ) => Boolean , where : String )(implicit ctx : Context ): Unit = {
751+ tree.foreachSubTree { tree =>
752+ tree match {
753+ case Ident (_) | Select (This (_), _) if tree.pos.isSourceDerived =>
754+ val sym = tree.symbol
755+ if (sym.maybeOwner == badOwner && ! allowed(tree.asInstanceOf [RefTree ].name, sym))
756+ ctx.error(i " illegal reference to $sym from $where: $tree // ${tree.toString}" , tree.pos)
757+ case _ =>
758+ }
759+ }
760+ }
761+
762+ /** Check that all case classes that extend `scala.Enum` are `enum` cases */
763+ def checkEnum (cdef : untpd.TypeDef , cls : Symbol )(implicit ctx : Context ): Unit = {
764+ import untpd .modsDeco
765+ def isEnumAnonCls =
766+ cls.isAnonymousClass &&
767+ cls.owner.isTerm &&
768+ (cls.owner.flagsUNSAFE.is(Case ) || cls.owner.name == nme.DOLLAR_NEW )
769+ if (! cdef.mods.hasMod[untpd.Mod .EnumCase ] && ! isEnumAnonCls)
770+ ctx.error(em " normal case $cls in ${cls.owner} cannot extend an enum " , cdef.pos)
771+ }
772+
773+ /** Check that all references coming from enum cases in an enum companion object
774+ * are legal.
775+ * @param cdef the enum companion object class
776+ * @param enumCtx the context immediately enclosing the corresponding enum
777+ */
778+ private def checkEnumCaseRefsLegal (cdef : TypeDef , enumCtx : Context )(implicit ctx : Context ): Unit = {
779+ def check (tree : Tree ) = {
780+ // allow access to `sym` if a typedIdent just outside the enclosing enum
781+ // would have produced the same symbol without errors
782+ def allowAccess (name : Name , sym : Symbol ): Boolean = {
783+ val testCtx = enumCtx.fresh.setNewTyperState()
784+ val ref = ctx.typer.typedIdent(untpd.Ident (name), WildcardType )(testCtx)
785+ ref.symbol == sym && ! testCtx.reporter.hasErrors
786+ }
787+ checkRefsLegal(tree, cdef.symbol, allowAccess, " enum case" )
788+ }
789+ cdef.rhs match {
790+ case impl : Template =>
791+ for (stat <- impl.body)
792+ if (stat.symbol.is(Case ))
793+ stat match {
794+ case TypeDef (_, Template (DefDef (_, tparams, vparamss, _, _), parents, _, _)) =>
795+ tparams.foreach(check)
796+ vparamss.foreach(_.foreach(check))
797+ parents.foreach(check)
798+ case vdef : ValDef =>
799+ vdef.rhs match {
800+ case Block ((clsDef @ TypeDef (_, impl : Template )) :: Nil , _)
801+ if clsDef.symbol.isAnonymousClass =>
802+ impl.parents.foreach(check)
803+ case _ =>
804+ }
805+ case _ =>
806+ }
807+ else if (stat.symbol.is(Module ) && stat.symbol.linkedClass.is(Case ))
808+ stat match {
809+ case TypeDef (_, impl : Template ) =>
810+ for ((defaultGetter @
811+ DefDef (DefaultGetterName (nme.CONSTRUCTOR , _), _, _, _, _)) <- impl.body)
812+ check(defaultGetter.rhs)
813+ case _ =>
814+ }
815+ case _ =>
816+ }
817+ }
818+
819+ /** Check all enum cases in all enum companions in `stats` for legal accesses.
820+ * @param enumContexts a map from`enum` symbols to the contexts enclosing their definitions
821+ */
822+ def checkEnumCompanions (stats : List [Tree ], enumContexts : collection.Map [Symbol , Context ])(implicit ctx : Context ): List [Tree ] = {
823+ for (stat @ TypeDef (_, _) <- stats)
824+ if (stat.symbol.is(Module ))
825+ for (enumContext <- enumContexts.get(stat.symbol.linkedClass))
826+ checkEnumCaseRefsLegal(stat, enumContext)
827+ stats
828+ }
829+ }
830+
831+ trait ReChecking extends Checking {
832+ import tpd ._
833+ override def checkEnum (cdef : untpd.TypeDef , cls : Symbol )(implicit ctx : Context ): Unit = ()
834+ override def checkRefsLegal (tree : tpd.Tree , badOwner : Symbol , allowed : (Name , Symbol ) => Boolean , where : String )(implicit ctx : Context ): Unit = ()
835+ override def checkEnumCompanions (stats : List [Tree ], enumContexts : collection.Map [Symbol , Context ])(implicit ctx : Context ): List [Tree ] = stats
744836}
745837
746- trait NoChecking extends Checking {
838+ trait NoChecking extends ReChecking {
747839 import tpd ._
748840 override def checkNonCyclic (sym : Symbol , info : TypeBounds , reportErrors : Boolean )(implicit ctx : Context ): Type = info
749841 override def checkValue (tree : Tree , proto : Type )(implicit ctx : Context ): tree.type = tree
0 commit comments