Skip to content

Use _get_seeds_per_chain utility in smc/sampling.py #6258

@ricardoV94

Description

@ricardoV94

pymc/pymc/smc/sampling.py

Lines 186 to 199 in bcffce2

if random_seed == -1:
raise FutureWarning(
f"random_seed should be a non-negative integer or None, got: {random_seed}"
"This will raise a ValueError in the Future"
)
random_seed = None
if isinstance(random_seed, int) or random_seed is None:
rng = np.random.default_rng(seed=random_seed)
random_seed = list(rng.integers(2**30, size=chains))
elif isinstance(random_seed, Iterable):
if len(random_seed) != chains:
raise ValueError(f"Length of seeds ({len(seeds)}) must match number of chains {chains}")
else:
raise TypeError("Invalid value for `random_seed`. Must be tuple, list, int or None")

This logic is better handled by the newer utility, also the type-hints for random_seed can be updated

It's curretly located here:

pymc/pymc/sampling.py

Lines 258 to 305 in e57d1d7

def _get_seeds_per_chain(
random_state: RandomState,
chains: int,
) -> Union[Sequence[int], np.ndarray]:
"""Obtain or validate specified integer seeds per chain.
This function process different possible sources of seeding and returns one integer
seed per chain:
1. If the input is an integer and a single chain is requested, the input is
returned inside a tuple.
2. If the input is a sequence or NumPy array with as many entries as chains,
the input is returned.
3. If the input is an integer and multiple chains are requested, new unique seeds
are generated from NumPy default Generator seeded with that integer.
4. If the input is None new unique seeds are generated from an unseeded NumPy default
Generator.
5. If a RandomState or Generator is provided, new unique seeds are generated from it.
Raises
------
ValueError
If none of the conditions above are met
"""
def _get_unique_seeds_per_chain(integers_fn):
seeds = []
while len(set(seeds)) != chains:
seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)]
return seeds
if random_state is None or isinstance(random_state, int):
if chains == 1 and isinstance(random_state, int):
return (random_state,)
return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers)
if isinstance(random_state, np.random.Generator):
return _get_unique_seeds_per_chain(random_state.integers)
if isinstance(random_state, np.random.RandomState):
return _get_unique_seeds_per_chain(random_state.randint)
if not isinstance(random_state, (list, tuple, np.ndarray)):
raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.")
if len(random_state) != chains:
raise ValueError(
f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})."
)
return random_state

But will likely move after #6257

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions