Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class _Process(multiprocessing.Process):
and send finished samples using shared memory.
"""

def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed):
def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed):
super().__init__(daemon=True, name=name)
self._msg_pipe = msg_pipe
self._step_method = step_method
Expand Down Expand Up @@ -173,6 +173,10 @@ def _start_loop(self):
raise ValueError("Unexpected msg " + msg[0])

while True:
if draw == self._tune:
self._step_method.stop_tuning()
tuning = False

if draw < self._draws + self._tune:
try:
point, stats = self._compute_point()
Expand All @@ -183,10 +187,6 @@ def _start_loop(self):
else:
return

if draw == self._tune:
self._step_method.stop_tuning()
tuning = False

msg = self._recv_msg()
if msg[0] == "abort":
raise KeyboardInterrupt()
Expand Down Expand Up @@ -222,7 +222,7 @@ def _collect_warnings(self):
class ProcessAdapter:
"""Control a Chain process from the main thread."""

def __init__(self, draws, tune, step_method, chain, seed, start):
def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
self.chain = chain
process_name = "worker_chain_%s" % chain
self._msg_pipe, remote_conn = multiprocessing.Pipe()
Expand Down Expand Up @@ -353,15 +353,15 @@ def terminate_all(processes, patience=2):
class ParallelSampler:
def __init__(
self,
draws,
tune,
chains,
cores,
seeds,
start_points,
draws:int,
tune:int,
chains:int,
cores:int,
seeds:list,
start_points:list,
step_method,
start_chain_num=0,
progressbar=True,
start_chain_num:int=0,
progressbar:bool=True,
):

if any(len(arg) != chains for arg in [seeds, start_points]):
Expand Down
12 changes: 12 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,18 @@ def test_sample_tune_len(self):
trace = pm.sample(draws=100, tune=50, cores=4)
assert len(trace) == 100

@pytest.mark.parametrize('cores', [1, 2])
def test_sampler_stat_tune(self, cores):
with self.model:
tune_stat = pm.sample(
tune=5, draws=7, cores=cores,
discard_tuned_samples=False,
step=pm.Metropolis()
).get_sampler_stats('tune', chains=1)
assert list(tune_stat).count(True) == 5
assert list(tune_stat).count(False) == 7
pass

@pytest.mark.parametrize(
"start, error",
[
Expand Down