From 0210cdff26df6b60e8bdf63d22324e4dea919bc9 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 1 Mar 2020 11:00:48 +0100 Subject: [PATCH 1/3] add regression test for #3819 --- pymc3/tests/test_sampling.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index cfa45906d3..17bac486ae 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -127,6 +127,18 @@ def test_sample_tune_len(self): trace = pm.sample(draws=100, tune=50, cores=4) assert len(trace) == 100 + @pytest.mark.parametrize('cores', [1, 2]) + def test_sampler_stat_tune(self, cores): + with self.model: + tune_stat = pm.sample( + tune=5, draws=7, cores=cores, + discard_tuned_samples=False, + step=pm.Metropolis() + ).get_sampler_stats('tune', chains=1) + assert list(tune_stat).count(True) == 5 + assert list(tune_stat).count(False) == 7 + pass + @pytest.mark.parametrize( "start, error", [ From 0f51ea13b73ec7edab86aee96d7184ec4c5c6c64 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 1 Mar 2020 11:30:40 +0100 Subject: [PATCH 2/3] check for tune stop before making a draw closes #3819 --- pymc3/parallel_sampling.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 2048e5c15f..9427760cfe 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -173,6 +173,10 @@ def _start_loop(self): raise ValueError("Unexpected msg " + msg[0]) while True: + if draw == self._tune: + self._step_method.stop_tuning() + tuning = False + if draw < self._draws + self._tune: try: point, stats = self._compute_point() @@ -183,10 +187,6 @@ def _start_loop(self): else: return - if draw == self._tune: - self._step_method.stop_tuning() - tuning = False - msg = self._recv_msg() if msg[0] == "abort": raise KeyboardInterrupt() From 2f2a7cddd1a427e7cf37a2ecf3c3bf894940b53d Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sun, 1 Mar 2020 11:31:20 +0100 Subject: [PATCH 3/3] add some type annotations --- pymc3/parallel_sampling.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 9427760cfe..3caa4ff543 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -110,7 +110,7 @@ class _Process(multiprocessing.Process): and send finished samples using shared memory. """ - def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed): + def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed): super().__init__(daemon=True, name=name) self._msg_pipe = msg_pipe self._step_method = step_method @@ -222,7 +222,7 @@ def _collect_warnings(self): class ProcessAdapter: """Control a Chain process from the main thread.""" - def __init__(self, draws, tune, step_method, chain, seed, start): + def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start): self.chain = chain process_name = "worker_chain_%s" % chain self._msg_pipe, remote_conn = multiprocessing.Pipe() @@ -353,15 +353,15 @@ def terminate_all(processes, patience=2): class ParallelSampler: def __init__( self, - draws, - tune, - chains, - cores, - seeds, - start_points, + draws:int, + tune:int, + chains:int, + cores:int, + seeds:list, + start_points:list, step_method, - start_chain_num=0, - progressbar=True, + start_chain_num:int=0, + progressbar:bool=True, ): if any(len(arg) != chains for arg in [seeds, start_points]):