diff --git a/pymc/smc/sampling.py b/pymc/smc/sampling.py index ddeed25a16..19f783610d 100644 --- a/pymc/smc/sampling.py +++ b/pymc/smc/sampling.py @@ -18,7 +18,6 @@ import warnings from collections import defaultdict -from collections.abc import Iterable from itertools import repeat import cloudpickle @@ -34,6 +33,7 @@ from pymc.model import modelcontext from pymc.sampling.parallel import _cpu_count from pymc.smc.kernels import IMH +from pymc.util import RandomState, _get_seeds_per_chain def sample_smc( @@ -42,7 +42,7 @@ def sample_smc( *, start=None, model=None, - random_seed=None, + random_seed: RandomState = None, chains=None, cores=None, compute_convergence_checks=True, @@ -64,8 +64,10 @@ def sample_smc( Starting point in parameter space. It should be a list of dict with length `chains`. When None (default) the starting point is sampled from the prior distribution. model: Model (optional if in ``with`` context)). - random_seed: int - random seed + random_seed : int, array-like of int, RandomState or Generator, optional + Random seed(s) used by the sampling steps. If a list, tuple or array of ints + is passed, each entry will be used to seed each chain. A ValueError will be + raised if the length does not match the number of chains. chains : int The number of chains to sample. Running independent chains is important for some convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever @@ -183,20 +185,7 @@ def sample_smc( else: cores = min(chains, cores) - if random_seed == -1: - raise FutureWarning( - f"random_seed should be a non-negative integer or None, got: {random_seed}" - "This will raise a ValueError in the Future" - ) - random_seed = None - if isinstance(random_seed, int) or random_seed is None: - rng = np.random.default_rng(seed=random_seed) - random_seed = list(rng.integers(2**30, size=chains)) - elif isinstance(random_seed, Iterable): - if len(random_seed) != chains: - raise ValueError(f"Length of seeds ({len(seeds)}) must match number of chains {chains}") - else: - raise TypeError("Invalid value for `random_seed`. Must be tuple, list, int or None") + random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains) model = modelcontext(model)