Skip to content

Commit 4fd44ea

Browse files
committed
Move non-shared utils to respective files
1 parent 451a207 commit 4fd44ea

File tree

5 files changed

+55
-65
lines changed

5 files changed

+55
-65
lines changed

pymc/sampling.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,7 @@
4646
)
4747
from pymc.model import Model, modelcontext
4848
from pymc.parallel_sampling import Draw, _cpu_count
49-
from pymc.sampling_utils import (
50-
RandomSeed,
51-
RandomState,
52-
_get_seeds_per_chain,
53-
all_continuous,
54-
)
49+
from pymc.sampling_utils import RandomSeed, RandomState, _get_seeds_per_chain
5550
from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks
5651
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
5752
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
@@ -208,6 +203,17 @@ def _print_step_hierarchy(s: Step, level: int = 0) -> None:
208203
_log.info(">" * level + f"{s.__class__.__name__}: [{varnames}]")
209204

210205

206+
def all_continuous(vars):
207+
"""Check that vars not include discrete variables, excepting observed RVs."""
208+
209+
vars_ = [var for var in vars if not hasattr(var.tag, "observations")]
210+
211+
if any([(var.dtype in discrete_types) for var in vars_]):
212+
return False
213+
else:
214+
return True
215+
216+
211217
def sample(
212218
draws: int = 1000,
213219
step=None,

pymc/sampling_forward.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,16 @@
4444
from aesara.tensor.sharedvar import SharedVariable
4545
from arviz import InferenceData
4646
from fastprogress.fastprogress import progress_bar
47+
from typing_extensions import TypeAlias
4748

4849
import pymc as pm
4950

5051
from pymc.aesaraf import compile_pymc
5152
from pymc.backends.arviz import _DefaultTrace
5253
from pymc.backends.base import MultiTrace
54+
from pymc.blocking import PointType
5355
from pymc.model import Model, modelcontext
54-
from pymc.sampling_utils import (
55-
ArrayLike,
56-
PointList,
57-
RandomState,
58-
_get_seeds_per_chain,
59-
get_vars_in_point_list,
60-
)
56+
from pymc.sampling_utils import RandomState, _get_seeds_per_chain
6157
from pymc.util import dataset_to_point_list, get_default_varnames, point_wrapper
6258

6359
__all__ = (
@@ -69,9 +65,22 @@
6965
)
7066

7167

68+
ArrayLike: TypeAlias = Union[np.ndarray, List[float]]
69+
PointList: TypeAlias = List[PointType]
70+
7271
_log = logging.getLogger("pymc")
7372

7473

74+
def get_vars_in_point_list(trace, model):
75+
"""Get the list of Variable instances in the model that have values stored in the trace."""
76+
if not isinstance(trace, MultiTrace):
77+
names_in_trace = list(trace[0])
78+
else:
79+
names_in_trace = trace.varnames
80+
vars_in_trace = [model[v] for v in names_in_trace if v in model]
81+
return vars_in_trace
82+
83+
7584
def compile_forward_sampling_function(
7685
outputs: List[Variable],
7786
vars_in_trace: List[Variable],

pymc/sampling_utils.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,15 @@
1414

1515
"""Helper functions for MCMC, prior and posterior predictive sampling."""
1616

17-
from typing import List, Optional, Sequence, Union
17+
from typing import Optional, Sequence, Union
1818

1919
import numpy as np
2020

2121
from typing_extensions import TypeAlias
2222

2323
from pymc.backends.base import BaseTrace, MultiTrace
2424
from pymc.backends.ndarray import NDArray
25-
from pymc.initial_point import PointType
26-
from pymc.vartypes import discrete_types
2725

28-
ArrayLike: TypeAlias = Union[np.ndarray, List[float]]
29-
PointList: TypeAlias = List[PointType]
3026
Backend: TypeAlias = Union[BaseTrace, MultiTrace, NDArray]
3127

3228
RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]]
@@ -36,17 +32,6 @@
3632
__all__ = ()
3733

3834

