diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 17f05d642a..1d88c2cd4d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -39,12 +39,10 @@ jobs: pymc3/tests/test_distributions_timeseries.py pymc3/tests/test_parallel_sampling.py pymc3/tests/test_random.py - pymc3/tests/test_sampling.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py - | pymc3/tests/test_examples.py - pymc3/tests/test_gp.py pymc3/tests/test_mixture.py pymc3/tests/test_posteriors.py pymc3/tests/test_quadpotential.py @@ -54,6 +52,8 @@ jobs: pymc3/tests/test_variational_inference.py - | pymc3/tests/test_distributions.py + pymc3/tests/test_gp.py + pymc3/tests/test_sampling.py runs-on: ${{ matrix.os }} env: TEST_SUBSET: ${{ matrix.test-subset }} diff --git a/pymc3/sampling.py b/pymc3/sampling.py index a1b66d5118..a9771d3e55 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -14,10 +14,7 @@ """Functions for MCMC sampling.""" -from typing import Dict, List, Optional, TYPE_CHECKING, cast, Union, Any - -if TYPE_CHECKING: - from typing import Tuple +from typing import Dict, List, Optional, cast, Union, Any from typing import Iterable as TIterable from collections.abc import Iterable from collections import defaultdict @@ -218,11 +215,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None def _print_step_hierarchy(s, level=0): - if isinstance(s, (list, tuple)): - _log.info(">" * level + "list") - for i in s: - _print_step_hierarchy(i, level + 1) - elif isinstance(s, CompoundStep): + if isinstance(s, CompoundStep): _log.info(">" * level + "CompoundStep") for i in s.methods: _print_step_hierarchy(i, level + 1) @@ -458,7 +451,7 @@ def sample( if return_inferencedata is None: v = packaging.version.parse(pm.__version__) - if v.release[0] > 3 or v.release[1] >= 10: + if v.release[0] > 3 or v.release[1] >= 10: # type: ignore warnings.warn( "In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. " "You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.", @@ -585,7 +578,7 @@ def sample( UserWarning, ) _print_step_hierarchy(step) - trace = _sample_population(**sample_args, parallelize=cores > 1) + trace = _sample_population(parallelize=cores > 1, **sample_args) else: _log.info(f"Sequential sampling ({chains} chains in 1 job)") _print_step_hierarchy(step) @@ -770,11 +763,9 @@ def _sample_population( trace : MultiTrace Contains samples of all chains """ - # create the generator that iterates all chains in parallel - chains = [chain + c for c in range(chains)] sampling = _prepare_iter_population( draws, - chains, + [chain + c for c in range(chains)], step, start, parallelize, @@ -1582,10 +1573,7 @@ def insert(self, k: str, v, idx: int): ids: int The index of the sample we are inserting into the trace. """ - if hasattr(v, "shape"): - value_shape = tuple(v.shape) # type: Tuple[int, ...] - else: - value_shape = () + value_shape = np.shape(v) # initialize if necessary if k not in self.trace_dict: diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index be3c7b4f60..2185542f17 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -13,14 +13,9 @@ # limitations under the License. from itertools import combinations -import packaging from typing import Tuple import numpy as np - -try: - import unittest.mock as mock # py3 -except ImportError: - from unittest import mock +import unittest.mock as mock import numpy.testing as npt import arviz as az @@ -180,13 +175,9 @@ def test_trace_report_bart(self): assert var_imp[0] > var_imp[1:].sum() npt.assert_almost_equal(var_imp.sum(), 1) - def test_return_inferencedata(self): + def test_return_inferencedata(self, monkeypatch): with self.model: kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis()) - v = packaging.version.parse(pm.__version__) - if v.major > 3 or v.minor >= 10: - with pytest.warns(FutureWarning, match="pass return_inferencedata"): - result = pm.sample(**kwargs) # trace with tuning with pytest.warns(UserWarning, match="will be included"): @@ -203,12 +194,25 @@ def test_return_inferencedata(self): assert result.posterior.sizes["chain"] == 2 assert len(result._groups_warmup) > 0 - # inferencedata without tuning - result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=True) + # inferencedata without tuning, with idata_kwargs + prior = pm.sample_prior_predictive() + result = pm.sample( + **kwargs, + return_inferencedata=True, + discard_tuned_samples=True, + idata_kwargs={"prior": prior}, + random_seed=-1 + ) + assert "prior" in result assert isinstance(result, az.InferenceData) assert result.posterior.sizes["draw"] == 100 assert result.posterior.sizes["chain"] == 2 assert len(result._groups_warmup) == 0 + + # check warning for version 3.10 onwards + monkeypatch.setattr("pymc3.__version__", "3.10") + with pytest.warns(FutureWarning, match="pass return_inferencedata"): + result = pm.sample(**kwargs) pass @pytest.mark.parametrize("cores", [1, 2]) diff --git a/pyproject.toml b/pyproject.toml index b2c1083e72..c75a682c1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,13 @@ [tool.black] line-length = 100 +[tool.coverage.report] +exclude_lines = [ + "pragma: nocover", + "raise NotImplementedError", + "if TYPE_CHECKING:", +] + [tool.nbqa.mutate] isort = 1 black = 1