Skip to content

Commit 11f565c

Browse files
committed
Update progressbar managers with existing fit results from ZarrTrace
1 parent 6b08e89 commit 11f565c

File tree

4 files changed

+35
-7
lines changed

4 files changed

+35
-7
lines changed

pymc/progress_bar.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def __init__(
284284
self.update_stats_functions = step_method._make_progressbar_update_functions()
285285

286286
self._show_progress = show_progress
287+
self.draws = 0
287288
self.completed_draws = 0
288289
self.total_draws = draws + tune
289290
self.desc = "Sampling chain"
@@ -299,13 +300,18 @@ def __enter__(self):
299300
def __exit__(self, exc_type, exc_val, exc_tb):
300301
return self._progress.__exit__(exc_type, exc_val, exc_tb)
301302

303+
def set_initial_state(self, draws: int = 0, divergences: int = 0):
304+
self.draws = draws
305+
self.completed_draws += draws
306+
self.divergences += divergences
307+
302308
def _initialize_tasks(self):
303309
if self.combined_progress:
304310
self.tasks = [
305311
self._progress.add_task(
306312
self.desc.format(self),
307-
completed=0,
308-
draws=0,
313+
completed=self.completed_draws,
314+
draws=self.completed_draws,
309315
total=self.total_draws * self.chains - 1,
310316
chain_idx=0,
311317
sampling_speed=0,
@@ -319,14 +325,17 @@ def _initialize_tasks(self):
319325
self.tasks = [
320326
self._progress.add_task(
321327
self.desc.format(self),
322-
completed=0,
323-
draws=0,
328+
completed=self.completed_draws,
329+
draws=self.draws,
324330
total=self.total_draws - 1,
325331
chain_idx=chain_idx,
326332
sampling_speed=0,
327333
speed_unit="draws/s",
328334
failing=False,
329-
**{stat: value[chain_idx] for stat, value in self.progress_stats.items()},
335+
**{
336+
stat: value[0] if stat != "diverging" else self.divergences
337+
for stat, value in self.progress_stats.items()
338+
},
330339
)
331340
for chain_idx in range(self.chains)
332341
]

pymc/sampling/mcmc.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,13 +1150,25 @@ def _sample_many(
11501150

11511151
with progress_manager:
11521152
for i in range(chains):
1153+
trace = traces[i]
1154+
progress_manager.set_initial_state(
1155+
*trace.completed_draws_and_divergences(chain_specific=True)
1156+
)
1157+
progress_manager._progress.update(
1158+
progress_manager.tasks[i],
1159+
draws=progress_manager.completed_draws
1160+
if progress_manager.combined_progress
1161+
else progress_manager.draws,
1162+
divergences=progress_manager.divergences,
1163+
refresh=True,
1164+
)
11531165
step.sampling_state = initial_step_state
11541166
_sample(
11551167
draws=draws,
11561168
chain=i,
11571169
start=start[i],
11581170
step=step,
1159-
trace=traces[i],
1171+
trace=trace,
11601172
rng=rngs[i],
11611173
callback=callback,
11621174
progress_manager=progress_manager,

pymc/sampling/parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,9 @@ def __init__(
511511
progressbar=progressbar,
512512
progressbar_theme=progressbar_theme,
513513
)
514+
if traces is not None:
515+
for trace in traces:
516+
self._progress.set_initial_state(*trace.completed_draws_and_divergences())
514517

515518
def _make_active(self):
516519
while self._inactive and len(self._active) < self._max_active:

pymc/sampling/population.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from rich.progress import BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn
2828

2929
from pymc.backends.base import BaseTrace
30-
from pymc.backends.zarr import ZarrChain
3130
from pymc.initial_point import PointType
3231
from pymc.model import Model, modelcontext
3332
from pymc.progress_bar import CustomProgress
@@ -110,6 +109,10 @@ def _sample_population(
110109

111110
with CustomProgress(disable=not progressbar) as progress:
112111
task = progress.add_task("[red]Sampling...", total=draws)
112+
for trace in traces:
113+
progress.update(
114+
task, completed=trace.completed_draws_and_divergences(chain_specific=True)[0]
115+
)
113116
for _ in sampling:
114117
progress.update(task)
115118

@@ -197,6 +200,7 @@ def __init__(
197200
# enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
198201
# ):
199202
task = self._progress.add_task(description=f"Chain {c}")
203+
self._progress.update(task, completed=first_draw_idx)
200204
secondary_end, primary_end = multiprocessing.Pipe()
201205
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
202206
process = multiprocessing.Process(

0 commit comments

Comments
 (0)