Skip to content

Commit f8e8309

Browse files
author
juanitorduz
committed
improve random seed processing
1 parent 5d7283e commit f8e8309

File tree

1 file changed

+2
-14
lines changed

1 file changed

+2
-14
lines changed

pymc/smc/sampling.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pymc.model import modelcontext
3535
from pymc.sampling.parallel import _cpu_count
3636
from pymc.smc.kernels import IMH
37+
from pymc.util import _get_seeds_per_chain
3738

3839

3940
def sample_smc(
@@ -183,20 +184,7 @@ def sample_smc(
183184
else:
184185
cores = min(chains, cores)
185186

186-
if random_seed == -1:
187-
raise FutureWarning(
188-
f"random_seed should be a non-negative integer or None, got: {random_seed}"
189-
"This will raise a ValueError in the Future"
190-
)
191-
random_seed = None
192-
if isinstance(random_seed, int) or random_seed is None:
193-
rng = np.random.default_rng(seed=random_seed)
194-
random_seed = list(rng.integers(2**30, size=chains))
195-
elif isinstance(random_seed, Iterable):
196-
if len(random_seed) != chains:
197-
raise ValueError(f"Length of seeds ({len(seeds)}) must match number of chains {chains}")
198-
else:
199-
raise TypeError("Invalid value for `random_seed`. Must be tuple, list, int or None")
187+
random_seed = _get_seeds_per_chain(random_state=random_seed, chains=chains)
200188

201189
model = modelcontext(model)
202190

0 commit comments

Comments
 (0)