-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Labels
Description
Lines 186 to 199 in bcffce2
| if random_seed == -1: | |
| raise FutureWarning( | |
| f"random_seed should be a non-negative integer or None, got: {random_seed}" | |
| "This will raise a ValueError in the Future" | |
| ) | |
| random_seed = None | |
| if isinstance(random_seed, int) or random_seed is None: | |
| rng = np.random.default_rng(seed=random_seed) | |
| random_seed = list(rng.integers(2**30, size=chains)) | |
| elif isinstance(random_seed, Iterable): | |
| if len(random_seed) != chains: | |
| raise ValueError(f"Length of seeds ({len(seeds)}) must match number of chains {chains}") | |
| else: | |
| raise TypeError("Invalid value for `random_seed`. Must be tuple, list, int or None") |
This logic is better handled by the newer utility, also the type-hints for random_seed can be updated
It's curretly located here:
Lines 258 to 305 in e57d1d7
| def _get_seeds_per_chain( | |
| random_state: RandomState, | |
| chains: int, | |
| ) -> Union[Sequence[int], np.ndarray]: | |
| """Obtain or validate specified integer seeds per chain. | |
| This function process different possible sources of seeding and returns one integer | |
| seed per chain: | |
| 1. If the input is an integer and a single chain is requested, the input is | |
| returned inside a tuple. | |
| 2. If the input is a sequence or NumPy array with as many entries as chains, | |
| the input is returned. | |
| 3. If the input is an integer and multiple chains are requested, new unique seeds | |
| are generated from NumPy default Generator seeded with that integer. | |
| 4. If the input is None new unique seeds are generated from an unseeded NumPy default | |
| Generator. | |
| 5. If a RandomState or Generator is provided, new unique seeds are generated from it. | |
| Raises | |
| ------ | |
| ValueError | |
| If none of the conditions above are met | |
| """ | |
| def _get_unique_seeds_per_chain(integers_fn): | |
| seeds = [] | |
| while len(set(seeds)) != chains: | |
| seeds = [int(seed) for seed in integers_fn(2**30, dtype=np.int64, size=chains)] | |
| return seeds | |
| if random_state is None or isinstance(random_state, int): | |
| if chains == 1 and isinstance(random_state, int): | |
| return (random_state,) | |
| return _get_unique_seeds_per_chain(np.random.default_rng(random_state).integers) | |
| if isinstance(random_state, np.random.Generator): | |
| return _get_unique_seeds_per_chain(random_state.integers) | |
| if isinstance(random_state, np.random.RandomState): | |
| return _get_unique_seeds_per_chain(random_state.randint) | |
| if not isinstance(random_state, (list, tuple, np.ndarray)): | |
| raise ValueError(f"The `seeds` must be array-like. Got {type(random_state)} instead.") | |
| if len(random_state) != chains: | |
| raise ValueError( | |
| f"Number of seeds ({len(random_state)}) does not match the number of chains ({chains})." | |
| ) | |
| return random_state |
But will likely move after #6257