Skip to content

Commit c25ab22

Browse files
author
juanitorduz
committed
improve type-hint
1 parent e8f11e0 commit c25ab22

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pymc/smc/sampling.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pymc.model import modelcontext
3434
from pymc.sampling.parallel import _cpu_count
3535
from pymc.smc.kernels import IMH
36-
from pymc.util import _get_seeds_per_chain
36+
from pymc.util import RandomState, _get_seeds_per_chain
3737

3838

3939
def sample_smc(
@@ -42,7 +42,7 @@ def sample_smc(
4242
*,
4343
start=None,
4444
model=None,
45-
random_seed=None,
45+
random_seed: RandomState = None,
4646
chains=None,
4747
cores=None,
4848
compute_convergence_checks=True,
@@ -64,8 +64,10 @@ def sample_smc(
6464
Starting point in parameter space. It should be a list of dict with length `chains`.
6565
When None (default) the starting point is sampled from the prior distribution.
6666
model: Model (optional if in ``with`` context)).
67-
random_seed: int
68-
random seed
67+
random_seed : int, array-like of int, RandomState or Generator, optional
68+
Random seed(s) used by the sampling steps. If a list, tuple or array of ints
69+
is passed, each entry will be used to seed each chain. A ValueError will be
70+
raised if the length does not match the number of chains.
6971
chains : int
7072
The number of chains to sample. Running independent chains is important for some
7173
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever

0 commit comments

Comments
 (0)