Skip to content

Commit 0997d3c

Browse files
Split sampling into three modules
Closes #6141
1 parent e57d1d7 commit 0997d3c

File tree

15 files changed

+2581
-2440
lines changed

15 files changed

+2581
-2440
lines changed

.github/workflows/tests.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ jobs:
5959
pymc/tests/distributions/test_censored.py
6060
pymc/tests/distributions/test_simulator.py
6161
pymc/tests/distributions/test_truncated.py
62+
pymc/tests/test_sampling_predictive.py
63+
pymc/tests/stats/test_convergence.py
6264
6365
- |
6466
pymc/tests/tuning/test_scaling.py
@@ -147,7 +149,7 @@ jobs:
147149
python-version: ["3.8"]
148150
test-subset:
149151
- pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py
150-
- pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
152+
- pymc/tests/test_model.py pymc/tests/test_sampling_utils.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
151153
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py
152154
- pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py
153155

pymc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ def __set_compiler_flags():
6868
from pymc.plots import *
6969
from pymc.printing import *
7070
from pymc.sampling import *
71+
from pymc.sampling_predictive import *
72+
from pymc.sampling_utils import *
7173
from pymc.smc import *
7274
from pymc.stats import *
7375
from pymc.step_methods import *

pymc/sampling.py

Lines changed: 9 additions & 760 deletions
Large diffs are not rendered by default.

pymc/sampling_jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
88

99
from pymc.initial_point import StartDict
10-
from pymc.sampling import RandomSeed, _get_seeds_per_chain, _init_jitter
10+
from pymc.sampling import _init_jitter
11+
from pymc.sampling_utils import RandomSeed, _get_seeds_per_chain
1112

1213
xla_flags = os.getenv("XLA_FLAGS", "")
1314
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()

0 commit comments

Comments
 (0)