@@ -12,7 +12,9 @@ import typer.Typer
1212import typer .ImportInfo .withRootImports
1313import Decorators ._
1414import io .AbstractFile
15- import Phases .unfusedPhases
15+ import Phases .{unfusedPhases , Phase }
16+
17+ import sbt .interfaces .ProgressCallback
1618
1719import util ._
1820import reporting .{Suppression , Action , Profile , ActiveProfile , NoProfile }
@@ -32,6 +34,10 @@ import scala.collection.mutable
3234import scala .util .control .NonFatal
3335import scala .io .Codec
3436
37+ import Run .Progress
38+ import scala .compiletime .uninitialized
39+ import dotty .tools .dotc .transform .MegaPhase
40+
3541/** A compiler run. Exports various methods to compile source files */
3642class Run (comp : Compiler , ictx : Context ) extends ImplicitRunInfo with ConstraintRunInfo {
3743
@@ -155,14 +161,75 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
155161 }
156162
157163 /** The source files of all late entered symbols, as a set */
158- private var lateFiles = mutable.Set [AbstractFile ]()
164+ private val lateFiles = mutable.Set [AbstractFile ]()
159165
160166 /** A cache for static references to packages and classes */
161167 val staticRefs = util.EqHashMap [Name , Denotation ](initialCapacity = 1024 )
162168
163169 /** Actions that need to be performed at the end of the current compilation run */
164170 private var finalizeActions = mutable.ListBuffer [() => Unit ]()
165171
172+ private var _progress : Progress | Null = null // Set if progress reporting is enabled
173+
174+ private inline def trackProgress (using Context )(inline op : Context ?=> Progress => Unit ): Unit =
175+ foldProgress(())(op)
176+
177+ private inline def foldProgress [T ](using Context )(inline default : T )(inline op : Context ?=> Progress => T ): T =
178+ val local = _progress
179+ if local != null then
180+ op(using ctx)(local)
181+ else
182+ default
183+
184+ def didEnterUnit (unit : CompilationUnit )(using Context ): Boolean =
185+ foldProgress(true /* should progress by default */ )(_.tryEnterUnit(unit))
186+
187+ def canProgress ()(using Context ): Boolean =
188+ foldProgress(true /* not cancelled by default */ )(p => ! p.checkCancellation())
189+
190+ def doAdvanceUnit ()(using Context ): Unit =
191+ trackProgress : progress =>
192+ progress.currentUnitCount += 1 // trace that we completed a unit in the current (sub)phase
193+ progress.refreshProgress()
194+
195+ def doAdvanceLate ()(using Context ): Unit =
196+ trackProgress : progress =>
197+ progress.currentLateUnitCount += 1 // trace that we completed a late compilation
198+ progress.refreshProgress()
199+
200+ private def doEnterPhase (currentPhase : Phase )(using Context ): Unit =
201+ trackProgress : progress =>
202+ progress.enterPhase(currentPhase)
203+
204+ /** interrupt the thread and set cancellation state */
205+ private def cancelInterrupted (): Unit =
206+ try
207+ trackProgress(_.cancel())
208+ finally
209+ Thread .currentThread().nn.interrupt()
210+
211+ private def doAdvancePhase (currentPhase : Phase , wasRan : Boolean )(using Context ): Unit =
212+ trackProgress : progress =>
213+ progress.currentUnitCount = 0 // reset unit count in current (sub)phase
214+ progress.currentCompletedSubtraversalCount = 0 // reset subphase index to initial
215+ progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
216+ if wasRan then
217+ // add an extra traversal now that we completed a (sub)phase
218+ progress.completedTraversalCount += 1
219+ else
220+ // no subphases were ran, remove traversals from expected total
221+ progress.totalTraversals -= currentPhase.traversals
222+
223+ private def tryAdvanceSubPhase ()(using Context ): Unit =
224+ trackProgress : progress =>
225+ if progress.canAdvanceSubPhase then
226+ progress.currentUnitCount = 0 // reset unit count in current (sub)phase
227+ progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
228+ progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase
229+ progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase
230+ if ! progress.isCancelled() then
231+ progress.tickSubphase()
232+
166233 /** Will be set to true if any of the compiled compilation units contains
167234 * a pureFunctions language import.
168235 */
@@ -233,17 +300,20 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
233300 if ctx.settings.YnoDoubleBindings .value then
234301 ctx.base.checkNoDoubleBindings = true
235302
236- def runPhases (using Context ) = {
303+ def runPhases (allPhases : Array [ Phase ])( using Context ) = {
237304 var lastPrintedTree : PrintedTree = NoPrintedTree
238305 val profiler = ctx.profiler
239306 var phasesWereAdjusted = false
240307
241- for (phase <- ctx.base.allPhases)
242- if (phase.isRunnable)
308+ for phase <- allPhases do
309+ doEnterPhase(phase)
310+ val phaseWillRun = phase.isRunnable
311+ if phaseWillRun then
243312 Stats .trackTime(s " phase time ms/ $phase" ) {
244313 val start = System .currentTimeMillis
245314 val profileBefore = profiler.beforePhase(phase)
246- units = phase.runOn(units)
315+ try units = phase.runOn(units)
316+ catch case _ : InterruptedException => cancelInterrupted()
247317 profiler.afterPhase(phase, profileBefore)
248318 if (ctx.settings.Xprint .value.containsPhase(phase))
249319 for (unit <- units)
@@ -261,18 +331,25 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
261331 if ! Feature .ccEnabledSomewhere then
262332 ctx.base.unlinkPhaseAsDenotTransformer(Phases .checkCapturesPhase.prev)
263333 ctx.base.unlinkPhaseAsDenotTransformer(Phases .checkCapturesPhase)
264-
334+ end if
335+ end if
336+ end if
337+ doAdvancePhase(phase, wasRan = phaseWillRun)
338+ end for
265339 profiler.finished()
266340 }
267341
268342 val runCtx = ctx.fresh
269343 runCtx.setProfiler(Profiler ())
270344 unfusedPhases.foreach(_.initContext(runCtx))
271- runPhases(using runCtx)
345+ val fusedPhases = runCtx.base.allPhases
346+ runCtx.withProgressCallback: cb =>
347+ _progress = Progress (cb, this , fusedPhases.map(_.traversals).sum)
348+ runPhases(allPhases = fusedPhases)(using runCtx)
272349 if (! ctx.reporter.hasErrors)
273350 Rewrites .writeBack()
274351 suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
275- while (finalizeActions.nonEmpty) {
352+ while (finalizeActions.nonEmpty && canProgress() ) {
276353 val action = finalizeActions.remove(0 )
277354 action()
278355 }
@@ -294,10 +371,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
294371 .withRootImports
295372
296373 def process ()(using Context ) =
297- ctx.typer.lateEnterUnit(doTypeCheck =>
298- if typeCheck then
299- if compiling then finalizeActions += doTypeCheck
300- else doTypeCheck()
374+ ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck =>
375+ if compiling then finalizeActions += doTypeCheck
376+ else doTypeCheck()
301377 )
302378
303379 process()(using unitCtx)
@@ -400,7 +476,129 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
400476}
401477
402478object Run {
479+
480+ case class SubPhase (val name : String ):
481+ override def toString : String = name
482+
483+ class SubPhases (val phase : Phase ):
484+ require(phase.exists)
485+
486+ private def baseName : String = phase match
487+ case phase : MegaPhase => phase.shortPhaseName
488+ case phase => phase.phaseName
489+
490+ val all = IArray .from(phase.subPhases.map(sub => s " $baseName[ $sub] " ))
491+
492+ def next (using Context ): Option [SubPhases ] =
493+ val next0 = phase.megaPhase.next.megaPhase
494+ if next0.exists then Some (SubPhases (next0))
495+ else None
496+
497+ def size : Int = all.size
498+
499+ def subPhase (index : Int ) =
500+ if index < all.size then all(index)
501+ else baseName
502+
503+
504+ private class Progress (cb : ProgressCallback , private val run : Run , val initialTraversals : Int ):
505+ export cb .{cancel , isCancelled }
506+
507+ var totalTraversals : Int = initialTraversals // track how many phases we expect to run
508+ var currentUnitCount : Int = 0 // current unit count in the current (sub)phase
509+ var currentLateUnitCount : Int = 0 // current late unit count
510+ var completedTraversalCount : Int = 0 // completed traversals over all files
511+ var currentCompletedSubtraversalCount : Int = 0 // completed subphases in the current phase
512+ var seenPhaseCount : Int = 0 // how many phases we've seen so far
513+
514+ private var currPhase : Phase = uninitialized // initialized by enterPhase
515+ private var subPhases : SubPhases = uninitialized // initialized by enterPhase
516+ private var currPhaseName : String = uninitialized // initialized by enterPhase
517+ private var nextPhaseName : String = uninitialized // initialized by enterPhase
518+
519+ /** Enter into a new real phase, setting the current and next (sub)phases */
520+ def enterPhase (newPhase : Phase )(using Context ): Unit =
521+ if newPhase ne currPhase then
522+ currPhase = newPhase
523+ subPhases = SubPhases (newPhase)
524+ tickSubphase()
525+
526+ def canAdvanceSubPhase : Boolean =
527+ currentCompletedSubtraversalCount + 1 < subPhases.size
528+
529+ /** Compute the current (sub)phase name and next (sub)phase name */
530+ def tickSubphase ()(using Context ): Unit =
531+ val index = currentCompletedSubtraversalCount
532+ val s = subPhases
533+ currPhaseName = s.subPhase(index)
534+ nextPhaseName =
535+ if index + 1 < s.all.size then s.subPhase(index + 1 )
536+ else s.next match
537+ case None => " <end>"
538+ case Some (next0) => next0.subPhase(0 )
539+ if seenPhaseCount > 0 then
540+ refreshProgress()
541+
542+
543+ /** Counts the number of completed full traversals over files, plus the number of units in the current phase */
544+ private def currentProgress (): Int =
545+ completedTraversalCount * work() + currentUnitCount + currentLateUnitCount
546+
547+ /** Total progress is computed as the sum of
548+ * - the number of traversals we expect to make over all files
549+ * - the number of late compilations
550+ */
551+ private def totalProgress (): Int =
552+ totalTraversals * work() + run.lateFiles.size
553+
554+ private def work (): Int = run.files.size
555+
556+ private def requireInitialized (): Unit =
557+ require((currPhase : Phase | Null ) != null , " enterPhase was not called" )
558+
559+ def checkCancellation (): Boolean =
560+ if Thread .interrupted() then cancel()
561+ isCancelled()
562+
563+ /** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
564+ def tryEnterUnit (unit : CompilationUnit ): Boolean =
565+ if checkCancellation() then false
566+ else
567+ requireInitialized()
568+ cb.informUnitStarting(currPhaseName, unit)
569+ true
570+
571+ /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
572+ def refreshProgress ()(using Context ): Unit =
573+ requireInitialized()
574+ val total = totalProgress()
575+ if total > 0 && ! cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
576+ cancel()
577+
403578 extension (run : Run | Null )
579+
580+ /** record that the current phase has begun for the compilation unit of the current Context */
581+ def enterUnit (unit : CompilationUnit )(using Context ): Boolean =
582+ if run != null then run.didEnterUnit(unit)
583+ else true // don't check cancellation if we're not tracking progress
584+
585+ /** check progress cancellation, true if not cancelled */
586+ def enterRegion ()(using Context ): Boolean =
587+ if run != null then run.canProgress()
588+ else true // don't check cancellation if we're not tracking progress
589+
590+ /** advance the unit count and record progress in the current phase */
591+ def advanceUnit ()(using Context ): Unit =
592+ if run != null then run.doAdvanceUnit()
593+
594+ /** if there exists another subphase, switch to it and record progress */
595+ def enterNextSubphase ()(using Context ): Unit =
596+ if run != null then run.tryAdvanceSubPhase()
597+
598+ /** advance the late count and record progress in the current phase */
599+ def advanceLate ()(using Context ): Unit =
600+ if run != null then run.doAdvanceLate()
601+
404602 def enrichedErrorMessage : Boolean = if run == null then false else run.myEnrichedErrorMessage
405603 def enrichErrorMessage (errorMessage : String )(using Context ): String =
406604 if run == null then
0 commit comments