Skip to content

Commit 440ca46

Browse files
committed
Adapt to updates in upstream behavior
1 parent a742fa0 commit 440ca46

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

pymc/backends/ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,8 @@ def completed_draws_and_divergences(self, chain_specific: bool = True) -> tuple[
232232
for sampler_stats in self._stats:
233233
for key, data in sampler_stats.items():
234234
if "divergence" in key:
235-
divergent_draws += np.asarray(data)
236-
divergences = sum(divergent_draws > 0)
235+
divergent_draws += np.asarray(data[: len(self)])
236+
divergences = int(sum(divergent_draws > 0, start=0))
237237
return len(self), divergences
238238

239239

pymc/sampling/mcmc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1155,7 +1155,9 @@ def _sample_many(
11551155
*trace.completed_draws_and_divergences(chain_specific=True)
11561156
)
11571157
progress_manager._progress.update(
1158-
progress_manager.tasks[i],
1158+
progress_manager.tasks[0]
1159+
if progress_manager.combined_progress
1160+
else progress_manager.tasks[i],
11591161
draws=progress_manager.completed_draws
11601162
if progress_manager.combined_progress
11611163
else progress_manager.draws,

pymc/step_methods/hmc/base_hmc.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class BaseHMCState(StepMethodState):
6969
tune: bool
7070
potential: PotentialState
7171
_num_divs_sample: int
72+
divergences: int
7273

7374

7475
class BaseHMC(GradientSharedStep):

0 commit comments

Comments
 (0)