@@ -66,21 +66,28 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
6666
6767 recur(refs)
6868 end hidden
69+
70+ /** Deduct the footprint of `sym` and `sym*` from `refs` */
71+ private def deductSym (sym : Symbol )(using Context ) =
72+ val ref = sym.termRef
73+ if ref.isTrackableRef then refs -- CaptureSet (ref, ref.reach).elems.footprint
74+ else refs
75+
76+ /** Deduct the footprint of all captures of `deps` from `refs` */
77+ private def deductCapturesOf (deps : List [Tree ])(using Context ): Refs =
78+ deps.foldLeft(refs): (refs, dep) =>
79+ refs -- captures(dep).footprint
6980 end extension
7081
7182 /** The captures of an argument or prefix widened to the formal parameter, if
7283 * the latter contains a cap.
7384 */
7485 private def formalCaptures (arg : Tree )(using Context ): Refs =
75- val argType = arg.formalType.orElse(arg.nuType)
76- (if argType.hasUseAnnot then argType.deepCaptureSet else argType.captureSet)
77- .elems
86+ arg.formalType.orElse(arg.nuType).deepCaptureSet.elems
7887
7988 /** The captures of a node */
8089 private def captures (tree : Tree )(using Context ): Refs =
81- val tpe = tree.nuType
82- (if tree.formalType.hasUseAnnot then tpe.deepCaptureSet else tpe.captureSet)
83- .elems
90+ tree.nuType.deepCaptureSet.elems
8491
8592 private def sepApplyError (fn : Tree , args : List [Tree ], argIdx : Int ,
8693 overlap : Refs , hiddenInArg : Refs , footprints : List [(Refs , Int )],
@@ -144,7 +151,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
144151
145152 def sepUseError (tree : Tree , used : Refs , globalOverlap : Refs )(using Context ): Unit =
146153 val individualChecks = for mdefs <- previousDefs.iterator; mdef <- mdefs.iterator yield
147- val hiddenByDef = captures(mdef.tpt).hidden
154+ val hiddenByDef = captures(mdef.tpt).hidden.footprint
148155 val overlap = defUseOverlap(hiddenByDef, used, tree.symbol)
149156 if ! overlap.isEmpty then
150157 def resultStr = if mdef.isInstanceOf [DefDef ] then " result" else " "
@@ -172,20 +179,16 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
172179 val footprints = mutable.ListBuffer [(Refs , Int )]((footprint, 0 ))
173180 val indexedArgs = args.zipWithIndex
174181
175- def subtractDeps (elems : Refs , arg : Tree ): Refs =
176- deps(arg).foldLeft(elems): (elems, dep) =>
177- elems -- captures(dep).footprint
178-
179182 for (arg, idx) <- indexedArgs do
180183 if ! arg.needsSepCheck then
181- footprint = footprint ++ subtractDeps( captures(arg).footprint, arg)
184+ footprint = footprint ++ captures(arg).footprint.deductCapturesOf(deps( arg) )
182185 footprints += ((footprint, idx + 1 ))
183186 for (arg, idx) <- indexedArgs do
184187 if arg.needsSepCheck then
185188 val ac = formalCaptures(arg)
186189 val hiddenInArg = ac.hidden.footprint
187190 // println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
188- val overlap = subtractDeps( hiddenInArg.overlapWith(footprint), arg)
191+ val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps( arg) )
189192 if ! overlap.isEmpty then
190193 sepApplyError(fn, args, idx, overlap, hiddenInArg, footprints.toList, deps)
191194 footprint ++= captures(arg).footprint
@@ -267,7 +270,8 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
267270 case tree : ValOrDefDef =>
268271 traverseChildren(tree)
269272 if previousDefs.nonEmpty && ! tree.symbol.isOneOf(TermParamOrAccessor ) then
270- defsShadow ++= captures(tree.tpt).hidden.footprint
273+ capt.println(i " sep check def ${tree.symbol}: ${tree.tpt} with ${captures(tree.tpt).hidden.footprint}" )
274+ defsShadow ++= captures(tree.tpt).hidden.footprint.deductSym(tree.symbol)
271275 resultType(tree.symbol) = tree.tpt.nuType
272276 previousDefs.head += tree
273277 case _ =>
0 commit comments