@@ -10,11 +10,28 @@ import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
1010import CaptureSet .{Refs , emptySet , HiddenSet }
1111import config .Printers .capt
1212import StdNames .nme
13- import util .{SimpleIdentitySet , EqHashMap }
13+ import util .{SimpleIdentitySet , EqHashMap , SrcPos }
14+
15+ object SepChecker :
16+
17+ /** Enumerates kinds of captures encountered so far */
18+ enum Captures :
19+ case None
20+ case Explicit // one or more explicitly declared captures
21+ case Hidden // exacttly one hidden captures
22+ case NeedsCheck // one hidden capture and one other capture (hidden or declared)
23+
24+ def add (that : Captures ): Captures =
25+ if this == None then that
26+ else if that == None then this
27+ else if this == Explicit && that == Explicit then Explicit
28+ else NeedsCheck
29+ end Captures
1430
1531class SepChecker (checker : CheckCaptures .CheckerAPI ) extends tpd.TreeTraverser :
1632 import tpd .*
1733 import checker .*
34+ import SepChecker .*
1835
1936 /** The set of capabilities that are hidden by a polymorphic result type
2037 * of some previous definition.
@@ -52,21 +69,17 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
5269
5370 private def hidden (using Context ): Refs =
5471 val seen : util.EqHashSet [CaptureRef ] = new util.EqHashSet
55-
56- def hiddenByElem (elem : CaptureRef ): Refs =
57- if seen.add(elem) then elem match
58- case Fresh .Cap (hcs) => hcs.elems.filter(! _.isRootCapability) ++ recur(hcs.elems)
59- case ReadOnlyCapability (ref) => hiddenByElem(ref).map(_.readOnly)
60- case _ => emptySet
61- else emptySet
62-
6372 def recur (cs : Refs ): Refs =
6473 (emptySet /: cs): (elems, elem) =>
65- elems ++ hiddenByElem(elem)
66-
74+ if seen.add(elem) then elems ++ hiddenByElem(elem, recur )
75+ else elems
6776 recur(refs)
6877 end hidden
6978
79+ private def containsHidden (using Context ): Boolean =
80+ refs.exists: ref =>
81+ ! hiddenByElem(ref, _ => emptySet).isEmpty
82+
7083 /** Deduct the footprint of `sym` and `sym*` from `refs` */
7184 private def deductSym (sym : Symbol )(using Context ) =
7285 val ref = sym.termRef
@@ -79,6 +92,11 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
7992 refs -- captures(dep).footprint
8093 end extension
8194
95+ private def hiddenByElem (ref : CaptureRef , recur : Refs => Refs )(using Context ): Refs = ref match
96+ case Fresh .Cap (hcs) => hcs.elems.filter(! _.isRootCapability) ++ recur(hcs.elems)
97+ case ReadOnlyCapability (ref1) => hiddenByElem(ref1, recur).map(_.readOnly)
98+ case _ => emptySet
99+
82100 /** The captures of an argument or prefix widened to the formal parameter, if
83101 * the latter contains a cap.
84102 */
@@ -186,6 +204,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
186204 for (arg, idx) <- indexedArgs do
187205 if arg.needsSepCheck then
188206 val ac = formalCaptures(arg)
207+ checkType(arg.formalType, arg.srcPos, NoSymbol , " the argument's adapted type" )
189208 val hiddenInArg = ac.hidden.footprint
190209 // println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
191210 val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
@@ -212,6 +231,105 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
212231 if ! overlap.isEmpty then
213232 sepUseError(tree, usedFootprint, overlap)
214233
234+ def checkType (tpt : Tree , sym : Symbol )(using Context ): Unit =
235+ checkType(tpt.nuType, tpt.srcPos, sym, " " )
236+
237+ /** Check that all parts of type `tpe` are separated.
238+ * @param tpe the type to check
239+ * @param pos position for error reporting
240+ * @param sym if `tpe` is the (result-) type of a val or def, the symbol of
241+ * this definition, otherwise NoSymbol. If `sym` exists we
242+ * deduct its associated direct and reach capabilities everywhere
243+ * from the capture sets we check.
244+ * @param what a string describing what kind of type it is
245+ */
246+ def checkType (tpe : Type , pos : SrcPos , sym : Symbol , what : String )(using Context ): Unit =
247+
248+ def checkParts (parts : List [Type ]): Unit =
249+ var footprint : Refs = emptySet
250+ var hiddenSet : Refs = emptySet
251+ var checked = 0
252+ for part <- parts do
253+
254+ /** Report an error if `current` and `next` overlap.
255+ * @param current the footprint or hidden set seen so far
256+ * @param next the footprint or hidden set of the next part
257+ * @param mapRefs a function over the capture set elements of the next part
258+ * that returns the references of the same kind as `current`
259+ * (i.e. the part's footprint or hidden set)
260+ * @param prevRel a verbal description of current ("references or "hides")
261+ * @param nextRel a verbal descriiption of next
262+ */
263+ def checkSep (current : Refs , next : Refs , mapRefs : Refs => Refs , prevRel : String , nextRel : String ): Unit =
264+ val globalOverlap = current.overlapWith(next)
265+ if ! globalOverlap.isEmpty then
266+ val (prevStr, prevRefs, overlap) = parts.iterator.take(checked)
267+ .map: prev =>
268+ val prevRefs = mapRefs(prev.deepCaptureSet.elems).footprint.deductSym(sym)
269+ (i " , $prev , " , prevRefs, prevRefs.overlapWith(next))
270+ .dropWhile(_._3.isEmpty)
271+ .nextOption
272+ .getOrElse((" " , current, globalOverlap))
273+ report.error(
274+ em """ Separation failure in $what type $tpe.
275+ |One part, $part , $nextRel ${CaptureSet (next)}.
276+ |A previous part $prevStr $prevRel ${CaptureSet (prevRefs)}.
277+ |The two sets overlap at ${CaptureSet (overlap)}. """ ,
278+ pos)
279+
280+ val partRefs = part.deepCaptureSet.elems
281+ val partFootprint = partRefs.footprint.deductSym(sym)
282+ val partHidden = partRefs.hidden.footprint.deductSym(sym) -- partFootprint
283+
284+ checkSep(footprint, partHidden, identity, " references" , " hides" )
285+ checkSep(hiddenSet, partHidden, _.hidden, " also hides" , " hides" )
286+ checkSep(hiddenSet, partFootprint, _.hidden, " hides" , " references" )
287+
288+ footprint ++= partFootprint
289+ hiddenSet ++= partHidden
290+ checked += 1
291+ end for
292+ end checkParts
293+
294+ object traverse extends TypeAccumulator [Captures ]:
295+
296+ /** A stack of part lists to check. We maintain this since immediately
297+ * checking parts when traversing the type would check innermost to oputermost.
298+ * But we want to check outermost parts first since this prioritized errors
299+ * that are more obvious.
300+ */
301+ var toCheck : List [List [Type ]] = Nil
302+
303+ private val seen = util.HashSet [Symbol ]()
304+
305+ def apply (c : Captures , t : Type ) =
306+ if variance < 0 then c
307+ else
308+ val t1 = t.dealias
309+ t1 match
310+ case t @ AppliedType (tycon, args) =>
311+ val c1 = foldOver(Captures .None , t)
312+ if c1 == Captures .NeedsCheck then
313+ toCheck = (tycon :: args) :: toCheck
314+ c.add(c1)
315+ case t @ CapturingType (parent, cs) =>
316+ val c1 = this (c, parent)
317+ if cs.elems.containsHidden then c1.add(Captures .Hidden )
318+ else if ! cs.elems.isEmpty then c1.add(Captures .Explicit )
319+ else c1
320+ case t : TypeRef if t.symbol.isAbstractOrParamType =>
321+ if seen.contains(t.symbol) then c
322+ else
323+ seen += t.symbol
324+ apply(apply(c, t.prefix), t.info.bounds.hi)
325+ case t =>
326+ foldOver(c, t)
327+
328+ if ! tpe.hasAnnotation(defn.UntrackedCapturesAnnot ) then
329+ traverse(Captures .None , tpe)
330+ traverse.toCheck.foreach(checkParts)
331+ end checkType
332+
215333 private def collectMethodTypes (tp : Type ): List [TermLambda ] = tp match
216334 case tp : MethodType => tp :: collectMethodTypes(tp.resType)
217335 case tp : PolyType => collectMethodTypes(tp.resType)
@@ -231,7 +349,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
231349 (formal, arg) <- mt.paramInfos.zip(args)
232350 dep <- formal.captureSet.elems.toList
233351 do
234- val referred = dep match
352+ val referred = dep.stripReach match
235353 case dep : TermParamRef =>
236354 argMap(dep.binder)(dep.paramNum) :: Nil
237355 case dep : ThisType if dep.cls == fn.symbol.owner =>
@@ -269,11 +387,13 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
269387 defsShadow = saved
270388 case tree : ValOrDefDef =>
271389 traverseChildren(tree)
272- if previousDefs.nonEmpty && ! tree.symbol.isOneOf(TermParamOrAccessor ) then
273- capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
274- defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
275- resultType(tree.symbol) = tree.tpt.nuType
276- previousDefs.head += tree
390+ if ! tree.symbol.isOneOf(TermParamOrAccessor ) then
391+ checkType(tree.tpt, tree.symbol)
392+ if previousDefs.nonEmpty then
393+ capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
394+ defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
395+ resultType(tree.symbol) = tree.tpt.nuType
396+ previousDefs.head += tree
277397 case _ =>
278398 traverseChildren(tree)
279399end SepChecker
0 commit comments