diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 2048e5c15f..3caa4ff543 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -110,7 +110,7 @@ class _Process(multiprocessing.Process): and send finished samples using shared memory. """ - def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed): + def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed): super().__init__(daemon=True, name=name) self._msg_pipe = msg_pipe self._step_method = step_method @@ -173,6 +173,10 @@ def _start_loop(self): raise ValueError("Unexpected msg " + msg[0]) while True: + if draw == self._tune: + self._step_method.stop_tuning() + tuning = False + if draw < self._draws + self._tune: try: point, stats = self._compute_point() @@ -183,10 +187,6 @@ def _start_loop(self): else: return - if draw == self._tune: - self._step_method.stop_tuning() - tuning = False - msg = self._recv_msg() if msg[0] == "abort": raise KeyboardInterrupt() @@ -222,7 +222,7 @@ def _collect_warnings(self): class ProcessAdapter: """Control a Chain process from the main thread.""" - def __init__(self, draws, tune, step_method, chain, seed, start): + def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start): self.chain = chain process_name = "worker_chain_%s" % chain self._msg_pipe, remote_conn = multiprocessing.Pipe() @@ -353,15 +353,15 @@ def terminate_all(processes, patience=2): class ParallelSampler: def __init__( self, - draws, - tune, - chains, - cores, - seeds, - start_points, + draws:int, + tune:int, + chains:int, + cores:int, + seeds:list, + start_points:list, step_method, - start_chain_num=0, - progressbar=True, + start_chain_num:int=0, + progressbar:bool=True, ): if any(len(arg) != chains for arg in [seeds, start_points]): diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index cfa45906d3..17bac486ae 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -127,6 +127,18 @@ def test_sample_tune_len(self): trace = pm.sample(draws=100, tune=50, cores=4) assert len(trace) == 100 + @pytest.mark.parametrize('cores', [1, 2]) + def test_sampler_stat_tune(self, cores): + with self.model: + tune_stat = pm.sample( + tune=5, draws=7, cores=cores, + discard_tuned_samples=False, + step=pm.Metropolis() + ).get_sampler_stats('tune', chains=1) + assert list(tune_stat).count(True) == 5 + assert list(tune_stat).count(False) == 7 + pass + @pytest.mark.parametrize( "start, error", [