1818import warnings
1919
2020from collections import defaultdict
21- from collections .abc import Iterable
2221from itertools import repeat
2322
2423import cloudpickle
3433from pymc .model import modelcontext
3534from pymc .sampling .parallel import _cpu_count
3635from pymc .smc .kernels import IMH
36+ from pymc .util import RandomState , _get_seeds_per_chain
3737
3838
3939def sample_smc (
@@ -42,7 +42,7 @@ def sample_smc(
4242 * ,
4343 start = None ,
4444 model = None ,
45- random_seed = None ,
45+ random_seed : RandomState = None ,
4646 chains = None ,
4747 cores = None ,
4848 compute_convergence_checks = True ,
@@ -64,8 +64,10 @@ def sample_smc(
6464 Starting point in parameter space. It should be a list of dict with length `chains`.
6565 When None (default) the starting point is sampled from the prior distribution.
6666 model: Model (optional if in ``with`` context)).
67- random_seed: int
68- random seed
67+ random_seed : int, array-like of int, RandomState or Generator, optional
68+ Random seed(s) used by the sampling steps. If a list, tuple or array of ints
69+ is passed, each entry will be used to seed each chain. A ValueError will be
70+ raised if the length does not match the number of chains.
6971 chains : int
7072 The number of chains to sample. Running independent chains is important for some
7173 convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
@@ -183,20 +185,7 @@ def sample_smc(
183185 else :
184186 cores = min (chains , cores )
185187
186- if random_seed == - 1 :
187- raise FutureWarning (
188- f"random_seed should be a non-negative integer or None, got: { random_seed } "
189- "This will raise a ValueError in the Future"
190- )
191- random_seed = None
192- if isinstance (random_seed , int ) or random_seed is None :
193- rng = np .random .default_rng (seed = random_seed )
194- random_seed = list (rng .integers (2 ** 30 , size = chains ))
195- elif isinstance (random_seed , Iterable ):
196- if len (random_seed ) != chains :
197- raise ValueError (f"Length of seeds ({ len (seeds )} ) must match number of chains { chains } " )
198- else :
199- raise TypeError ("Invalid value for `random_seed`. Must be tuple, list, int or None" )
188+ random_seed = _get_seeds_per_chain (random_state = random_seed , chains = chains )
200189
201190 model = modelcontext (model )
202191
0 commit comments