Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 7 additions & 18 deletions pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import warnings

from collections import defaultdict
from collections.abc import Iterable
from itertools import repeat

import cloudpickle
Expand All @@ -34,6 +33,7 @@
from pymc.model import modelcontext
from pymc.sampling.parallel import _cpu_count
from pymc.smc.kernels import IMH
from pymc.util import RandomState, _get_seeds_per_chain


def sample_smc(
Expand All @@ -42,7 +42,7 @@ def sample_smc(
*,
start=None,
model=None,
random_seed=None,
random_seed: RandomState = None,
chains=None,
cores=None,
compute_convergence_checks=True,
Expand All @@ -64,8 +64,10 @@ def sample_smc(
Starting point in parameter space. It should be a list of dict with length `chains`.
When None (default) the starting point is sampled from the prior distribution.
model: Model (optional if in ``with`` context)).
random_seed: int
random seed
random_seed : int, array-like of int, RandomState or Generator, optional
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
is passed, each entry will be used to seed each chain. A ValueError will be
raised if the length does not match the number of chains.
chains : int
The number of chains to sample. Running independent chains is important for some
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
Expand Down Expand Up @@ -183,20 +185,7 @@ def sample_smc(
else:
cores = min(chains, cores)

if random_seed == -1:
raise FutureWarning(
f"random_seed should be a non-negative integer or None, got: {random_seed}"
"This will raise a ValueError in the Future"
)
random_seed = None
if isinstance(random_seed, int) or random_seed is None:
rng = np.random.default_rng(seed=random_seed)
random_seed = list(rng.integers(2**30, size=chains))
elif isinstance(random_seed, Iterable):
if len(random_seed) != chains:
raise ValueError(f"Length of seeds ({len(seeds)}) must match number of chains {chains}")
else:
raise TypeError("Invalid value for `random_seed`. Must be tuple, list, int or None")
random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)

model = modelcontext(model)

Expand Down