Skip to content

Commit 31365e6

Browse files
committed
Move get_seeds_per_chain and Random type-hint variables to util.py
1 parent 4fd44ea commit 31365e6

File tree

10 files changed

+119
-127
lines changed

10 files changed

+119
-127
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ jobs:
149149
python-version: ["3.8"]
150150
test-subset:
151151
- pymc/tests/variational/test_approximations.py pymc/tests/variational/test_callbacks.py pymc/tests/variational/test_inference.py pymc/tests/variational/test_opvi.py pymc/tests/test_initial_point.py
152-
- pymc/tests/test_model.py pymc/tests/test_sampling_utils.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
152+
- pymc/tests/test_model.py pymc/tests/step_methods/test_compound.py pymc/tests/step_methods/hmc/test_hmc.py
153153
- pymc/tests/gp/test_cov.py pymc/tests/gp/test_gp.py pymc/tests/gp/test_mean.py pymc/tests/gp/test_util.py pymc/tests/ode/test_ode.py pymc/tests/ode/test_utils.py pymc/tests/smc/test_smc.py pymc/tests/test_parallel_sampling.py
154154
- pymc/tests/test_sampling.py pymc/tests/step_methods/test_metropolis.py pymc/tests/step_methods/test_slicer.py pymc/tests/step_methods/hmc/test_nuts.py
155155

pymc/parallel_sampling.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import traceback
2222

2323
from collections import namedtuple
24-
from typing import TYPE_CHECKING, Dict, List, Sequence
24+
from typing import Dict, List, Sequence
2525

2626
import cloudpickle
2727
import numpy as np
@@ -31,10 +31,7 @@
3131
from pymc import aesaraf
3232
from pymc.blocking import DictToArrayBijection
3333
from pymc.exceptions import SamplingError
34-
35-
# Avoid circular import
36-
if TYPE_CHECKING:
37-
from pymc.sampling import RandomSeed
34+
from pymc.util import RandomSeed
3835

3936
logger = logging.getLogger("pymc")
4037

pymc/sampling.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,18 @@
4646
)
4747
from pymc.model import Model, modelcontext
4848
from pymc.parallel_sampling import Draw, _cpu_count
49-
from pymc.sampling_utils import RandomSeed, RandomState, _get_seeds_per_chain
5049
from pymc.stats.convergence import SamplerWarning, log_warning, run_convergence_checks
5150
from pymc.step_methods import NUTS, CompoundStep, DEMetropolis
5251
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
5352
from pymc.step_methods.hmc import quadpotential
54-
from pymc.util import drop_warning_stat, get_untransformed_name, is_transformed_name
53+
from pymc.util import (
54+
RandomSeed,
55+
RandomState,
56+
_get_seeds_per_chain,
57+
drop_warning_stat,
58+
get_untransformed_name,
59+
is_transformed_name,
60+
)
5561
from pymc.vartypes import discrete_types
5662

5763
sys.setrecursionlimit(10000)

pymc/sampling_forward.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@
5353
from pymc.backends.base import MultiTrace
5454
from pymc.blocking import PointType
5555
from pymc.model import Model, modelcontext
56-
from pymc.sampling_utils import RandomState, _get_seeds_per_chain
57-
from pymc.util import dataset_to_point_list, get_default_varnames, point_wrapper
56+
from pymc.util import (
57+
RandomState,
58+
_get_seeds_per_chain,
59+
dataset_to_point_list,
60+
get_default_varnames,
61+
point_wrapper,
62+
)
5863

