|
14 | 14 |
|
15 | 15 | import functools |
16 | 16 |
|
17 | | -from typing import Any, Dict, List, Tuple, Union, cast |
| 17 | +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast |
18 | 18 |
|
19 | 19 | import arviz |
20 | 20 | import cloudpickle |
@@ -387,3 +387,57 @@ def wrapped(**kwargs): |
387 | 387 | return core_function(**input_point) |
388 | 388 |
|
389 | 389 | return wrapped |
| 390 | + |
| 391 | + |
| 392 | +RandomSeed = Optional[Union[int, Sequence[int], np.ndarray]] |
| 393 | +RandomState = Union[RandomSeed, np.random.RandomState, np.random.Generator] |
| 394 | + |
| 395 | + |
| 396 | +def _get_seeds_per_chain( |
| 397 | + random_state: RandomState, |
| 398 | + chains: int, |
| 399 | +) -> Union[Sequence[int], np.ndarray]: |
| 400 | + """Obtain or validate specified integer seeds per chain. |
| 401 | +
|
| 402 | + This function process different possible sources of seeding and returns one integer |
| 403 | + seed per chain: |
| 404 | + 1. If the input is an integer and a single chain is requested, the input is |
| 405 | + returned inside a tuple. |
| 406 | + 2. If the input is a sequence or NumPy array with as many entries as chains, |
| 407 | + the input is returned. |
| 408 | + 3. If the input is an integer and multiple chains are requested, new unique seeds |
| 409 | + are generated from NumPy default Generator seeded with that integer. |
| 410 | + 4. If the input is None new unique seeds are generated from an unseeded NumPy default |
| 411 | + Generator. |
| 412 | + 5. If a RandomState or Generator is provided, new unique seeds are generated from it. |
| 413 | +
|
| 414 | + Raises |
| 415 | + ------ |
| 416 | + ValueError |
| 417 | + If none of the conditions above are met |
| 418 | + """ |
| 419 | + |
| 420 | + def _get_unique_seeds_per_chain(integers_fn): |
| 421 | + seeds = [] |
| 422 | + while len(set(seeds)) != chains: |
| 423 | + seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)] |
| 424 | + return seeds |
| 425 | + |
| 426 | + if random_state is None or isinstance(random_state, int): |
| 427 | + if chains == 1 and isinstance(random_state, int): |
| 428 | + return (random_state,) |
| 429 | + return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers) |
| 430 | + if isinstance(random_state, np.random.Generator): |
| 431 | + return _get_unique_seeds_per_chain(random_state.integers) |
| 432 | + if isinstance(random_state, np.random.RandomState): |
| 433 | + return _get_unique_seeds_per_chain(random_state.randint) |
| 434 | + |
| 435 | + if not isinstance(random_state, (list, tuple, np.ndarray)): |
| 436 | + raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.") |
| 437 | + |
| 438 | + if len(random_state) != chains: |
| 439 | + raise ValueError( |
| 440 | + f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})." |
| 441 | + ) |
| 442 | + |
| 443 | + return random_state |
0 commit comments