From 1f2c2946637e64de64289d44a5edbbc8ecdfbe09 Mon Sep 17 00:00:00 2001 From: MarcoGorelli Date: Fri, 27 Nov 2020 16:44:17 +0000 Subject: [PATCH 1/3] improve coverage --- pymc3/sampling.py | 21 ++++++--------------- pymc3/tests/test_sampling.py | 31 +++++++++++++++++-------------- pyproject.toml | 7 +++++++ 3 files changed, 30 insertions(+), 29 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index a1b66d5118..14cf3b3911 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -14,10 +14,7 @@ """Functions for MCMC sampling.""" -from typing import Dict, List, Optional, TYPE_CHECKING, cast, Union, Any - -if TYPE_CHECKING: - from typing import Tuple +from typing import Dict, List, Optional, cast, Union, Any, Tuple from typing import Iterable as TIterable from collections.abc import Iterable from collections import defaultdict @@ -218,11 +215,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None def _print_step_hierarchy(s, level=0): - if isinstance(s, (list, tuple)): - _log.info(">" * level + "list") - for i in s: - _print_step_hierarchy(i, level + 1) - elif isinstance(s, CompoundStep): + if isinstance(s, CompoundStep): _log.info(">" * level + "CompoundStep") for i in s.methods: _print_step_hierarchy(i, level + 1) @@ -458,7 +451,7 @@ def sample( if return_inferencedata is None: v = packaging.version.parse(pm.__version__) - if v.release[0] > 3 or v.release[1] >= 10: + if v.release[0] > 3 or v.release[1] >= 10: # type: ignore warnings.warn( "In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. " "You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.", @@ -585,7 +578,7 @@ def sample( UserWarning, ) _print_step_hierarchy(step) - trace = _sample_population(**sample_args, parallelize=cores > 1) + trace = _sample_population(parallelize=cores > 1, **sample_args) else: _log.info(f"Sequential sampling ({chains} chains in 1 job)") _print_step_hierarchy(step) @@ -770,11 +763,9 @@ def _sample_population( trace : MultiTrace Contains samples of all chains """ - # create the generator that iterates all chains in parallel - chains = [chain + c for c in range(chains)] sampling = _prepare_iter_population( draws, - chains, + [chain + c for c in range(chains)], step, start, parallelize, @@ -1583,7 +1574,7 @@ def insert(self, k: str, v, idx: int): The index of the sample we are inserting into the trace. """ if hasattr(v, "shape"): - value_shape = tuple(v.shape) # type: Tuple[int, ...] + value_shape: Tuple[int, ...] = tuple(v.shape) else: value_shape = () diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index be3c7b4f60..21a69928fd 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -13,14 +13,9 @@ # limitations under the License. from itertools import combinations -import packaging from typing import Tuple import numpy as np - -try: - import unittest.mock as mock # py3 -except ImportError: - from unittest import mock +import unittest.mock as mock import numpy.testing as npt import arviz as az @@ -180,13 +175,9 @@ def test_trace_report_bart(self): assert var_imp[0] > var_imp[1:].sum() npt.assert_almost_equal(var_imp.sum(), 1) - def test_return_inferencedata(self): + def test_return_inferencedata(self, monkeypatch): with self.model: kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis()) - v = packaging.version.parse(pm.__version__) - if v.major > 3 or v.minor >= 10: - with pytest.warns(FutureWarning, match="pass return_inferencedata"): - result = pm.sample(**kwargs) # trace with tuning with pytest.warns(UserWarning, match="will be included"): @@ -203,13 +194,25 @@ def test_return_inferencedata(self): assert result.posterior.sizes["chain"] == 2 assert len(result._groups_warmup) > 0 - # inferencedata without tuning - result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=True) + # inferencedata without tuning, with idata_kwargs + prior = pm.sample_prior_predictive() + result = pm.sample( + **kwargs, + return_inferencedata=True, + discard_tuned_samples=True, + idata_kwargs={"prior": prior}, + random_seed=-1 + ) + assert "prior" in result assert isinstance(result, az.InferenceData) assert result.posterior.sizes["draw"] == 100 assert result.posterior.sizes["chain"] == 2 assert len(result._groups_warmup) == 0 - pass + + # check warning for version 3.10 onwards + monkeypatch.setattr("pymc3.__version__", "3.10") + with pytest.warns(FutureWarning, match="pass return_inferencedata"): + result = pm.sample(**kwargs) @pytest.mark.parametrize("cores", [1, 2]) def test_sampler_stat_tune(self, cores): diff --git a/pyproject.toml b/pyproject.toml index b2c1083e72..c75a682c1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,13 @@ [tool.black] line-length = 100 +[tool.coverage.report] +exclude_lines = [ + "pragma: nocover", + "raise NotImplementedError", + "if TYPE_CHECKING:", +] + [tool.nbqa.mutate] isort = 1 black = 1 From 48f7c55006efbedb62fcc7a6c8b556bc2ede85a2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sat, 28 Nov 2020 20:33:21 +0000 Subject: [PATCH 2/3] redistribute tests --- .github/workflows/pytest.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 17f05d642a..1d88c2cd4d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -39,12 +39,10 @@ jobs: pymc3/tests/test_distributions_timeseries.py pymc3/tests/test_parallel_sampling.py pymc3/tests/test_random.py - pymc3/tests/test_sampling.py pymc3/tests/test_shared.py pymc3/tests/test_smc.py - | pymc3/tests/test_examples.py - pymc3/tests/test_gp.py pymc3/tests/test_mixture.py pymc3/tests/test_posteriors.py pymc3/tests/test_quadpotential.py @@ -54,6 +52,8 @@ jobs: pymc3/tests/test_variational_inference.py - | pymc3/tests/test_distributions.py + pymc3/tests/test_gp.py + pymc3/tests/test_sampling.py runs-on: ${{ matrix.os }} env: TEST_SUBSET: ${{ matrix.test-subset }} From fa42e7e6bf333fd122d3009f73ad9f0dcf8332d4 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sun, 29 Nov 2020 12:14:18 +0000 Subject: [PATCH 3/3] use np.shape --- pymc3/sampling.py | 7 ++----- pymc3/tests/test_sampling.py | 1 + 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 14cf3b3911..a9771d3e55 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -14,7 +14,7 @@ """Functions for MCMC sampling.""" -from typing import Dict, List, Optional, cast, Union, Any, Tuple +from typing import Dict, List, Optional, cast, Union, Any from typing import Iterable as TIterable from collections.abc import Iterable from collections import defaultdict @@ -1573,10 +1573,7 @@ def insert(self, k: str, v, idx: int): ids: int The index of the sample we are inserting into the trace. """ - if hasattr(v, "shape"): - value_shape: Tuple[int, ...] = tuple(v.shape) - else: - value_shape = () + value_shape = np.shape(v) # initialize if necessary if k not in self.trace_dict: diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 21a69928fd..2185542f17 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -213,6 +213,7 @@ def test_return_inferencedata(self, monkeypatch): monkeypatch.setattr("pymc3.__version__", "3.10") with pytest.warns(FutureWarning, match="pass return_inferencedata"): result = pm.sample(**kwargs) + pass @pytest.mark.parametrize("cores", [1, 2]) def test_sampler_stat_tune(self, cores):