39-
def all_continuous(vars):
40-
"""Check that vars not include discrete variables, excepting observed RVs."""
41-
42-
vars_ = [var for var in vars if not hasattr(var.tag, "observations")]
43-
44-
if any([(var.dtype in discrete_types) for var in vars_]):
45-
return False
46-
else:
47-
return True
48-
49-
5035
def _get_seeds_per_chain(
5136
random_state: RandomState,
5237
chains: int,
@@ -95,13 +80,3 @@ def _get_unique_seeds_per_chain(integers_fn):
9580
)
9681

9782
return random_state
98-
99-
100-
def get_vars_in_point_list(trace, model):
101-
"""Get the list of Variable instances in the model that have values stored in the trace."""
102-
if not isinstance(trace, MultiTrace):
103-
names_in_trace = list(trace[0])
104-
else:
105-
names_in_trace = trace.varnames
106-
vars_in_trace = [model[v] for v in names_in_trace if v in model]
107-
return vars_in_trace

pymc/tests/test_sampling_forward.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535

3636
from pymc.aesaraf import compile_pymc
3737
from pymc.backends.base import MultiTrace
38-
from pymc.sampling_forward import compile_forward_sampling_function
38+
from pymc.sampling_forward import (
39+
compile_forward_sampling_function,
40+
get_vars_in_point_list,
41+
)
3942
from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode
4043

4144

@@ -1635,3 +1638,24 @@ def test_Triangular(
16351638
prior_samples=prior_samples,
16361639
)
16371640
assert prior["target"].shape == (prior_samples,) + shape
1641+
1642+
1643+
def test_get_vars_in_point_list():
1644+
with pm.Model() as modelA:
1645+
pm.Normal("a", 0, 1)
1646+
pm.Normal("b", 0, 1)
1647+
with pm.Model() as modelB:
1648+
a = pm.Normal("a", 0, 1)
1649+
pm.Normal("c", 0, 1)
1650+
1651+
point_list = [{"a": 0, "b": 0}]
1652+
vars_in_trace = get_vars_in_point_list(point_list, modelB)
1653+
assert set(vars_in_trace) == {a}
1654+
1655+
strace = pm.backends.NDArray(model=modelB, vars=modelA.free_RVs)
1656+
strace.setup(1, 1)
1657+
strace.values = point_list[0]
1658+
strace.draw_idx = 1
1659+
trace = MultiTrace([strace])
1660+
vars_in_trace = get_vars_in_point_list(trace, modelB)
1661+
assert set(vars_in_trace) == {a}

pymc/tests/test_sampling_utils.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616
import numpy as np
1717
import pytest
1818

19-
import pymc as pm
20-
21-
from pymc.backends.base import MultiTrace
22-
from pymc.sampling_utils import _get_seeds_per_chain, get_vars_in_point_list
19+
from pymc.sampling_utils import _get_seeds_per_chain
2320

2421

2522
def test_get_seeds_per_chain():
@@ -56,24 +53,3 @@ def test_get_seeds_per_chain():
5653

5754
with pytest.raises(ValueError, match=re.escape("The `seeds` must be array-like")):
5855
_get_seeds_per_chain({1: 1, 2: 2}, 2)
59-
60-
61-
def test_get_vars_in_point_list():
62-
with pm.Model() as modelA:
63-
pm.Normal("a", 0, 1)
64-
pm.Normal("b", 0, 1)
65-
with pm.Model() as modelB:
66-
a = pm.Normal("a", 0, 1)
67-
pm.Normal("c", 0, 1)
68-
69-
point_list = [{"a": 0, "b": 0}]
70-
vars_in_trace = get_vars_in_point_list(point_list, modelB)
71-
assert set(vars_in_trace) == {a}
72-
73-
strace = pm.backends.NDArray(model=modelB, vars=modelA.free_RVs)
74-
strace.setup(1, 1)
75-
strace.values = point_list[0]
76-
strace.draw_idx = 1
77-
trace = MultiTrace([strace])
78-
vars_in_trace = get_vars_in_point_list(trace, modelB)
79-
assert set(vars_in_trace) == {a}

0 commit comments

Comments
 (0)