@@ -319,11 +319,29 @@ object Nullables:
319319 if ! info.isEmpty then tree.putAttachment(NNInfo , info)
320320 tree
321321
322+ /* Collect the nullability info from parts of `tree` */
323+ def collectNotNullInfo (using Context ): NotNullInfo = tree match
324+ case Typed (expr, _) =>
325+ expr.notNullInfo
326+ case Apply (fn, args) =>
327+ val argsInfo = args.map(_.notNullInfo)
328+ val fnInfo = fn.notNullInfo
329+ argsInfo.foldLeft(fnInfo)(_ seq _)
330+ case TypeApply (fn, _) =>
331+ fn.notNullInfo
332+ case _ =>
333+ // Other cases are handled specially in typer.
334+ NotNullInfo .empty
335+
322336 /* The nullability info of `tree` */
323337 def notNullInfo (using Context ): NotNullInfo =
324- stripInlined(tree).getAttachment(NNInfo ) match
338+ val tree1 = stripInlined(tree)
339+ tree1.getAttachment(NNInfo ) match
325340 case Some (info) if ! ctx.erasedTypes => info
326- case _ => NotNullInfo .empty
341+ case _ =>
342+ val nnInfo = tree1.collectNotNullInfo
343+ tree1.withNotNullInfo(nnInfo)
344+ nnInfo
327345
328346 /* The nullability info of `tree`, assuming it is a condition that evaluates to `c` */
329347 def notNullInfoIf (c : Boolean )(using Context ): NotNullInfo =
@@ -404,21 +422,23 @@ object Nullables:
404422 end extension
405423
406424 extension (tree : Assign )
407- def computeAssignNullable ()(using Context ): tree.type = tree.lhs match
408- case TrackedRef (ref) =>
409- val rhstp = tree.rhs.typeOpt
410- if ctx.explicitNulls && ref.isNullableUnion then
411- if rhstp.isNullType || rhstp.isNullableUnion then
412- // If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
413- // lhs variable is no longer trackable. We don't need to check whether the type `T`
414- // is correct here, as typer will check it.
415- tree.withNotNullInfo(NotNullInfo (Set (), Set (ref)))
416- else
417- // If the initial type is nullable and the assigned value is non-null,
418- // we add it to the NotNull.
419- tree.withNotNullInfo(NotNullInfo (Set (ref), Set ()))
420- else tree
421- case _ => tree
425+ def computeAssignNullable ()(using Context ): tree.type =
426+ var nnInfo = tree.rhs.notNullInfo
427+ tree.lhs match
428+ case TrackedRef (ref) if ctx.explicitNulls && ref.isNullableUnion =>
429+ nnInfo = nnInfo.seq:
430+ val rhstp = tree.rhs.typeOpt
431+ if rhstp.isNullType || rhstp.isNullableUnion then
432+ // If the type of rhs is nullable (`T|Null` or `Null`), then the nullability of the
433+ // lhs variable is no longer trackable. We don't need to check whether the type `T`
434+ // is correct here, as typer will check it.
435+ NotNullInfo (Set (), Set (ref))
436+ else
437+ // If the initial type is nullable and the assigned value is non-null,
438+ // we add it to the NotNull.
439+ NotNullInfo (Set (ref), Set ())
440+ case _ =>
441+ tree.withNotNullInfo(nnInfo)
422442 end extension
423443
424444 private val analyzedOps = Set (nme.EQ , nme.NE , nme.eq, nme.ne, nme.ZAND , nme.ZOR , nme.UNARY_! )
0 commit comments