5964
__all__ = (
6065
"compile_forward_sampling_function",

pymc/sampling_jax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from pymc.initial_point import StartDict
1010
from pymc.sampling import _init_jitter
11-
from pymc.sampling_utils import RandomSeed, _get_seeds_per_chain
1211

1312
xla_flags = os.getenv("XLA_FLAGS", "")
1413
xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
@@ -33,7 +32,7 @@
3332

3433
from pymc import Model, modelcontext
3534
from pymc.backends.arviz import find_constants, find_observations
36-
from pymc.util import get_default_varnames
35+
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
3736

3837
warnings.warn("This module is experimental.")
3938

pymc/sampling_utils.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414

1515
"""Helper functions for MCMC, prior and posterior predictive sampling."""
1616

17-
from typing import Optional, Sequence, Union
18-
19-
import numpy as np
17+
from typing import Union
2018

2119
from typing_extensions import TypeAlias
2220

@@ -25,58 +23,4 @@
2523

2624
Backend: TypeAlias = Union[BaseTrace, MultiTrace, NDArray]
2725

28-
RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]]
29-
RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator]
30-
31-
3226
__all__ = ()
33-
34-
35-
def _get_seeds_per_chain(
36-
random_state: RandomState,
37-
chains: int,
38-
) -> Union[Sequence[int], np.ndarray]:
39-
"""Obtain or validate specified integer seeds per chain.
40-
41-
This function process different possible sources of seeding and returns one integer
42-
seed per chain:
43-
1. If the input is an integer and a single chain is requested, the input is
44-
returned inside a tuple.
45-
2. If the input is a sequence or NumPy array with as many entries as chains,
46-
the input is returned.
47-
3. If the input is an integer and multiple chains are requested, new unique seeds
48-
are generated from NumPy default Generator seeded with that integer.
49-
4. If the input is None new unique seeds are generated from an unseeded NumPy default
50-
Generator.
51-
5. If a RandomState or Generator is provided, new unique seeds are generated from it.
52-
53-
Raises
54-
------
55-
ValueError
56-
If none of the conditions above are met
57-
"""
58-
59-
def _get_unique_seeds_per_chain(integers_fn):
60-
seeds = []
61-
while len(set(seeds)) != chains:
62-
seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)]
63-
return seeds
64-
65-
if random_state is None or isinstance(random_state, int):
66-
if chains == 1 and isinstance(random_state, int):
67-
return (random_state,)
68-
return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers)
69-
if isinstance(random_state, np.random.Generator):
70-
return _get_unique_seeds_per_chain(random_state.integers)
71-
if isinstance(random_state, np.random.RandomState):
72-
return _get_unique_seeds_per_chain(random_state.randint)
73-
74-
if not isinstance(random_state, (list, tuple, np.ndarray)):
75-
raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.")
76-
77-
if len(random_state) != chains:
78-
raise ValueError(
79-
f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})."
80-
)
81-
82-
return random_state

pymc/tests/test_sampling_utils.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

pymc/tests/test_util.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import re
1415

