@@ -32,6 +32,7 @@ import java.lang.AssertionError
3232import dotty .tools .dotc .util .Positions .Position
3333import Decorators ._
3434import tpd ._
35+ import Flags ._
3536import StdNames .nme
3637
3738/**
@@ -80,54 +81,68 @@ class LabelDefs extends MiniPhaseTransform {
8081
8182 val queue = new ArrayBuffer [Tree ]()
8283
83-
84-
85- override def transformBlock (tree : tpd.Block )(implicit ctx : Context , info : TransformerInfo ): tpd.Tree = {
86- collectLabelDefs.clear
87- val newStats = collectLabelDefs.transformStats(tree.stats)
88- val newExpr = collectLabelDefs.transform(tree.expr)
89- val labelCalls = collectLabelDefs.labelCalls
90- val entryPoints = collectLabelDefs.parentLabelCalls
91- val labelDefs = collectLabelDefs.labelDefs
92-
93- // make sure that for every label there's a single location it should return and single entry point
94- // if theres already a location that it returns to that's a failure
95- val disallowed = new mutable.HashMap [Symbol , Tree ]()
96- queue.sizeHint(labelCalls.size + entryPoints.size)
97- def moveLabels (entryPoint : Tree ): List [Tree ] = {
98- if ((entryPoint.symbol is Flags .Label ) && labelDefs.contains(entryPoint.symbol)) {
99- val visitedNow = new mutable.HashMap [Symbol , Tree ]()
100- val treesToAppend = new ArrayBuffer [Tree ]() // order matters. parents should go first
101- queue.clear()
102-
103- var visited = 0
104- queue += entryPoint
105- while (visited < queue.size) {
106- val owningLabelDefSym = queue(visited).symbol
107- val owningLabelDef = labelDefs(owningLabelDefSym)
108- for (call <- labelCalls(owningLabelDefSym))
109- if (disallowed.contains(call.symbol)) {
110- val oldCall = disallowed(call.symbol)
111- ctx.error(s " Multiple return locations for Label $oldCall and $call" , call.symbol.pos)
112- } else {
113- if ((! visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) {
114- val df = labelDefs(call.symbol)
115- visitedNow.put(call.symbol, labelDefs(call.symbol))
116- queue += call
84+ override def transformDefDef (tree : tpd.DefDef )(implicit ctx : Context , info : TransformerInfo ): tpd.Tree = {
85+ if (tree.symbol is Flags .Label ) tree
86+ else {
87+ collectLabelDefs.clear
88+ val newRhs = collectLabelDefs.transform(tree.rhs)
89+ val labelCalls = collectLabelDefs.labelCalls
90+ var entryPoints = collectLabelDefs.parentLabelCalls
91+ var labelDefs = collectLabelDefs.labelDefs
92+
93+ // make sure that for every label there's a single location it should return and single entry point
94+ // if theres already a location that it returns to that's a failure
95+ val disallowed = new mutable.HashMap [Symbol , Tree ]()
96+ queue.sizeHint(labelCalls.size + entryPoints.size)
97+ def moveLabels (entryPoint : Tree ): List [Tree ] = {
98+ if ((entryPoint.symbol is Flags .Label ) && labelDefs.contains(entryPoint.symbol)) {
99+ val visitedNow = new mutable.HashMap [Symbol , Tree ]()
100+ val treesToAppend = new ArrayBuffer [Tree ]() // order matters. parents should go first
101+ queue.clear()
102+
103+ var visited = 0
104+ queue += entryPoint
105+ while (visited < queue.size) {
106+ val owningLabelDefSym = queue(visited).symbol
107+ val owningLabelDef = labelDefs(owningLabelDefSym)
108+ for (call <- labelCalls(owningLabelDefSym))
109+ if (disallowed.contains(call.symbol)) {
110+ val oldCall = disallowed(call.symbol)
111+ ctx.error(s " Multiple return locations for Label $oldCall and $call" , call.symbol.pos)
112+ } else {
113+ if ((! visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) {
114+ visitedNow.put(call.symbol, labelDefs(call.symbol))
115+ queue += call
116+ }
117117 }
118+ if (! treesToAppend.contains(owningLabelDef)) {
119+ treesToAppend += owningLabelDef
118120 }
119- if (! treesToAppend.contains(owningLabelDef))
120- treesToAppend += owningLabelDef
121- visited += 1
121+ visited += 1
122+ }
123+ disallowed ++= visitedNow
124+
125+ treesToAppend.toList
126+ } else Nil
127+ }
128+
129+ val putLabelDefsNearCallees = new TreeMap () {
130+
131+ override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
132+ tree match {
133+ case t : Apply if (entryPoints.contains(t)) =>
134+ entryPoints = entryPoints - t
135+ Block (moveLabels(t), t)
136+ case _ => if (entryPoints.nonEmpty && labelDefs.nonEmpty) super .transform(tree) else tree
137+ }
122138 }
123- disallowed ++= visitedNow
139+ }
124140
125- treesToAppend.toList
126- } else Nil
127- }
128141
129- cpy.Block (tree)(entryPoints.flatMap(moveLabels).toList ++ newStats, newExpr )
142+ val res = cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(newRhs) )
130143
144+ res
145+ }
131146 }
132147
133148 val collectLabelDefs = new TreeMap () {
@@ -137,13 +152,12 @@ class LabelDefs extends MiniPhaseTransform {
137152 var isInsideLabel = false
138153 var isInsideBlock = false
139154
140- def shouldMoveLabel = ! isInsideBlock
155+ def shouldMoveLabel = true
141156
142157 // labelSymbol -> Defining tree
143158 val labelDefs = new mutable.HashMap [Symbol , Tree ]()
144159 // owner -> all calls by this owner
145160 val labelCalls = new mutable.HashMap [Symbol , mutable.Set [Tree ]]()
146- val labelCallCounts = new mutable.HashMap [Symbol , Int ]()
147161
148162 def clear = {
149163 parentLabelCalls.clear()
@@ -175,7 +189,6 @@ class LabelDefs extends MiniPhaseTransform {
175189 } else r
176190 case t : Apply if t.symbol is Flags .Label =>
177191 parentLabelCalls = parentLabelCalls + t
178- labelCallCounts.get(t.symbol)
179192 super .transform(tree)
180193 case _ =>
181194 super .transform(tree)
0 commit comments