From a3df22505abd9beeb054e02d7c99a61ef2651221 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 25 Feb 2023 15:02:34 +0100 Subject: [PATCH 1/4] Extract return part of `pm.sample` --- pymc/sampling/mcmc.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index cba911aed9..243b03eb5e 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -333,7 +333,7 @@ def sample( compute_convergence_checks: bool = True, keep_warning_stat: bool = False, return_inferencedata: bool = True, - idata_kwargs: dict = None, + idata_kwargs: Optional[Dict[str, Any]] = None, callback=None, mp_ctx=None, model: Optional[Model] = None, @@ -687,7 +687,36 @@ def sample( t_sampling = time.time() - t_start - # Wrap chain traces in a MultiTrace + # Packaging, validating and returning the result was extracted + # into a function to make it easier to test and refactor. + return _sample_return( + traces=traces, + tune=tune, + t_sampling=t_sampling, + discard_tuned_samples=discard_tuned_samples, + compute_convergence_checks=compute_convergence_checks, + return_inferencedata=return_inferencedata, + keep_warning_stat=keep_warning_stat, + idata_kwargs=idata_kwargs or {}, + model=model, + ) + + +def _sample_return( + *, + traces: Sequence[IBaseTrace], + tune: int, + t_sampling: float, + discard_tuned_samples: bool, + compute_convergence_checks: bool, + return_inferencedata: bool, + keep_warning_stat: bool, + idata_kwargs: Dict[str, Any], + model: Model, +) -> Union[InferenceData, MultiTrace]: + """Final step of `pm.sampler` that picks/slices chains, + runs diagnostics and converts to the desired return type.""" + # Pick and slice chains to keep the maximum number of samples if discard_tuned_samples: traces, length = _choose_chains(traces, tune) else: @@ -725,8 +754,7 @@ def sample( idata = None if compute_convergence_checks or return_inferencedata: ikwargs: Dict[str, Any] = dict(model=model, save_warmup=not discard_tuned_samples) - if idata_kwargs: - ikwargs.update(idata_kwargs) + ikwargs.update(idata_kwargs) idata = pm.to_inference_data(mtrace, **ikwargs) if compute_convergence_checks: From 62cb70995274bca974728064e00ec08db350e54d Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 25 Feb 2023 17:40:19 +0100 Subject: [PATCH 2/4] Speed up `test_mcmc` --- tests/backends/test_ndarray.py | 6 ++++++ tests/sampling/test_mcmc.py | 36 ++++------------------------------ 2 files changed, 10 insertions(+), 32 deletions(-) diff --git a/tests/backends/test_ndarray.py b/tests/backends/test_ndarray.py index 4b6ef2f0b2..f1404d91cb 100644 --- a/tests/backends/test_ndarray.py +++ b/tests/backends/test_ndarray.py @@ -123,6 +123,12 @@ def test_multitrace_nonunique(self): with pytest.raises(ValueError): base.MultiTrace([self.strace0, self.strace1]) + def test_multitrace_iter_notimplemented(self): + mtrace = base.MultiTrace([self.strace0]) + with pytest.raises(NotImplementedError): + for _ in mtrace: + pass + class TestSqueezeCat: def setup_method(self): diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 950694e646..e2cefc70ef 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -151,24 +151,6 @@ def test_sample_does_not_rely_on_external_global_seeding(self): assert np.all(idata12["x"] != idata22["x"]) assert np.all(idata13["x"] != idata23["x"]) - def test_sample(self): - test_cores = [1] - with self.model: - for cores in test_cores: - for steps in [1, 10, 300]: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - warnings.filterwarnings( - "ignore", "More chains .* than draws .*", UserWarning - ) - pm.sample( - steps, - tune=0, - step=self.step, - cores=cores, - random_seed=self.random_seed, - ) - def test_sample_init(self): with self.model: for init in ( @@ -199,11 +181,11 @@ def test_sample_init(self): def test_sample_args(self): with self.model: with pytest.raises(ValueError) as excinfo: - pm.sample(50, tune=0, foo=1) + pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo=1) assert "'foo'" in str(excinfo.value) with pytest.raises(ValueError) as excinfo: - pm.sample(50, tune=0, foo={}) + pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo={}) assert "foo" in str(excinfo.value) def test_parallel_start(self): @@ -232,6 +214,7 @@ def test_sample_tune_len(self): draws=100, tune=50, cores=1, + step=pm.Metropolis(), return_inferencedata=False, discard_tuned_samples=False, ) @@ -387,18 +370,7 @@ def test_sequential_backend(self): backend = NDArray() with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - pm.sample(10, cores=1, chains=2, trace=backend) - - def test_exceptions(self): - # Test iteration over MultiTrace NotImplementedError - with pm.Model() as model: - mu = pm.Normal("mu", 0.0, 1.0) - a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2])) - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - trace = pm.sample(tune=0, draws=10, chains=2, return_inferencedata=False) - with pytest.raises(NotImplementedError): - xvars = [t["mu"] for t in trace] + pm.sample(10, tune=5, cores=1, chains=2, step=pm.Metropolis(), trace=backend) @pytest.mark.parametrize("symbolic_rv", (False, True)) def test_deterministic_of_unobserved(self, symbolic_rv): From a386095fb97edd3c2f5193074c2a016ecfb3b1d2 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 25 Feb 2023 18:14:12 +0100 Subject: [PATCH 3/4] Group tests related to `pm.sample` return parameters --- tests/sampling/test_mcmc.py | 353 +++++++++++++++++++----------------- 1 file changed, 182 insertions(+), 171 deletions(-) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index e2cefc70ef..7e14af7956 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -203,25 +203,6 @@ def test_parallel_start(self): assert idata.warmup_posterior["x"].sel(chain=0, draw=0).values[0] > 0 assert idata.warmup_posterior["x"].sel(chain=1, draw=0).values[0] < 0 - def test_sample_tune_len(self): - with self.model: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) - trace = pm.sample(draws=100, tune=50, cores=1, return_inferencedata=False) - assert len(trace) == 100 - trace = pm.sample( - draws=100, - tune=50, - cores=1, - step=pm.Metropolis(), - return_inferencedata=False, - discard_tuned_samples=False, - ) - assert len(trace) == 150 - trace = pm.sample(draws=100, tune=50, cores=4, return_inferencedata=False) - assert len(trace) == 100 - def test_reset_tuning(self): with self.model: tune = 50 @@ -233,80 +214,6 @@ def test_reset_tuning(self): assert step.potential._n_samples == tune assert step.step_adapt._count == tune + 1 - @pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice]) - @pytest.mark.parametrize("discard", [True, False]) - def test_trace_report(self, step_cls, discard): - with self.model: - # add more variables, because stats are 2D with CompoundStep! - pm.Uniform("uni") - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", ".*Tuning samples will be included.*", UserWarning - ) - trace = pm.sample( - draws=100, - tune=50, - cores=1, - discard_tuned_samples=discard, - step=step_cls(), - compute_convergence_checks=False, - return_inferencedata=False, - ) - assert trace.report.n_tune == 50 - assert trace.report.n_draws == 100 - assert isinstance(trace.report.t_sampling, float) - - def test_return_inferencedata(self): - with self.model: - kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis()) - - # trace with tuning - with pytest.warns(UserWarning, match="will be included"): - result = pm.sample( - **kwargs, return_inferencedata=False, discard_tuned_samples=False - ) - assert isinstance(result, pm.backends.base.MultiTrace) - assert len(result) == 150 - - # inferencedata with tuning - result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False) - assert isinstance(result, InferenceData) - assert result.posterior.sizes["draw"] == 100 - assert result.posterior.sizes["chain"] == 2 - assert len(result._groups_warmup) > 0 - - # inferencedata without tuning, with idata_kwargs - prior = pm.sample_prior_predictive(return_inferencedata=False) - result = pm.sample( - **kwargs, - return_inferencedata=True, - discard_tuned_samples=True, - idata_kwargs={"prior": prior}, - random_seed=-1, - ) - assert "prior" in result - assert isinstance(result, InferenceData) - assert result.posterior.sizes["draw"] == 100 - assert result.posterior.sizes["chain"] == 2 - assert len(result._groups_warmup) == 0 - - @pytest.mark.parametrize("cores", [1, 2]) - def test_sampler_stat_tune(self, cores): - with self.model: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) - tune_stat = pm.sample( - tune=5, - draws=7, - cores=cores, - discard_tuned_samples=False, - return_inferencedata=False, - step=pm.Metropolis(), - ).get_sampler_stats("tune", chains=1) - assert list(tune_stat).count(True) == 5 - assert list(tune_stat).count(False) == 7 - @pytest.mark.parametrize( "start, error", [ @@ -411,6 +318,188 @@ def test_transform_with_rv_dependency(self, symbolic_rv): assert np.allclose(scipy.special.expit(trace["y_interval__"]), trace["y"]) +class ApocalypticMetropolis(pm.Metropolis): + """A stepper that warns in every iteration.""" + + stats_dtypes_shapes = { + **pm.Metropolis.stats_dtypes_shapes, + "warning": (SamplerWarning, None), + } + + def astep(self, q0): + draw, stats = super().astep(q0) + stats[0]["warning"] = SamplerWarning( + WarningType.BAD_ENERGY, + "Asteroid incoming!", + "warn", + ) + return draw, stats + + +class TestSampleReturn: + """Tests related to kwargs that parametrize how `pm.sample` results are returned.""" + + def test_sample_tune_len(self): + with pm.Model(): + pm.Normal("n") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) + trace = pm.sample( + draws=100, tune=50, cores=1, step=pm.Metropolis(), return_inferencedata=False + ) + assert len(trace) == 100 + trace = pm.sample( + draws=100, + tune=50, + cores=1, + step=pm.Metropolis(), + return_inferencedata=False, + discard_tuned_samples=False, + ) + assert len(trace) == 150 + trace = pm.sample( + draws=100, + tune=50, + cores=4, + step=pm.Metropolis(), + return_inferencedata=False, + ) + assert len(trace) == 100 + + @pytest.mark.parametrize("discard", [True, False]) + def test_trace_report(self, discard): + with pm.Model(): + pm.Uniform("uni") + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", ".*Tuning samples will be included.*", UserWarning + ) + trace = pm.sample( + draws=100, + tune=50, + cores=1, + discard_tuned_samples=discard, + step=pm.Metropolis(), + compute_convergence_checks=False, + return_inferencedata=False, + ) + assert trace.report.n_tune == 50 + assert trace.report.n_draws == 100 + assert isinstance(trace.report.t_sampling, float) + + def test_return_inferencedata(self): + model, _, step, _ = simple_init() + with model: + kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=step) + + # trace with tuning + with pytest.warns(UserWarning, match="will be included"): + result = pm.sample( + **kwargs, return_inferencedata=False, discard_tuned_samples=False + ) + assert isinstance(result, pm.backends.base.MultiTrace) + assert len(result) == 150 + + # inferencedata with tuning + result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False) + assert isinstance(result, InferenceData) + assert result.posterior.sizes["draw"] == 100 + assert result.posterior.sizes["chain"] == 2 + assert len(result._groups_warmup) > 0 + + # inferencedata without tuning, with idata_kwargs + prior = pm.sample_prior_predictive(return_inferencedata=False) + result = pm.sample( + **kwargs, + return_inferencedata=True, + discard_tuned_samples=True, + idata_kwargs={"prior": prior}, + random_seed=-1, + ) + assert "prior" in result + assert isinstance(result, InferenceData) + assert result.posterior.sizes["draw"] == 100 + assert result.posterior.sizes["chain"] == 2 + assert len(result._groups_warmup) == 0 + + @pytest.mark.parametrize("cores", [1, 2]) + def test_sampler_stat_tune(self, cores): + with pm.Model(): + pm.Normal("n") + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) + tune_stat = pm.sample( + tune=5, + draws=7, + cores=cores, + discard_tuned_samples=False, + return_inferencedata=False, + step=pm.Metropolis(), + ).get_sampler_stats("tune", chains=1) + assert list(tune_stat).count(True) == 5 + assert list(tune_stat).count(False) == 7 + + @pytest.mark.parametrize("cores", [1, 2]) + def test_logs_sampler_warnings(self, caplog, cores): + """Asserts that "warning" sampler stats are logged during sampling.""" + with pm.Model(): + pm.Normal("n") + with caplog.at_level(logging.WARNING): + idata = pm.sample( + tune=2, + draws=3, + cores=cores, + chains=cores, + step=ApocalypticMetropolis(), + compute_convergence_checks=False, + discard_tuned_samples=False, + keep_warning_stat=True, + ) + + # Sampler warnings should be logged + nwarns = sum("Asteroid" in rec.message for rec in caplog.records) + assert nwarns == (2 + 3) * cores + + @pytest.mark.parametrize("keep_warning_stat", [None, True]) + def test_keep_warning_stat_setting(self, keep_warning_stat): + """The ``keep_warning_stat`` stat (aka "Adrian's kwarg) enables users + to keep the ``SamplerWarning`` objects from the ``sample_stats.warning`` group. + This breaks ``idata.to_netcdf()`` which is why it defaults to ``False``. + """ + sample_kwargs = dict( + tune=2, + draws=3, + chains=1, + compute_convergence_checks=False, + discard_tuned_samples=False, + keep_warning_stat=keep_warning_stat, + ) + if keep_warning_stat: + sample_kwargs["keep_warning_stat"] = True + with pm.Model(): + pm.Normal("n") + idata = pm.sample(step=ApocalypticMetropolis(), **sample_kwargs) + + if keep_warning_stat: + assert "warning" in idata.warmup_sample_stats + assert "warning" in idata.sample_stats + # And end up in the InferenceData + assert "warning" in idata.sample_stats + # NOTE: The stats are squeezed by default but this does not always work. + # This tests flattens so we don't have to be exact in accessing (non-)squeezed items. + # Also see https://github.com/pymc-devs/pymc/issues/6207. + warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten()) + assert any(isinstance(w, SamplerWarning) for w in warn_objs) + assert any("Asteroid" in w.message for w in warn_objs) + else: + assert "warning" not in idata.warmup_sample_stats + assert "warning" not in idata.sample_stats + assert "warning_dim_0" not in idata.warmup_sample_stats + assert "warning_dim_0" not in idata.sample_stats + + def test_sample_find_MAP_does_not_modify_start(): # see https://github.com/pymc-devs/pymc/pull/4458 with pm.Model(): @@ -603,84 +692,6 @@ def test_step_args(): npt.assert_allclose(idata1.sample_stats.scaling, 0) -class ApocalypticMetropolis(pm.Metropolis): - """A stepper that warns in every iteration.""" - - stats_dtypes_shapes = { - **pm.Metropolis.stats_dtypes_shapes, - "warning": (SamplerWarning, None), - } - - def astep(self, q0): - draw, stats = super().astep(q0) - stats[0]["warning"] = SamplerWarning( - WarningType.BAD_ENERGY, - "Asteroid incoming!", - "warn", - ) - return draw, stats - - -@pytest.mark.parametrize("cores", [1, 2]) -def test_logs_sampler_warnings(caplog, cores): - """Asserts that "warning" sampler stats are logged during sampling.""" - with pm.Model(): - pm.Normal("n") - with caplog.at_level(logging.WARNING): - idata = pm.sample( - tune=2, - draws=3, - cores=cores, - chains=cores, - step=ApocalypticMetropolis(), - compute_convergence_checks=False, - discard_tuned_samples=False, - keep_warning_stat=True, - ) - - # Sampler warnings should be logged - nwarns = sum("Asteroid" in rec.message for rec in caplog.records) - assert nwarns == (2 + 3) * cores - - -@pytest.mark.parametrize("keep_warning_stat", [None, True]) -def test_keep_warning_stat_setting(keep_warning_stat): - """The ``keep_warning_stat`` stat (aka "Adrian's kwarg) enables users - to keep the ``SamplerWarning`` objects from the ``sample_stats.warning`` group. - This breaks ``idata.to_netcdf()`` which is why it defaults to ``False``. - """ - sample_kwargs = dict( - tune=2, - draws=3, - chains=1, - compute_convergence_checks=False, - discard_tuned_samples=False, - keep_warning_stat=keep_warning_stat, - ) - if keep_warning_stat: - sample_kwargs["keep_warning_stat"] = True - with pm.Model(): - pm.Normal("n") - idata = pm.sample(step=ApocalypticMetropolis(), **sample_kwargs) - - if keep_warning_stat: - assert "warning" in idata.warmup_sample_stats - assert "warning" in idata.sample_stats - # And end up in the InferenceData - assert "warning" in idata.sample_stats - # NOTE: The stats are squeezed by default but this does not always work. - # This tests flattens so we don't have to be exact in accessing (non-)squeezed items. - # Also see https://github.com/pymc-devs/pymc/issues/6207. - warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten()) - assert any(isinstance(w, SamplerWarning) for w in warn_objs) - assert any("Asteroid" in w.message for w in warn_objs) - else: - assert "warning" not in idata.warmup_sample_stats - assert "warning" not in idata.sample_stats - assert "warning_dim_0" not in idata.warmup_sample_stats - assert "warning_dim_0" not in idata.sample_stats - - def test_init_nuts(caplog): with pm.Model() as model: a = pm.Normal("a") From 6a2ed47cb4eb794cc396ce4e8bec2a0e70b84f39 Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 25 Feb 2023 20:20:38 +0100 Subject: [PATCH 4/4] Consolidate tests of return options Removes a regression test added in #3821 because it took 14 seconds. --- tests/sampling/test_mcmc.py | 186 +++++++++++++++--------------------- 1 file changed, 79 insertions(+), 107 deletions(-) diff --git a/tests/sampling/test_mcmc.py b/tests/sampling/test_mcmc.py index 7e14af7956..c87eb6d456 100644 --- a/tests/sampling/test_mcmc.py +++ b/tests/sampling/test_mcmc.py @@ -233,22 +233,20 @@ def test_sample_start_good_shape(self, start): def test_sample_callback(self): callback = mock.Mock() test_cores = [1, 2] - test_chains = [1, 2] with self.model: for cores in test_cores: - for chain in test_chains: - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - pm.sample( - 10, - tune=0, - chains=chain, - step=self.step, - cores=cores, - random_seed=self.random_seed, - callback=callback, - ) - assert callback.called + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) + pm.sample( + 10, + tune=0, + chains=2, + step=self.step, + cores=cores, + random_seed=self.random_seed, + callback=callback, + ) + assert callback.called def test_callback_can_cancel(self): trace_cancel_length = 5 @@ -339,107 +337,81 @@ def astep(self, q0): class TestSampleReturn: """Tests related to kwargs that parametrize how `pm.sample` results are returned.""" - def test_sample_tune_len(self): - with pm.Model(): + def test_sample_return_lengths(self): + with pm.Model() as model: pm.Normal("n") - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) - trace = pm.sample( - draws=100, tune=50, cores=1, step=pm.Metropolis(), return_inferencedata=False - ) - assert len(trace) == 100 - trace = pm.sample( - draws=100, - tune=50, - cores=1, - step=pm.Metropolis(), - return_inferencedata=False, - discard_tuned_samples=False, - ) - assert len(trace) == 150 - trace = pm.sample( - draws=100, - tune=50, - cores=4, - step=pm.Metropolis(), - return_inferencedata=False, - ) - assert len(trace) == 100 - @pytest.mark.parametrize("discard", [True, False]) - def test_trace_report(self, discard): - with pm.Model(): - pm.Uniform("uni") - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", ".*Tuning samples will be included.*", UserWarning - ) - trace = pm.sample( + # Get a MultiTrace with warmup + with pytest.warns(UserWarning, match="will be included"): + mtrace = pm.sample( draws=100, tune=50, cores=1, - discard_tuned_samples=discard, + chains=3, step=pm.Metropolis(), - compute_convergence_checks=False, return_inferencedata=False, - ) - assert trace.report.n_tune == 50 - assert trace.report.n_draws == 100 - assert isinstance(trace.report.t_sampling, float) - - def test_return_inferencedata(self): - model, _, step, _ = simple_init() - with model: - kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=step) - - # trace with tuning - with pytest.warns(UserWarning, match="will be included"): - result = pm.sample( - **kwargs, return_inferencedata=False, discard_tuned_samples=False - ) - assert isinstance(result, pm.backends.base.MultiTrace) - assert len(result) == 150 - - # inferencedata with tuning - result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False) - assert isinstance(result, InferenceData) - assert result.posterior.sizes["draw"] == 100 - assert result.posterior.sizes["chain"] == 2 - assert len(result._groups_warmup) > 0 - - # inferencedata without tuning, with idata_kwargs - prior = pm.sample_prior_predictive(return_inferencedata=False) - result = pm.sample( - **kwargs, - return_inferencedata=True, - discard_tuned_samples=True, - idata_kwargs={"prior": prior}, - random_seed=-1, - ) - assert "prior" in result - assert isinstance(result, InferenceData) - assert result.posterior.sizes["draw"] == 100 - assert result.posterior.sizes["chain"] == 2 - assert len(result._groups_warmup) == 0 - - @pytest.mark.parametrize("cores", [1, 2]) - def test_sampler_stat_tune(self, cores): - with pm.Model(): - pm.Normal("n") - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) - warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning) - tune_stat = pm.sample( - tune=5, - draws=7, - cores=cores, discard_tuned_samples=False, - return_inferencedata=False, - step=pm.Metropolis(), - ).get_sampler_stats("tune", chains=1) - assert list(tune_stat).count(True) == 5 - assert list(tune_stat).count(False) == 7 + ) + assert isinstance(mtrace, pm.backends.base.MultiTrace) + assert len(mtrace) == 150 + + # Now instead of running more MCMCs, we'll test the other return + # options using the basetraces inside the MultiTrace. + traces = list(mtrace._straces.values()) + assert len(traces) == 3 + + # MultiTrace without warmup + mtrace_pst = pm.sampling.mcmc._sample_return( + traces=traces, + tune=50, + t_sampling=123.4, + discard_tuned_samples=True, + return_inferencedata=False, + compute_convergence_checks=False, + keep_warning_stat=True, + idata_kwargs={}, + model=model, + ) + assert isinstance(mtrace_pst, pm.backends.base.MultiTrace) + assert len(mtrace_pst) == 100 + assert mtrace_pst.report.t_sampling == 123.4 + assert mtrace_pst.report.n_tune == 50 + assert mtrace_pst.report.n_draws == 100 + + # InferenceData with warmup + idata_w = pm.sampling.mcmc._sample_return( + traces=traces, + tune=50, + t_sampling=123.4, + discard_tuned_samples=False, + compute_convergence_checks=False, + return_inferencedata=True, + keep_warning_stat=True, + idata_kwargs={}, + model=model, + ) + assert isinstance(idata_w, InferenceData) + assert hasattr(idata_w, "warmup_posterior") + assert idata_w.warmup_posterior.sizes["draw"] == 50 + assert idata_w.posterior.sizes["draw"] == 100 + assert idata_w.posterior.sizes["chain"] == 3 + + # InferenceData without warmup + idata = pm.sampling.mcmc._sample_return( + traces=traces, + tune=50, + t_sampling=123.4, + discard_tuned_samples=True, + compute_convergence_checks=False, + return_inferencedata=True, + keep_warning_stat=False, + idata_kwargs={}, + model=model, + ) + assert isinstance(idata, InferenceData) + assert not hasattr(idata, "warmup_posterior") + assert idata.posterior.sizes["draw"] == 100 + assert idata.posterior.sizes["chain"] == 3 @pytest.mark.parametrize("cores", [1, 2]) def test_logs_sampler_warnings(self, caplog, cores):