Skip to content

Commit fb156af

Browse files
Fix number of tune steps with cores>1 (#3821)
* add regression test for #3819 * check for tune stop before making a draw closes #3819 * add some type annotations
1 parent 906a01b commit fb156af

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

pymc3/parallel_sampling.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ class _Process(multiprocessing.Process):
110110
and send finished samples using shared memory.
111111
"""
112112

113-
def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed):
113+
def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed):
114114
super().__init__(daemon=True, name=name)
115115
self._msg_pipe = msg_pipe
116116
self._step_method = step_method
@@ -173,6 +173,10 @@ def _start_loop(self):
173173
raise ValueError("Unexpected msg " + msg[0])
174174

175175
while True:
176+
if draw == self._tune:
177+
self._step_method.stop_tuning()
178+
tuning = False
179+
176180
if draw < self._draws + self._tune:
177181
try:
178182
point, stats = self._compute_point()
@@ -183,10 +187,6 @@ def _start_loop(self):
183187
else:
184188
return
185189

186-
if draw == self._tune:
187-
self._step_method.stop_tuning()
188-
tuning = False
189-
190190
msg = self._recv_msg()
191191
if msg[0] == "abort":
192192
raise KeyboardInterrupt()
@@ -222,7 +222,7 @@ def _collect_warnings(self):
222222
class ProcessAdapter:
223223
"""Control a Chain process from the main thread."""
224224

225-
def __init__(self, draws, tune, step_method, chain, seed, start):
225+
def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
226226
self.chain = chain
227227
process_name = "worker_chain_%s" % chain
228228
self._msg_pipe, remote_conn = multiprocessing.Pipe()
@@ -353,15 +353,15 @@ def terminate_all(processes, patience=2):
353353
class ParallelSampler:
354354
def __init__(
355355
self,
356-
draws,
357-
tune,
358-
chains,
359-
cores,
360-
seeds,
361-
start_points,
356+
draws:int,
357+
tune:int,
358+
chains:int,
359+
cores:int,
360+
seeds:list,
361+
start_points:list,
362362
step_method,
363-
start_chain_num=0,
364-
progressbar=True,
363+
start_chain_num:int=0,
364+
progressbar:bool=True,
365365
):
366366

367367
if any(len(arg) != chains for arg in [seeds, start_points]):

pymc3/tests/test_sampling.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,18 @@ def test_sample_tune_len(self):
127127
trace = pm.sample(draws=100, tune=50, cores=4)
128128
assert len(trace) == 100
129129

130+
@pytest.mark.parametrize('cores', [1, 2])
131+
def test_sampler_stat_tune(self, cores):
132+
with self.model:
133+
tune_stat = pm.sample(
134+
tune=5, draws=7, cores=cores,
135+
discard_tuned_samples=False,
136+
step=pm.Metropolis()
137+
).get_sampler_stats('tune', chains=1)
138+
assert list(tune_stat).count(True) == 5
139+
assert list(tune_stat).count(False) == 7
140+
pass
141+
130142
@pytest.mark.parametrize(
131143
"start, error",
132144
[

0 commit comments

Comments
 (0)