From d88748563ae8ff3d903bf1c3b47bc93bce2b408a Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 5 Nov 2022 12:47:26 +0100 Subject: [PATCH 1/2] Move sampling code into sampling submodule This is a follow-up to #6257 where we split the `sampling.py` into two files. --- .github/workflows/tests.yml | 18 +++++----- docs/source/api/samplers.rst | 4 +-- docs/source/contributing/build_docs.md | 2 +- pymc/__init__.py | 1 - pymc/sampling/__init__.py | 16 +++++++++ .../forward.py} | 0 pymc/{sampling_jax.py => sampling/jax.py} | 2 +- pymc/{sampling.py => sampling/mcmc.py} | 4 +-- .../parallel.py} | 0 pymc/smc/kernels.py | 2 +- pymc/smc/sampling.py | 2 +- pymc/tests/distributions/test_mixture.py | 4 +-- pymc/tests/distributions/test_multivariate.py | 2 +- pymc/tests/distributions/test_timeseries.py | 4 +-- pymc/tests/sampler_fixtures.py | 2 +- pymc/tests/sampling/__init__.py | 0 .../test_forward.py} | 2 +- .../test_jax.py} | 2 +- .../test_mcmc.py} | 34 +++++++++---------- .../test_parallel.py} | 2 +- pymc/variational/opvi.py | 5 +-- scripts/run_mypy.py | 13 +++---- 22 files changed, 69 insertions(+), 52 deletions(-) create mode 100644 pymc/sampling/__init__.py rename pymc/{sampling_forward.py => sampling/forward.py} (100%) rename pymc/{sampling_jax.py => sampling/jax.py} (99%) rename pymc/{sampling.py => sampling/mcmc.py} (99%) rename pymc/{parallel_sampling.py => sampling/parallel.py} (100%) create mode 100644 pymc/tests/sampling/__init__.py rename pymc/tests/{test_sampling_forward.py => sampling/test_forward.py} (99%) rename pymc/tests/{test_sampling_jax.py => sampling/test_jax.py} (99%) rename pymc/tests/{test_sampling.py => sampling/test_mcmc.py} (97%) rename pymc/tests/{test_parallel_sampling.py => sampling/test_parallel.py} (99%) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 31ee9a085f..807a0928f3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -59,16 +59,16 @@ jobs: pymc/tests/distributions/test_censored.py pymc/tests/distributions/test_simulator.py pymc/tests/distributions/test_truncated.py - pymc/tests/test_sampling_forward.py + pymc/tests/sampling/test_forward.py pymc/tests/stats/test_convergence.py - | pymc/tests/tuning/test_scaling.py pymc/tests/tuning/test_starting.py - pymc/tests/test_sampling.py pymc/tests/distributions/test_dist_math.py pymc/tests/distributions/test_transform.py - pymc/tests/test_parallel_sampling.py + pymc/tests/sampling/test_mcmc.py + pymc/tests/sampling/test_parallel.py pymc/tests/test_printing.py - | @@ -150,8 +150,8 @@ jobs: test-subset: - pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py - pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py - - pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py - - pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py + - pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/sampling/test_parallel.py + - pymc/tests/sampling/test_mcmc.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py fail-fast: false runs-on: ${{ matrix.os }} @@ -221,12 +221,12 @@ jobs: python-version: ["3.9"] test-subset: - | - pymc/tests/test_parallel_sampling.py + pymc/tests/sampling/test_parallel.py pymc/tests/test_data.py pymc/tests/test_model.py - | - pymc/tests/test_sampling.py + pymc/tests/sampling/test_mcmc.py - | pymc/tests/backends/test_arviz.py @@ -294,7 +294,7 @@ jobs: floatx: [float64] python-version: ["3.9"] test-subset: - - pymc/tests/test_sampling_jax.py + - pymc/tests/sampling/test_jax.py fail-fast: false runs-on: ${{ matrix.os }} env: @@ -363,7 +363,7 @@ jobs: floatx: [float32] python-version: ["3.10"] test-subset: - - pymc/tests/test_sampling.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py + - pymc/tests/sampling/test_mcmc.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py fail-fast: false runs-on: ${{ matrix.os }} env: diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index 0acbc9df28..265462e10f 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -13,8 +13,8 @@ This submodule contains functions for MCMC and forward sampling. sample_prior_predictive sample_posterior_predictive sample_posterior_predictive_w - sampling_jax.sample_blackjax_nuts - sampling_jax.sample_numpyro_nuts + sampling.jax.sample_blackjax_nuts + sampling.jax.sample_numpyro_nuts iter_sample init_nuts draw diff --git a/docs/source/contributing/build_docs.md b/docs/source/contributing/build_docs.md index 09e37328f4..d0d4044ec8 100644 --- a/docs/source/contributing/build_docs.md +++ b/docs/source/contributing/build_docs.md @@ -9,7 +9,7 @@ To build the docs, run these commands at PyMC repository root: ```bash pip install -r requirements-dev.txt # Make sure the dev requirements are installed -pip install numpyro # Make sure `sampling_jax` docs can be built +pip install numpyro # Make sure `sampling/jax` docs can be built pip install -e . # Install local pymc version as installable package make clean # clean built docs from previous runs and intermediate outputs make html # Build docs diff --git a/pymc/__init__.py b/pymc/__init__.py index 27cdf6e2bb..09314aa5c3 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -68,7 +68,6 @@ def __set_compiler_flags(): from pymc.plots import * from pymc.printing import * from pymc.sampling import * -from pymc.sampling_forward import * from pymc.smc import * from pymc.stats import * from pymc.step_methods import * diff --git a/pymc/sampling/__init__.py b/pymc/sampling/__init__.py new file mode 100644 index 0000000000..4a5d2a57a8 --- /dev/null +++ b/pymc/sampling/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2022 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pymc.sampling.forward import * +from pymc.sampling.mcmc import * diff --git a/pymc/sampling_forward.py b/pymc/sampling/forward.py similarity index 100% rename from pymc/sampling_forward.py rename to pymc/sampling/forward.py diff --git a/pymc/sampling_jax.py b/pymc/sampling/jax.py similarity index 99% rename from pymc/sampling_jax.py rename to pymc/sampling/jax.py index 3c435d3d27..c284b2fba1 100644 --- a/pymc/sampling_jax.py +++ b/pymc/sampling/jax.py @@ -7,7 +7,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Union from pymc.initial_point import StartDict -from pymc.sampling import _init_jitter +from pymc.sampling.mcmc import _init_jitter xla_flags = os.getenv("XLA_FLAGS", "") xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() diff --git a/pymc/sampling.py b/pymc/sampling/mcmc.py similarity index 99% rename from pymc/sampling.py rename to pymc/sampling/mcmc.py index 9b9b59c7ca..4a497ff468 100644 --- a/pymc/sampling.py +++ b/pymc/sampling/mcmc.py @@ -45,7 +45,7 @@ make_initial_point_fns_per_chain, ) from pymc.model import Model, modelcontext -from pymc.parallel_sampling import Draw, _cpu_count +from pymc.sampling.parallel import Draw, _cpu_count from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks from pymc.step_methods import NUTS, CompoundStep, DEMetropolis from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared @@ -1404,7 +1404,7 @@ def _mp_sample( mtrace : pymc.backends.base.MultiTrace A ``MultiTrace`` object that contains the samples for all chains. """ - import pymc.parallel_sampling as ps + import pymc.sampling.parallel as ps # We did draws += tune in pm.sample draws -= tune diff --git a/pymc/parallel_sampling.py b/pymc/sampling/parallel.py similarity index 100% rename from pymc/parallel_sampling.py rename to pymc/sampling/parallel.py diff --git a/pymc/smc/kernels.py b/pymc/smc/kernels.py index 8059ec54b4..43d060da50 100644 --- a/pymc/smc/kernels.py +++ b/pymc/smc/kernels.py @@ -35,7 +35,7 @@ from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.model import Point, modelcontext -from pymc.sampling_forward import sample_prior_predictive +from pymc.sampling.forward import sample_prior_predictive from pymc.step_methods.metropolis import MultivariateNormalProposal from pymc.vartypes import discrete_types diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index 87d198dcd4..ddeed25a16 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -32,7 +32,7 @@ from pymc.backends.arviz import dict_to_dataset, to_inference_data from pymc.backends.base import MultiTrace from pymc.model import modelcontext -from pymc.parallel_sampling import _cpu_count +from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH diff --git a/pymc/tests/distributions/test_mixture.py b/pymc/tests/distributions/test_mixture.py index 113f3e37b1..940aa21911 100644 --- a/pymc/tests/distributions/test_mixture.py +++ b/pymc/tests/distributions/test_mixture.py @@ -55,12 +55,12 @@ from pymc.distributions.transforms import _default_transform from pymc.math import expand_packed_triangular from pymc.model import Model -from pymc.sampling import sample -from pymc.sampling_forward import ( +from pymc.sampling.forward import ( draw, sample_posterior_predictive, sample_prior_predictive, ) +from pymc.sampling.mcmc import sample from pymc.step_methods import Metropolis from pymc.tests.distributions.util import ( Domain, diff --git a/pymc/tests/distributions/test_multivariate.py b/pymc/tests/distributions/test_multivariate.py index d023912ce2..e7cb828696 100644 --- a/pymc/tests/distributions/test_multivariate.py +++ b/pymc/tests/distributions/test_multivariate.py @@ -41,7 +41,7 @@ ) from pymc.distributions.shape_utils import change_dist_size, to_tuple from pymc.math import kronecker -from pymc.sampling_forward import draw +from pymc.sampling.forward import draw from pymc.tests.distributions.util import ( BaseTestDistributionRandom, Domain, diff --git a/pymc/tests/distributions/test_timeseries.py b/pymc/tests/distributions/test_timeseries.py index 735b84b707..31e5f07acc 100644 --- a/pymc/tests/distributions/test_timeseries.py +++ b/pymc/tests/distributions/test_timeseries.py @@ -42,8 +42,8 @@ RandomWalk, ) from pymc.model import Model -from pymc.sampling import sample -from pymc.sampling_forward import draw, sample_posterior_predictive +from pymc.sampling.forward import draw, sample_posterior_predictive +from pymc.sampling.mcmc import sample from pymc.tests.distributions.util import assert_moment_is_expected from pymc.tests.helpers import select_by_precision diff --git a/pymc/tests/sampler_fixtures.py b/pymc/tests/sampler_fixtures.py index dabb3466e6..db66784adc 100644 --- a/pymc/tests/sampler_fixtures.py +++ b/pymc/tests/sampler_fixtures.py @@ -178,7 +178,7 @@ def make_step(cls): if hasattr(cls, "step_args"): args.update(cls.step_args) if "scaling" not in args: - _, step = pm.sampling.init_nuts(n_init=10000, **args) + _, step = pm.sampling.mcmc.init_nuts(n_init=10000, **args) else: step = pm.NUTS(**args) return step diff --git a/pymc/tests/sampling/__init__.py b/pymc/tests/sampling/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pymc/tests/test_sampling_forward.py b/pymc/tests/sampling/test_forward.py similarity index 99% rename from pymc/tests/test_sampling_forward.py rename to pymc/tests/sampling/test_forward.py index d02d4df1f3..28a33c418d 100644 --- a/pymc/tests/test_sampling_forward.py +++ b/pymc/tests/sampling/test_forward.py @@ -35,7 +35,7 @@ from pymc.aesaraf import compile_pymc from pymc.backends.base import MultiTrace -from pymc.sampling_forward import ( +from pymc.sampling.forward import ( compile_forward_sampling_function, get_vars_in_point_list, ) diff --git a/pymc/tests/test_sampling_jax.py b/pymc/tests/sampling/test_jax.py similarity index 99% rename from pymc/tests/test_sampling_jax.py rename to pymc/tests/sampling/test_jax.py index 9b0c0ab909..8bf23af2ef 100644 --- a/pymc/tests/test_sampling_jax.py +++ b/pymc/tests/sampling/test_jax.py @@ -17,7 +17,7 @@ import pymc as pm with pytest.warns(UserWarning, match="module is experimental"): - from pymc.sampling_jax import ( + from pymc.sampling.jax import ( _get_batched_jittered_initial_points, _get_log_likelihood, _numpyro_nuts_defaults, diff --git a/pymc/tests/test_sampling.py b/pymc/tests/sampling/test_mcmc.py similarity index 97% rename from pymc/tests/test_sampling.py rename to pymc/tests/sampling/test_mcmc.py index 64fdb2a40f..540850bd23 100644 --- a/pymc/tests/test_sampling.py +++ b/pymc/tests/sampling/test_mcmc.py @@ -35,7 +35,7 @@ from pymc.backends.ndarray import NDArray from pymc.distributions import transforms from pymc.exceptions import SamplingError -from pymc.sampling import assign_step_methods +from pymc.sampling.mcmc import assign_step_methods from pymc.stats.convergence import SamplerWarning, WarningType from pymc.step_methods import ( NUTS, @@ -57,7 +57,7 @@ def setup_method(self): def test_checks_seeds_kwarg(self): with self.model: with pytest.raises(ValueError, match="Number of seeds"): - pm.sampling.init_nuts(chains=2, random_seed=[1]) + pm.sampling.mcmc.init_nuts(chains=2, random_seed=[1]) class TestSample(SeededTest): @@ -208,7 +208,7 @@ def test_sample_args(self): def test_iter_sample(self): with self.model: - samps = pm.sampling.iter_sample( + samps = pm.sampling.mcmc.iter_sample( draws=5, step=self.step, start=self.start, @@ -255,7 +255,7 @@ def test_reset_tuning(self): with self.model: tune = 50 chains = 2 - start, step = pm.sampling.init_nuts(chains=chains, random_seed=[1, 2]) + start, step = pm.sampling.mcmc.init_nuts(chains=chains, random_seed=[1, 2]) with warnings.catch_warnings(): warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning) pm.sample(draws=2, tune=tune, chains=chains, step=step, initvals=start, cores=1) @@ -346,11 +346,11 @@ def test_sampler_stat_tune(self, cores): ) def test_sample_start_bad_shape(self, start, error): with pytest.raises(error): - pm.sampling._check_start_shape(self.model, start) + pm.sampling.mcmc._check_start_shape(self.model, start) @pytest.mark.parametrize("start", [{"x": np.array([1, 1])}, {"x": [10, 10]}, {"x": [-10, -10]}]) def test_sample_start_good_shape(self, start): - pm.sampling._check_start_shape(self.model, start) + pm.sampling.mcmc._check_start_shape(self.model, start) def test_sample_callback(self): callback = mock.Mock() @@ -515,7 +515,7 @@ def test_choose_chains(n_points, tune, expected_length, expected_n_traces): trace_1.record({"a": 0}) for _ in range(n_points[2]): trace_2.record({"a": 0}) - traces, length = pm.sampling._choose_chains([trace_0, trace_1, trace_2], tune=tune) + traces, length = pm.sampling.mcmc._choose_chains([trace_0, trace_1, trace_2], tune=tune) assert length == expected_length assert expected_n_traces == len(traces) @@ -575,29 +575,29 @@ def test_constant_named(self): class TestChooseBackend: def test_choose_backend_none(self): - with mock.patch("pymc.sampling.NDArray") as nd: - pm.sampling._choose_backend(None) + with mock.patch("pymc.sampling.mcmc.NDArray") as nd: + pm.sampling.mcmc._choose_backend(None) assert nd.called def test_choose_backend_list_of_variables(self): - with mock.patch("pymc.sampling.NDArray") as nd: - pm.sampling._choose_backend(["var1", "var2"]) + with mock.patch("pymc.sampling.mcmc.NDArray") as nd: + pm.sampling.mcmc._choose_backend(["var1", "var2"]) nd.assert_called_with(vars=["var1", "var2"]) def test_errors_and_warnings(self): with pm.Model(): A = pm.Normal("A") B = pm.Uniform("B") - strace = pm.sampling.NDArray(vars=[A, B]) + strace = pm.backends.ndarray.NDArray(vars=[A, B]) strace.setup(10, 0) with pytest.raises(ValueError, match="from existing MultiTrace"): - pm.sampling._choose_backend(trace=MultiTrace([strace])) + pm.sampling.mcmc._choose_backend(trace=MultiTrace([strace])) strace.record({"A": 2, "B_interval__": 0.1}) assert len(strace) == 1 with pytest.raises(ValueError, match="Continuation of traces"): - pm.sampling._choose_backend(trace=strace) + pm.sampling.mcmc._choose_backend(trace=strace) def check_exec_nuts_init(method): @@ -657,7 +657,7 @@ def test_init_jitter(initval, jitter_max_retries, expectation): # Starting value is negative (invalid) when np.random.rand returns 0 (jitter = -1) # and positive (valid) when it returns 1 (jitter = 1) with mock.patch("numpy.random.Generator.uniform", side_effect=[-1, -1, -1, 1, -1]): - start = pm.sampling._init_jitter( + start = pm.sampling.mcmc._init_jitter( model=m, initvals=None, seeds=[1], @@ -704,7 +704,7 @@ def test_log_warning_stats(caplog): stats = [s1, s2] with caplog.at_level(logging.WARNING): - pm.sampling.log_warning_stats(stats) + pm.sampling.mcmc.log_warning_stats(stats) # We have a list of stats dicts, because there might be several samplers involved. assert "too low" in caplog.records[0].message @@ -716,7 +716,7 @@ def test_log_warning_stats_knows_SamplerWarning(caplog): stats = [dict(warning=SamplerWarning(WarningType.BAD_ENERGY, "Not that interesting", "debug"))] with caplog.at_level(logging.DEBUG, logger="pymc"): - pm.sampling.log_warning_stats(stats) + pm.sampling.mcmc.log_warning_stats(stats) assert "Not that interesting" in caplog.records[0].message diff --git a/pymc/tests/test_parallel_sampling.py b/pymc/tests/sampling/test_parallel.py similarity index 99% rename from pymc/tests/test_parallel_sampling.py rename to pymc/tests/sampling/test_parallel.py index 2883acd297..77ba9f48b7 100644 --- a/pymc/tests/test_parallel_sampling.py +++ b/pymc/tests/sampling/test_parallel.py @@ -27,7 +27,7 @@ from aesara.tensor.type import TensorType import pymc as pm -import pymc.parallel_sampling as ps +import pymc.sampling.parallel as ps from pymc.aesaraf import floatX diff --git a/pymc/variational/opvi.py b/pymc/variational/opvi.py index d3ec0b8b2b..be9e430462 100644 --- a/pymc/variational/opvi.py +++ b/pymc/variational/opvi.py @@ -66,7 +66,8 @@ reseed_rngs, rvs_to_value_vars, ) -from pymc.backends import NDArray +from pymc.backends.base import MultiTrace +from pymc.backends.ndarray import NDArray from pymc.blocking import DictToArrayBijection from pymc.initial_point import make_initial_point_fn from pymc.model import modelcontext @@ -1477,7 +1478,7 @@ def sample( finally: trace.close() - trace = pm.sampling.MultiTrace([trace]) + trace = MultiTrace([trace]) if not return_inferencedata: return trace else: diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index 2b55afdddf..f997b5d289 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -49,10 +49,11 @@ pymc/ode/__init__.py pymc/ode/ode.py pymc/ode/utils.py -pymc/parallel_sampling.py pymc/plots/__init__.py -pymc/sampling.py -pymc/sampling_forward.py +pymc/sampling/__init__.py +pymc/sampling/forward.py +pymc/sampling/mcmc.py +pymc/sampling/parallel.py pymc/smc/__init__.py pymc/smc/sampling.py pymc/smc/kernels.py @@ -167,10 +168,10 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]): print("You can run `python scripts/run_mypy.py --verbose` to reproduce this test locally.") sys.exit(1) - if unexpected_passing == {"pymc/sampling_jax.py"}: - print("Letting you know that 'pymc/sampling_jax.py' unexpectedly passed.") + if unexpected_passing == {"pymc/sampling/jax.py"}: + print("Letting you know that 'pymc/sampling/jax.py' unexpectedly passed.") print("But this file is known to sometimes pass and sometimes not.") - print("Unless you tried to resolve problems in sampling_jax.py just ignore this message.") + print("Unless you tried to resolve problems in sampling/jax.py just ignore this message.") elif unexpected_passing: print("!!!!!!!!!") print(f"{len(unexpected_passing)} files unexpectedly passed the type checks:") From f18e61cf56a51f8b0a5e6fc4cb4bae0d8502a85b Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Sat, 5 Nov 2022 13:36:28 +0100 Subject: [PATCH 2/2] Reintroduce sampling_jax.py for backward compatibility This is a separate commit to make sure that git tracks the rename of the old `sampling_jax.py` to `sampling/jax.py` correctly. --- pymc/sampling/jax.py | 8 ++++++++ pymc/sampling_jax.py | 7 +++++++ pymc/tests/sampling/test_jax.py | 8 ++++++++ scripts/run_mypy.py | 1 + 4 files changed, 24 insertions(+) create mode 100644 pymc/sampling_jax.py diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index c284b2fba1..c47f1da2f8 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -37,6 +37,14 @@ warnings.warn("This module is experimental.") +__all__ = ( + "get_jaxified_graph", + "get_jaxified_logp", + "sample_blackjax_nuts", + "sample_numpyro_nuts", +) + + @jax_funcify.register(Assert) @jax_funcify.register(CheckParameterValue) @jax_funcify.register(SpecifyShape) diff --git a/pymc/sampling_jax.py b/pymc/sampling_jax.py new file mode 100644 index 0000000000..6afeb9e1d1 --- /dev/null +++ b/pymc/sampling_jax.py @@ -0,0 +1,7 @@ +# This file exists only for backward-compatibility with imports like +# `import pymc.sampling_jax` or `from pymc import sampling_jax`. + +# pylint: disable=wildcard-import +# pylint: disable=unused-wildcard-import + +from pymc.sampling.jax import * diff --git a/pymc/tests/sampling/test_jax.py b/pymc/tests/sampling/test_jax.py index 8bf23af2ef..28e4f4c9b3 100644 --- a/pymc/tests/sampling/test_jax.py +++ b/pymc/tests/sampling/test_jax.py @@ -16,6 +16,14 @@ import pymc as pm + +def test_old_import_route(): + import pymc.sampling.jax as new_sj + import pymc.sampling_jax as old_sj + + assert set(new_sj.__all__) <= set(dir(old_sj)) + + with pytest.warns(UserWarning, match="module is experimental"): from pymc.sampling.jax import ( _get_batched_jittered_initial_points, diff --git a/scripts/run_mypy.py b/scripts/run_mypy.py index f997b5d289..310f26fbdc 100644 --- a/scripts/run_mypy.py +++ b/scripts/run_mypy.py @@ -50,6 +50,7 @@ pymc/ode/ode.py pymc/ode/utils.py pymc/plots/__init__.py +pymc/sampling_jax.py pymc/sampling/__init__.py pymc/sampling/forward.py pymc/sampling/mcmc.py