1516
import arviz
1617
import numpy as np
@@ -24,6 +25,7 @@
2425
from pymc.distributions.transforms import RVTransform
2526
from pymc.util import (
2627
UNSET,
28+
_get_seeds_per_chain,
2729
dataset_to_point_list,
2830
drop_warning_stat,
2931
hash_key,
@@ -198,3 +200,39 @@ def test_drop_warning_stat():
198200
assert "a" in ss
199201
assert "warning" not in ss
200202
assert "warning_dim_0" not in ss
203+
204+
205+
def test_get_seeds_per_chain():
206+
ret = _get_seeds_per_chain(None, chains=1)
207+
assert len(ret) == 1 and isinstance(ret[0], int)
208+
209+
ret = _get_seeds_per_chain(None, chains=2)
210+
assert len(ret) == 2 and isinstance(ret[0], int)
211+
212+
ret = _get_seeds_per_chain(5, chains=1)
213+
assert ret == (5,)
214+
215+
ret = _get_seeds_per_chain(5, chains=3)
216+
assert len(ret) == 3 and isinstance(ret[0], int) and not any(r == 5 for r in ret)
217+
218+
rng = np.random.default_rng(123)
219+
expected_ret = rng.integers(2**30, dtype=np.int64, size=1)
220+
rng = np.random.default_rng(123)
221+
ret = _get_seeds_per_chain(rng, chains=1)
222+
assert ret == expected_ret
223+
224+
rng = np.random.RandomState(456)
225+
expected_ret = rng.randint(2**30, dtype=np.int64, size=2)
226+
rng = np.random.RandomState(456)
227+
ret = _get_seeds_per_chain(rng, chains=2)
228+
assert np.all(ret == expected_ret)
229+
230+
for expected_ret in ([0, 1, 2], (0, 1, 2, 3), np.arange(5)):
231+
ret = _get_seeds_per_chain(expected_ret, chains=len(expected_ret))
232+
assert ret is expected_ret
233+
234+
with pytest.raises(ValueError, match="does not match the number of chains"):
235+
_get_seeds_per_chain(expected_ret, chains=len(expected_ret) + 1)
236+
237+
with pytest.raises(ValueError, match=re.escape("The `seeds` must be array-like")):
238+
_get_seeds_per_chain({1: 1, 2: 2}, 2)

pymc/util.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import functools
1616

17-
from typing import Any, Dict, List, Tuple, Union, cast
17+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
1818

1919
import arviz
2020
import cloudpickle
@@ -387,3 +387,57 @@ def wrapped(**kwargs):
387387
return core_function(**input_point)
388388

389389
return wrapped
390+
391+
392+
RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]]
393+
RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator]
394+
395+
396+
def _get_seeds_per_chain(
397+
random_state: RandomState,
398+
chains: int,
399+
) -> Union[Sequence[int], np.ndarray]:
400+
"""Obtain or validate specified integer seeds per chain.
401+
402+
This function process different possible sources of seeding and returns one integer
403+
seed per chain:
404+
1. If the input is an integer and a single chain is requested, the input is
405+
returned inside a tuple.
406+
2. If the input is a sequence or NumPy array with as many entries as chains,
407+
the input is returned.
408+
3. If the input is an integer and multiple chains are requested, new unique seeds
409+
are generated from NumPy default Generator seeded with that integer.
410+
4. If the input is None new unique seeds are generated from an unseeded NumPy default
411+
Generator.
412+
5. If a RandomState or Generator is provided, new unique seeds are generated from it.
413+
414+
Raises
415+
------
416+
ValueError
417+
If none of the conditions above are met
418+
"""
419+
420+
def _get_unique_seeds_per_chain(integers_fn):
421+
seeds = []
422+
while len(set(seeds)) != chains:
423+
seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)]
424+
return seeds
425+
426+
if random_state is None or isinstance(random_state, int):
427+
if chains == 1 and isinstance(random_state, int):
428+
return (random_state,)
429+
return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers)
430+
if isinstance(random_state, np.random.Generator):
431+
return _get_unique_seeds_per_chain(random_state.integers)
432+
if isinstance(random_state, np.random.RandomState):
433+
return _get_unique_seeds_per_chain(random_state.randint)
434+
435+
if not isinstance(random_state, (list, tuple, np.ndarray)):
436+
raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.")
437+
438+
if len(random_state) != chains:
439+
raise ValueError(
440+
f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})."
441+
)
442+
443+
return random_state

pymc/variational/opvi.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,12 @@
7070
from pymc.blocking import DictToArrayBijection
7171
from pymc.initial_point import make_initial_point_fn
7272
from pymc.model import modelcontext
73-
from pymc.sampling_utils import RandomState, _get_seeds_per_chain
74-
from pymc.util import WithMemoization, locally_cachedmethod
73+
from pymc.util import (
74+
RandomState,
75+
WithMemoization,
76+
_get_seeds_per_chain,
77+
locally_cachedmethod,
78+
)
7579
from pymc.variational.updates import adagrad_window
7680
from pymc.vartypes import discrete_types
7781

0 commit comments

Comments
 (0)