@@ -12,7 +12,9 @@ import typer.Typer
1212import typer .ImportInfo .withRootImports
1313import Decorators ._
1414import io .AbstractFile
15- import Phases .unfusedPhases
15+ import Phases .{unfusedPhases , Phase }
16+
17+ import sbt .interfaces .ProgressCallback
1618
1719import util ._
1820import reporting .{Suppression , Action , Profile , ActiveProfile , NoProfile }
@@ -32,6 +34,10 @@ import scala.collection.mutable
3234import scala .util .control .NonFatal
3335import scala .io .Codec
3436
37+ import Run .Progress
38+ import scala .compiletime .uninitialized
39+ import dotty .tools .dotc .transform .MegaPhase
40+
3541/** A compiler run. Exports various methods to compile source files */
3642class Run (comp : Compiler , ictx : Context ) extends ImplicitRunInfo with ConstraintRunInfo {
3743
@@ -155,14 +161,75 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
155161 }
156162
157163 /** The source files of all late entered symbols, as a set */
158- private var lateFiles = mutable.Set [AbstractFile ]()
164+ private val lateFiles = mutable.Set [AbstractFile ]()
159165
160166 /** A cache for static references to packages and classes */
161167 val staticRefs = util.EqHashMap [Name , Denotation ](initialCapacity = 1024 )
162168
163169 /** Actions that need to be performed at the end of the current compilation run */
164170 private var finalizeActions = mutable.ListBuffer [() => Unit ]()
165171
172+ private var _progress : Progress | Null = null // Set if progress reporting is enabled
173+
174+ private inline def trackProgress (using Context )(inline op : Context ?=> Progress => Unit ): Unit =
175+ foldProgress(())(op)
176+
177+ private inline def foldProgress [T ](using Context )(inline default : T )(inline op : Context ?=> Progress => T ): T =
178+ val local = _progress
179+ if local != null then
180+ op(using ctx)(local)
181+ else
182+ default
183+
184+ def didEnterUnit (unit : CompilationUnit )(using Context ): Boolean =
185+ foldProgress(true /* should progress by default */ )(_.tryEnterUnit(unit))
186+
187+ def canProgress ()(using Context ): Boolean =
188+ foldProgress(true /* not cancelled by default */ )(p => ! p.checkCancellation())
189+
190+ def doAdvanceUnit ()(using Context ): Unit =
191+ trackProgress : progress =>
192+ progress.currentUnitCount += 1 // trace that we completed a unit in the current (sub)phase
193+ progress.refreshProgress()
194+
195+ def doAdvanceLate ()(using Context ): Unit =
196+ trackProgress : progress =>
197+ progress.currentLateUnitCount += 1 // trace that we completed a late compilation
198+ progress.refreshProgress()
199+
200+ private def doEnterPhase (currentPhase : Phase )(using Context ): Unit =
201+ trackProgress : progress =>
202+ progress.enterPhase(currentPhase)
203+
204+ /** interrupt the thread and set cancellation state */
205+ private def cancelInterrupted (): Unit =
206+ try
207+ trackProgress(_.cancel())
208+ finally
209+ Thread .currentThread().nn.interrupt()
210+
211+ private def doAdvancePhase (currentPhase : Phase , wasRan : Boolean )(using Context ): Unit =
212+ trackProgress : progress =>
213+ progress.currentUnitCount = 0 // reset unit count in current (sub)phase
214+ progress.currentCompletedSubtraversalCount = 0 // reset subphase index to initial
215+ progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
216+ if wasRan then
217+ // add an extra traversal now that we completed a (sub)phase
218+ progress.completedTraversalCount += 1
219+ else
220+ // no subphases were ran, remove traversals from expected total
221+ progress.totalTraversals -= currentPhase.traversals
222+
223+ private def tryAdvanceSubPhase ()(using Context ): Unit =
224+ trackProgress : progress =>
225+ if progress.canAdvanceSubPhase then
226+ progress.currentUnitCount = 0 // reset unit count in current (sub)phase
227+ progress.seenPhaseCount += 1 // trace that we've seen a (sub)phase
228+ progress.completedTraversalCount += 1 // add an extra traversal now that we completed a (sub)phase
229+ progress.currentCompletedSubtraversalCount += 1 // record that we've seen a subphase
230+ if ! progress.isCancelled() then
231+ progress.tickSubphase()
232+
166233 /** Will be set to true if any of the compiled compilation units contains
167234 * a pureFunctions language import.
168235 */
@@ -233,17 +300,20 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
233300 if ctx.settings.YnoDoubleBindings .value then
234301 ctx.base.checkNoDoubleBindings = true
235302
236- def runPhases (using Context ) = {
303+ def runPhases (allPhases : Array [ Phase ])( using Context ) = {
237304 var lastPrintedTree : PrintedTree = NoPrintedTree
238305 val profiler = ctx.profiler
239306 var phasesWereAdjusted = false
240307
241- for (phase <- ctx.base.allPhases)
242- if (phase.isRunnable)
308+ for phase <- allPhases do
309+ doEnterPhase(phase)
310+ val phaseWillRun = phase.isRunnable
311+ if phaseWillRun then
243312 Stats .trackTime(s " phase time ms/ $phase" ) {
244313 val start = System .currentTimeMillis
245314 val profileBefore = profiler.beforePhase(phase)
246- units = phase.runOn(units)
315+ try units = phase.runOn(units)
316+ catch case _ : InterruptedException => cancelInterrupted()
247317 profiler.afterPhase(phase, profileBefore)
248318 if (ctx.settings.Xprint .value.containsPhase(phase))
249319 for (unit <- units)
@@ -260,18 +330,25 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
260330 if ! Feature .ccEnabledSomewhere then
261331 ctx.base.unlinkPhaseAsDenotTransformer(Phases .checkCapturesPhase.prev)
262332 ctx.base.unlinkPhaseAsDenotTransformer(Phases .checkCapturesPhase)
263-
333+ end if
334+ end if
335+ end if
336+ doAdvancePhase(phase, wasRan = phaseWillRun)
337+ end for
264338 profiler.finished()
265339 }
266340
267341 val runCtx = ctx.fresh
268342 runCtx.setProfiler(Profiler ())
269343 unfusedPhases.foreach(_.initContext(runCtx))
270- runPhases(using runCtx)
344+ val fusedPhases = runCtx.base.allPhases
345+ runCtx.withProgressCallback: cb =>
346+ _progress = Progress (cb, this , fusedPhases.map(_.traversals).sum)
347+ runPhases(allPhases = fusedPhases)(using runCtx)
271348 if (! ctx.reporter.hasErrors)
272349 Rewrites .writeBack()
273350 suppressions.runFinished(hasErrors = ctx.reporter.hasErrors)
274- while (finalizeActions.nonEmpty) {
351+ while (finalizeActions.nonEmpty && canProgress() ) {
275352 val action = finalizeActions.remove(0 )
276353 action()
277354 }
@@ -293,10 +370,9 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
293370 .withRootImports
294371
295372 def process ()(using Context ) =
296- ctx.typer.lateEnterUnit(doTypeCheck =>
297- if typeCheck then
298- if compiling then finalizeActions += doTypeCheck
299- else doTypeCheck()
373+ ctx.typer.lateEnterUnit(typeCheck)(doTypeCheck =>
374+ if compiling then finalizeActions += doTypeCheck
375+ else doTypeCheck()
300376 )
301377
302378 process()(using unitCtx)
@@ -399,7 +475,129 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
399475}
400476
401477object Run {
478+
479+ case class SubPhase (val name : String ):
480+ override def toString : String = name
481+
482+ class SubPhases (val phase : Phase ):
483+ require(phase.exists)
484+
485+ private def baseName : String = phase match
486+ case phase : MegaPhase => phase.shortPhaseName
487+ case phase => phase.phaseName
488+
489+ val all = IArray .from(phase.subPhases.map(sub => s " $baseName[ $sub] " ))
490+
491+ def next (using Context ): Option [SubPhases ] =
492+ val next0 = phase.megaPhase.next.megaPhase
493+ if next0.exists then Some (SubPhases (next0))
494+ else None
495+
496+ def size : Int = all.size
497+
498+ def subPhase (index : Int ) =
499+ if index < all.size then all(index)
500+ else baseName
501+
502+
503+ private class Progress (cb : ProgressCallback , private val run : Run , val initialTraversals : Int ):
504+ export cb .{cancel , isCancelled }
505+
506+ var totalTraversals : Int = initialTraversals // track how many phases we expect to run
507+ var currentUnitCount : Int = 0 // current unit count in the current (sub)phase
508+ var currentLateUnitCount : Int = 0 // current late unit count
509+ var completedTraversalCount : Int = 0 // completed traversals over all files
510+ var currentCompletedSubtraversalCount : Int = 0 // completed subphases in the current phase
511+ var seenPhaseCount : Int = 0 // how many phases we've seen so far
512+
513+ private var currPhase : Phase = uninitialized // initialized by enterPhase
514+ private var subPhases : SubPhases = uninitialized // initialized by enterPhase
515+ private var currPhaseName : String = uninitialized // initialized by enterPhase
516+ private var nextPhaseName : String = uninitialized // initialized by enterPhase
517+
518+ /** Enter into a new real phase, setting the current and next (sub)phases */
519+ def enterPhase (newPhase : Phase )(using Context ): Unit =
520+ if newPhase ne currPhase then
521+ currPhase = newPhase
522+ subPhases = SubPhases (newPhase)
523+ tickSubphase()
524+
525+ def canAdvanceSubPhase : Boolean =
526+ currentCompletedSubtraversalCount + 1 < subPhases.size
527+
528+ /** Compute the current (sub)phase name and next (sub)phase name */
529+ def tickSubphase ()(using Context ): Unit =
530+ val index = currentCompletedSubtraversalCount
531+ val s = subPhases
532+ currPhaseName = s.subPhase(index)
533+ nextPhaseName =
534+ if index + 1 < s.all.size then s.subPhase(index + 1 )
535+ else s.next match
536+ case None => " <end>"
537+ case Some (next0) => next0.subPhase(0 )
538+ if seenPhaseCount > 0 then
539+ refreshProgress()
540+
541+
542+ /** Counts the number of completed full traversals over files, plus the number of units in the current phase */
543+ private def currentProgress (): Int =
544+ completedTraversalCount * work() + currentUnitCount + currentLateUnitCount
545+
546+ /** Total progress is computed as the sum of
547+ * - the number of traversals we expect to make over all files
548+ * - the number of late compilations
549+ */
550+ private def totalProgress (): Int =
551+ totalTraversals * work() + run.lateFiles.size
552+
553+ private def work (): Int = run.files.size
554+
555+ private def requireInitialized (): Unit =
556+ require((currPhase : Phase | Null ) != null , " enterPhase was not called" )
557+
558+ def checkCancellation (): Boolean =
559+ if Thread .interrupted() then cancel()
560+ isCancelled()
561+
562+ /** trace that we are beginning a unit in the current (sub)phase, unless cancelled */
563+ def tryEnterUnit (unit : CompilationUnit ): Boolean =
564+ if checkCancellation() then false
565+ else
566+ requireInitialized()
567+ cb.informUnitStarting(currPhaseName, unit)
568+ true
569+
570+ /** trace the current progress out of the total, in the current (sub)phase, reporting the next (sub)phase */
571+ def refreshProgress ()(using Context ): Unit =
572+ requireInitialized()
573+ val total = totalProgress()
574+ if total > 0 && ! cb.progress(currentProgress(), total, currPhaseName, nextPhaseName) then
575+ cancel()
576+
402577 extension (run : Run | Null )
578+
579+ /** record that the current phase has begun for the compilation unit of the current Context */
580+ def enterUnit (unit : CompilationUnit )(using Context ): Boolean =
581+ if run != null then run.didEnterUnit(unit)
582+ else true // don't check cancellation if we're not tracking progress
583+
584+ /** check progress cancellation, true if not cancelled */
585+ def enterRegion ()(using Context ): Boolean =
586+ if run != null then run.canProgress()
587+ else true // don't check cancellation if we're not tracking progress
588+
589+ /** advance the unit count and record progress in the current phase */
590+ def advanceUnit ()(using Context ): Unit =
591+ if run != null then run.doAdvanceUnit()
592+
593+ /** if there exists another subphase, switch to it and record progress */
594+ def enterNextSubphase ()(using Context ): Unit =
595+ if run != null then run.tryAdvanceSubPhase()
596+
597+ /** advance the late count and record progress in the current phase */
598+ def advanceLate ()(using Context ): Unit =
599+ if run != null then run.doAdvanceLate()
600+
403601 def enrichedErrorMessage : Boolean = if run == null then false else run.myEnrichedErrorMessage
404602 def enrichErrorMessage (errorMessage : String )(using Context ): String =
405603 if run == null then
0 commit comments